29 #include "tiny_dnn/core/backend.h"
31 #include "tiny_dnn/core/kernels/avx_deconv2d_kernel.h"
32 #include "tiny_dnn/core/kernels/avx_deconv2d_back_kernel.h"
44 std::function<
void(
const tensor_t&)> f1,
45 std::function<
void(
const tensor_t&, tensor_t&)> f2,
46 std::function<
void(
const tensor_t&,
const tensor_t&, tensor_t&)> f3,
49 , conv_layer_worker_storage_(ptr)
50 , copy_and_pad_input(f1)
51 , copy_and_unpad_delta(f2)
52 , backward_activation(f3) {}
56 std::function<
void(
const tensor_t&)> f1,
57 std::function<
void(
const tensor_t&, tensor_t&)> f2,
58 std::function<
void(
const tensor_t&,
const tensor_t&, tensor_t&)> f3,
61 , deconv_layer_worker_storage_(ptr)
62 , copy_and_unpad_output(f1)
63 , copy_and_pad_delta(f2)
64 , backward_activation(f3) {}
67 avx_backend(std::vector<std::vector<serial_size_t>>* out2in,
68 std::vector<serial_size_t>* in2out,
69 std::function<
void(
const tensor_t&,
const tensor_t&, tensor_t&)> f,
71 : max_pooling_layer_worker_storage_(ptr)
74 , backward_activation(f) {}
78 std::function<
void(
const tensor_t&,
const tensor_t&, tensor_t&)> f)
80 , backward_activation(f) {}
84 void conv2d(
const std::vector<tensor_t*>& in_data,
85 std::vector<tensor_t*>& out_data)
override {
87 if (params_c_)
return;
88 if (params_f_)
return;
89 if (conv_layer_worker_storage_)
return;
102 void conv2d_q(
const std::vector<tensor_t*>& in_data,
103 std::vector<tensor_t*>& out_data)
override {
104 throw nn_error(
"not implemented yet.");
107 void conv2d_eq(
const std::vector<tensor_t*>& in_data,
108 std::vector<tensor_t*>& out_data)
override {
109 throw nn_error(
"not implemented yet.");
112 void conv2d(
const std::vector<tensor_t*>& in_data,
113 const std::vector<tensor_t*>& out_data,
114 std::vector<tensor_t*>& out_grad,
115 std::vector<tensor_t*>& in_grad)
override {
142 void conv2d_q(
const std::vector<tensor_t*>& in_data,
143 const std::vector<tensor_t*>& out_data,
144 std::vector<tensor_t*>& out_grad,
145 std::vector<tensor_t*>& in_grad)
override {
146 throw nn_error(
"not implemented yet.");
149 void deconv2d(
const std::vector<tensor_t*>& in_data,
150 std::vector<tensor_t*>& out_data)
override {
151 (*deconv_layer_worker_storage_).prev_out_ = in_data[0];
152 const vec_t& W = (*in_data[1])[0];
153 const vec_t& bias = (*in_data[2])[0];
154 tensor_t& a = *out_data[1];
155 const tensor_t &in = *in_data[0];
157 fill_tensor(a, float_t(0));
159 kernels::avx_deconv2d_kernel(*params_d_,
160 in, W, bias, a, layer_->parallelize());
162 copy_and_unpad_output(a);
163 a = *(*deconv_layer_worker_storage_).curr_out_unpadded_;
166 void deconv2d_q(
const std::vector<tensor_t*>& in_data,
167 std::vector<tensor_t*>& out_data)
override {
168 throw nn_error(
"not implemented yet.");
171 void deconv2d_eq(
const std::vector<tensor_t*>& in_data,
172 std::vector<tensor_t*>& out_data)
override {
173 throw nn_error(
"not implemented yet.");
176 void deconv2d(
const std::vector<tensor_t*>& in_data,
177 const std::vector<tensor_t*>& out_data,
178 std::vector<tensor_t*>& out_grad,
179 std::vector<tensor_t*>& in_grad)
override {
182 if (params_d_->pad_type == padding::same)
183 copy_and_pad_delta(cws.curr_delta_padded, *in_grad[0]);
185 const tensor_t& prev_out = *(cws.prev_out_);
186 const vec_t& W = (*in_data[1])[0];
187 tensor_t& dW = *in_grad[1];
188 tensor_t& db = *in_grad[2];
189 tensor_t& curr_delta = (params_d_->pad_type == padding::same) ? cws.curr_delta_padded : *out_grad[1];
190 tensor_t* prev_delta = in_grad[0];
192 assert(W.size() == params_d_->weight.size());
193 assert(dW[0].size() == params_d_->weight.size());
194 assert(curr_delta[0].size() == layer_->
out_shape()[0].size());
196 backward_activation(*out_grad[0], *out_data[0], curr_delta);
198 fill_tensor(*prev_delta, float_t(0));
200 kernels::avx_deconv2d_back_kernel(*params_d_,
201 prev_out, W, dW, db, curr_delta, prev_delta);
204 void deconv2d_q(
const std::vector<tensor_t*>& in_data,
205 const std::vector<tensor_t*>& out_data,
206 std::vector<tensor_t*>& out_grad,
207 std::vector<tensor_t*>& in_grad)
override {
208 throw nn_error(
"not implemented yet.");
211 void maxpool(
const std::vector<tensor_t*>& in_data,
212 std::vector<tensor_t*>& out_data)
override {
214 if (max_pooling_layer_worker_storage_) {}
227 void maxpool(
const std::vector<tensor_t*>& in_data,
228 const std::vector<tensor_t*>& out_data,
229 std::vector<tensor_t*>& out_grad,
230 std::vector<tensor_t*>& in_grad)
override {
244 void fully(
const std::vector<tensor_t*>& in_data,
245 std::vector<tensor_t*>& out_data)
override {
255 void fully_q(
const std::vector<tensor_t*>& in_data,
256 std::vector<tensor_t*>& out_data)
override {
257 throw nn_error(
"not implemented yet.");
260 void fully_eq(
const std::vector<tensor_t*>& in_data,
261 std::vector<tensor_t*>& out_data)
override {
262 throw nn_error(
"not implemented yet.");
265 void fully(
const std::vector<tensor_t*>& in_data,
266 const std::vector<tensor_t*>& out_data,
267 std::vector<tensor_t*>& out_grad,
268 std::vector<tensor_t*>& in_grad)
override {
282 void fully_q(
const std::vector<tensor_t*>& in_data,
283 const std::vector<tensor_t*>& out_data,
284 std::vector<tensor_t*>& out_grad,
285 std::vector<tensor_t*>& in_grad)
override {
286 throw nn_error(
"not implemented yet.");
289 backend_t type()
const override {
return backend_t::avx; }
301 std::vector<std::vector<serial_size_t>>* out2in_;
302 std::vector<serial_size_t>* in2out_;
305 std::function<void(
const tensor_t&)> copy_and_pad_input;
306 std::function<void(
const tensor_t&)> copy_and_unpad_output;
307 std::function<void(
const tensor_t&, tensor_t&)> copy_and_unpad_delta;
308 std::function<void(
const tensor_t&, tensor_t&)> copy_and_pad_delta;
309 std::function<void(
const tensor_t&,
const tensor_t&, tensor_t&)> backward_activation;
Definition: backend_avx.h:37
Definition: conv_params.h:92
Definition: fully_params.h:34
virtual std::vector< shape3d > out_shape() const =0
array of output shapes (width x height x depth)
error exception class for tiny-dnn
Definition: nn_error.h:37
Definition: conv_params.h:34
Definition: deconv_params.h:32
Definition: deconv_params.h:39
Definition: maxpool_params.h:51