47 #include "tiny_dnn/core/framework/op_kernel.h"
49 #include "tiny_dnn/core/kernels/fully_connected_op_avx.h"
50 #include "tiny_dnn/core/kernels/fully_connected_op_internal.h"
60 auto params = OpKernel::params_->fully();
63 const tensor_t& prev_out = context.input(0);
64 const tensor_t& W = context.input(1);
65 tensor_t& dW = context.input_grad(1);
66 tensor_t* db = params.has_bias_ ? &context.input_grad(2) :
nullptr;
67 tensor_t& prev_delta = context.input_grad(0);
68 tensor_t& curr_delta = context.output_grad(1);
72 fill_tensor(prev_delta, float_t(0));
76 const core::backend_t engine = context.engine();
78 if (engine == core::backend_t::internal) {
79 kernels::fully_connected_op_internal(
83 params.has_bias_ ? *db : dummy,
87 context.parallelize());
89 else if (engine == core::backend_t::avx) {
90 kernels::fully_connected_op_avx(
94 params.has_bias_ ? *db : dummy,
98 context.parallelize());
101 throw nn_error(
"Not supported engine: " + to_string(engine));
Definition: fully_connected_grad_op.h:54
Definition: op_kernel.h:55
Definition: op_kernel.h:72
Definition: op_kernel.h:175
error exception class for tiny-dnn
Definition: nn_error.h:37