47 #include "tiny_dnn/core/framework/op_kernel.h"
49 #include "tiny_dnn/core/kernels/conv2d_op_avx.h"
50 #include "tiny_dnn/core/kernels/conv2d_op_internal.h"
51 #include "tiny_dnn/core/kernels/conv2d_op_nnpack.h"
61 auto params = OpKernel::params_->conv();
64 const tensor_t& in_data = context.input(0);
65 const tensor_t& W = context.input(1);
66 const tensor_t& bias = context.input(2);
67 tensor_t& out_data = context.output(1);
70 fill_tensor(out_data, float_t(0));
75 const core::backend_t engine = context.engine();
77 if (engine == core::backend_t::internal) {
78 kernels::conv2d_op_internal(
84 context.parallelize());
86 else if (engine == core::backend_t::nnpack) {
87 kernels::conv2d_op_nnpack(
94 else if (engine == core::backend_t::avx) {
95 kernels::conv2d_op_avx(
101 context.parallelize());
104 throw nn_error(
"Not supported engine: " + to_string(engine));
Definition: conv2d_op.h:55
Definition: op_kernel.h:55
Definition: op_kernel.h:72
Definition: op_kernel.h:175
error exception class for tiny-dnn
Definition: nn_error.h:37