47 #include "tiny_dnn/core/framework/op_kernel.h"
49 #include "tiny_dnn/core/kernels/conv2d_grad_op_avx.h"
50 #include "tiny_dnn/core/kernels/conv2d_op_internal.h"
60 auto params = OpKernel::params_->conv();
63 const tensor_t& prev_out = context.input(0);
64 const tensor_t& W = context.input(1);
65 tensor_t& dW = context.input_grad(1);
66 tensor_t& db = context.input_grad(2);
67 tensor_t& prev_delta = context.input_grad(0);
68 tensor_t& curr_delta = context.output_grad(1);
71 fill_tensor(prev_delta, float_t(0));
76 const core::backend_t engine = context.engine();
78 if (engine == core::backend_t::internal) {
79 kernels::conv2d_op_internal(
87 context.parallelize());
89 else if (engine == core::backend_t::avx) {
90 kernels::conv2d_grad_op_avx(
98 context.parallelize());
101 throw nn_error(
"Not supported engine: " + to_string(engine));
Definition: conv2d_grad_op.h:54
Definition: op_kernel.h:55
Definition: op_kernel.h:72
Definition: op_kernel.h:175
error exception class for tiny-dnn
Definition: nn_error.h:37