29 #include "tiny_dnn/core/params/deconv_params.h"
35 inline void tiny_deconv2d_back_kernel(
const deconv_params& params,
36 const tensor_t& prev_out,
41 tensor_t* prev_delta) {
43 for_i(prev_out.size(), [&](
int sample) {
44 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
45 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
46 if (!params.tbl.is_connected(outc, inc)) continue;
48 serial_size_t idx = 0;
49 idx = params.in.depth_ * outc + inc;
50 idx = params.weight.get_index(0, 0, idx);
51 const float_t *pw = &W[idx];
53 idx = params.out_unpadded.get_index(0, 0, outc);
54 const float_t *pdelta_src = &curr_delta[sample][idx];
56 idx = params.in.get_index(0, 0, inc);
57 float_t *pdelta_dst = &(*prev_delta)[sample][idx];
59 for (serial_size_t y = 0; y < params.in.height_; y++) {
60 for (serial_size_t x = 0; x < params.in.width_; x++) {
61 const float_t * ppw = pw;
63 float_t * ppdelta_dst = pdelta_dst + y * params.in.width_ + x;
64 float_t sum = float_t(0);
66 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
67 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
68 idx = (y * params.h_stride + wy) *
69 params.out.width_ + (x *
70 params.w_stride + wx);
71 sum += ppw[wy * params.weight.width_ + wx] *
82 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
83 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
84 if (!params.tbl.is_connected(outc, inc))
continue;
86 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
87 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
88 float_t dst = float_t(0);
90 serial_size_t idx = 0;
91 idx = params.in.get_index(0, 0, inc);
92 const float_t * prevo = &prev_out[sample][idx];
94 idx = params.out.get_index(wx, wy, outc);
95 const float_t * delta = &curr_delta[sample][idx];
97 for (serial_size_t y = 0; y < params.in.height_; y++) {
98 dst += vectorize::dot(
99 prevo + y * params.in.width_,
100 delta + y * params.out.width_,
104 idx = params.in.depth_ * outc + inc;
105 dW[sample][params.weight.get_index(wx, wy, idx)] += dst;
112 if (params.has_bias) {
115 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
116 serial_size_t idx = params.out.get_index(0, 0, outc);
117 const float_t * delta = &curr_delta[sample][idx];
118 const float_t * deltaa = delta + params.out.width_ *
120 db[sample][outc] += std::accumulate(delta, deltaa, float_t(0));