47 #include "tiny_dnn/core/framework/op_kernel.h"
49 #include "tiny_dnn/core/kernels/fully_connected_op_avx.h"
50 #include "tiny_dnn/core/kernels/fully_connected_op_internal.h"
51 #include "tiny_dnn/core/kernels/fully_connected_op_nnpack.h"
61 auto params = OpKernel::params_->fully();
64 const tensor_t& in_data = context.input(0);
65 const tensor_t& W = context.input(1);
66 const tensor_t* bias = params.has_bias_ ? &context.input(2) :
nullptr;
67 tensor_t& out_data = context.output(1);
70 fill_tensor(out_data, float_t(0));
74 const core::backend_t engine = context.engine();
76 if (engine == core::backend_t::internal) {
77 kernels::fully_connected_op_internal(
80 params.has_bias_ ? (*bias)[0] : vec_t(),
83 context.parallelize());
85 else if (engine == core::backend_t::nnpack) {
86 kernels::fully_connected_op_nnpack(
89 params.has_bias_ ? (*bias)[0] : vec_t(),
92 context.parallelize());
94 else if (engine == core::backend_t::avx) {
95 kernels::fully_connected_op_avx(
98 params.has_bias_ ? (*bias)[0] : vec_t(),
101 context.parallelize());
104 throw nn_error(
"Not supported engine: " + to_string(engine));
Definition: fully_connected_op.h:55
Definition: op_kernel.h:55
Definition: op_kernel.h:72
Definition: op_kernel.h:175
error exception class for tiny-dnn
Definition: nn_error.h:37