51 conv2d_op_internal(
const tensor_t& in_data,
55 const core::conv_params& params,
56 const bool parallelize) {
57 for_i(parallelize, in_data.size(), [&](
int sample) {
58 const vec_t& in = in_data[sample];
59 vec_t& a = out_data[sample];
61 for (serial_size_t o = 0; o < params.out.depth_; o++) {
62 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
63 if (!params.tbl.is_connected(o, inc)) continue;
65 serial_size_t idx = 0;
66 idx = params.in.depth_ * o + inc;
67 idx = params.weight.get_index(0, 0, idx);
68 const float_t *pw = &W[idx];
70 idx = params.in_padded.get_index(0, 0, inc);
71 const float_t *pi = &in[idx];
73 idx = params.out.get_index(0, 0, o);
74 float_t *pa = &a[idx];
76 for (serial_size_t y = 0; y < params.out.height_; y++) {
77 for (serial_size_t x = 0; x < params.out.width_; x++) {
78 const float_t * ppw = pw;
79 const float_t * ppi = pi + params.in_padded.width_ *
80 (y * params.h_stride) +
82 float_t sum = float_t(0);
85 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
86 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
87 idx = wy * params.in_padded.width_ + wx;
88 sum += *ppw++ * ppi[idx];
91 pa[y * params.out.width_ + x] += sum;
96 if (params.has_bias) {
97 float_t * pa = &a[params.out.get_index(0, 0, o)];
98 float_t * paa = pa + params.out.width_ * params.out.height_;
99 std::for_each(pa, paa, [&](float_t& f) { f += bias[o]; });
109 template <
typename tensor_t,
typename vec_t>
111 conv2d_op_internal(
const tensor_t& prev_out,
115 tensor_t& curr_delta,
116 tensor_t& prev_delta,
117 const core::conv_params& params,
118 const bool parallelize) {
120 typedef typename vec_t::value_type float_t;
122 for_i(parallelize, prev_out.size(), [&](
int sample) {
124 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
125 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
126 if (!params.tbl.is_connected(outc, inc)) continue;
128 serial_size_t idx = 0;
129 idx = params.in.depth_ * outc + inc;
130 idx = params.weight.get_index(0, 0, idx);
131 const float_t *pw = &W[idx];
133 idx = params.out.get_index(0, 0, outc);
134 const float_t *pdelta_src = &curr_delta[sample][idx];
136 idx = params.in_padded.get_index(0, 0, inc);
138 float_t *pdelta_dst = &prev_delta[sample][idx];
140 for (serial_size_t y = 0; y < params.out.height_; y++) {
141 for (serial_size_t x = 0; x < params.out.width_; x++) {
142 const float_t * ppw = pw;
144 idx = y * params.out.width_ + x;
145 const float_t ppdelta_src = pdelta_src[idx];
147 float_t * ppdelta_dst = pdelta_dst +
148 y * params.h_stride * params.in_padded.width_ +
151 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
152 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
153 idx = wy * params.in_padded.width_ + wx;
154 ppdelta_dst[idx] += *ppw++ * ppdelta_src;
163 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
164 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
165 if (!params.tbl.is_connected(outc, inc))
continue;
167 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
168 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
169 float_t dst = float_t(0);
171 serial_size_t idx = 0;
172 idx = params.in_padded.get_index(wx, wy, inc);
173 const float_t * prevo = &prev_out[sample][idx];
175 idx = params.out.get_index(0, 0, outc);
176 const float_t * delta = &curr_delta[sample][idx];
178 if (params.w_stride > 1) {
179 for (serial_size_t y = 0; y < params.out.height_; y++) {
180 serial_size_t prevo_idx = y * params.in_padded.width_ * params.h_stride;
181 serial_size_t delta_idx = y * params.out.width_;
183 for (serial_size_t x = 0; x < params.out.width_; x++) {
184 dst += prevo[prevo_idx + x * params.w_stride] * delta[delta_idx + x];
188 for (serial_size_t y = 0; y < params.out.height_; y++) {
189 dst += vectorize::dot(
190 prevo + y * params.in_padded.width_ * params.h_stride,
191 delta + y * params.out.width_,
197 idx = params.in.depth_ * outc + inc;
198 dW[sample][params.weight.get_index(wx, wy, idx)] += dst;
205 if (params.has_bias) {
206 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
207 serial_size_t idx = params.out.get_index(0, 0, outc);
208 const float_t * delta = &curr_delta[sample][idx];
209 const float_t * deltaa = delta + params.out.width_ *
211 db[sample][outc] += std::accumulate(delta, deltaa, float_t(0));