tiny_dnn  1.0.0
A header only, dependency-free deep learning framework in C++11
conv2d_grad_op_avx.h
1 /*
2  Copyright (c) 2016, Taiga Nomi, Edgar Riba
3  All rights reserved.
4 
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  * Redistributions of source code must retain the above copyright
8  notice, this list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright
10  notice, this list of conditions and the following disclaimer in the
11  documentation and/or other materials provided with the distribution.
12  * Neither the name of the <organization> nor the
13  names of its contributors may be used to endorse or promote products
14  derived from this software without specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 #pragma once
28 
29 #include <vector>
30 #include "tiny_dnn/core/params/conv_params.h"
31 #include "tiny_dnn/core/kernels/conv2d_op_internal.h"
32 
33 #ifdef CNN_USE_AVX
34 #include "tiny_dnn/core/kernels/avx_kernel_common.h"
35 #endif
36 
37 namespace tiny_dnn {
38 namespace kernels {
39 
40 #ifdef CNN_USE_AVX
41 
42 // float ver
43 template <typename Allocator>
44 void avx_conv2d_5x5_back_kernel_one(const core::conv_params& params,
45  const std::vector<float, Allocator>& prev_out,
46  const std::vector<float, Allocator>& W,
47  std::vector<float, Allocator>& dW,
48  std::vector<float, Allocator>& db,
49  std::vector<float, Allocator>& curr_delta,
50  std::vector<float, Allocator>* prev_delta) {
51  auto& in = params.in;
52  auto& out = params.out;
53  auto& in_padded = params.in_padded;
54  auto& tbl = params.tbl;
55  auto w_stride = params.w_stride;
56  const size_t in_padded_area = in_padded.area();
57  float* pdelta_dst_org = &(*prev_delta)[0];
58  const size_t h_stride2 = params.h_stride * in_padded.width_;
59  static const __m256i imask = _mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0);
60  static const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0));
61  // propagate delta to previous layer
62  if (w_stride == 1 && out.width_ >= 4) {
63  const serial_size_t nblocks = out.width_ / 4;
64  for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
65  for (serial_size_t outc = 0; outc < out.depth_; outc++) {
66  if (!tbl.is_connected(outc, inc)) continue;
67  const float* pw = &W[25 * (in.depth_ * outc + inc)];
68  const float* pdelta_src = &curr_delta[out.get_index(0, 0, outc)];
69  float* pdelta_dst = pdelta_dst_org;
70  __m256 w0a = _mm256_and_ps(_mm256_loadu_ps(pw+0), mask);
71  __m256 w1a = _mm256_and_ps(_mm256_loadu_ps(pw+5), mask);
72  __m256 w2a = _mm256_and_ps(_mm256_loadu_ps(pw+10), mask);
73  __m256 w3a = _mm256_and_ps(_mm256_loadu_ps(pw+15), mask);
74  __m256 w4a = _mm256_and_ps(_mm256_loadu_ps(pw+20), mask);
75  __m256 w0b = leftShift<4>(w0a);
76  __m256 w1b = leftShift<4>(w1a);
77  __m256 w2b = leftShift<4>(w2a);
78  __m256 w3b = leftShift<4>(w3a);
79  __m256 w4b = leftShift<4>(w4a);
80  __m256 w0c = leftShift<8>(w0a);
81  __m256 w1c = leftShift<8>(w1a);
82  __m256 w2c = leftShift<8>(w2a);
83  __m256 w3c = leftShift<8>(w3a);
84  __m256 w4c = leftShift<8>(w4a);
85  __m256 w0d = leftShift<12>(w0a);
86  __m256 w1d = leftShift<12>(w1a);
87  __m256 w2d = leftShift<12>(w2a);
88  __m256 w3d = leftShift<12>(w3a);
89  __m256 w4d = leftShift<12>(w4a);
90  for (serial_size_t y = 0; y < out.height_; y++) {
91  const float* pdelta_src2 = pdelta_src;
92  float* delta_dst0 = pdelta_dst;
93  float* delta_dst1 = &pdelta_dst[in_padded.width_ * 1];
94  float* delta_dst2 = &pdelta_dst[in_padded.width_ * 2];
95  float* delta_dst3 = &pdelta_dst[in_padded.width_ * 3];
96  float* delta_dst4 = &pdelta_dst[in_padded.width_ * 4];
97  for (serial_size_t n = 0; n < nblocks; ++n) {
98  __m256 delta_src = _mm256_broadcast_ps((const __m128*)pdelta_src2);
99  __m256 dst0 = _mm256_loadu_ps(delta_dst0 + 4 * n);
100  __m256 dst1 = _mm256_loadu_ps(delta_dst1 + 4 * n);
101  __m256 dst2 = _mm256_loadu_ps(delta_dst2 + 4 * n);
102  __m256 dst3 = _mm256_loadu_ps(delta_dst3 + 4 * n);
103  __m256 dst4 = _mm256_loadu_ps(delta_dst4 + 4 * n);
104  __m256 delta_src0 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(0, 0, 0, 0));
105  __m256 delta_src1 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(1, 1, 1, 1));
106  __m256 delta_src2 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(2, 2, 2, 2));
107  __m256 delta_src3 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(3, 3, 3, 3));
108  dst0 = madd256_ps(w0a, delta_src0, dst0);
109  dst1 = madd256_ps(w1a, delta_src0, dst1);
110  dst2 = madd256_ps(w2a, delta_src0, dst2);
111  dst3 = madd256_ps(w3a, delta_src0, dst3);
112  dst4 = madd256_ps(w4a, delta_src0, dst4);
113  dst0 = madd256_ps(w0b, delta_src1, dst0);
114  dst1 = madd256_ps(w1b, delta_src1, dst1);
115  dst2 = madd256_ps(w2b, delta_src1, dst2);
116  dst3 = madd256_ps(w3b, delta_src1, dst3);
117  dst4 = madd256_ps(w4b, delta_src1, dst4);
118  dst0 = madd256_ps(w0c, delta_src2, dst0);
119  dst1 = madd256_ps(w1c, delta_src2, dst1);
120  dst2 = madd256_ps(w2c, delta_src2, dst2);
121  dst3 = madd256_ps(w3c, delta_src2, dst3);
122  dst4 = madd256_ps(w4c, delta_src2, dst4);
123  dst0 = madd256_ps(w0d, delta_src3, dst0);
124  _mm256_storeu_ps(delta_dst0 + 4 * n, dst0);
125  dst1 = madd256_ps(w1d, delta_src3, dst1);
126  _mm256_storeu_ps(delta_dst1 + 4 * n, dst1);
127  dst2 = madd256_ps(w2d, delta_src3, dst2);
128  _mm256_storeu_ps(delta_dst2 + 4 * n, dst2);
129  dst3 = madd256_ps(w3d, delta_src3, dst3);
130  _mm256_storeu_ps(delta_dst3 + 4 * n, dst3);
131  dst4 = madd256_ps(w4d, delta_src3, dst4);
132  _mm256_storeu_ps(delta_dst4 + 4 * n, dst4);
133  pdelta_src2 += 4;
134  }
135  for (serial_size_t x = nblocks * 4; x < out.width_; x++) {
136  __m256 delta_src = _mm256_broadcast_ss(pdelta_src + x);
137  __m256 dst0 = _mm256_loadu_ps(delta_dst0 + x);
138  __m256 dst1 = _mm256_loadu_ps(delta_dst1 + x);
139  __m256 dst2 = _mm256_loadu_ps(delta_dst2 + x);
140  __m256 dst3 = _mm256_loadu_ps(delta_dst3 + x);
141  __m256 dst4 = _mm256_loadu_ps(delta_dst4 + x);
142  dst0 = madd256_ps(w0a, delta_src, dst0);
143  dst1 = madd256_ps(w1a, delta_src, dst1);
144  dst2 = madd256_ps(w2a, delta_src, dst2);
145  dst3 = madd256_ps(w3a, delta_src, dst3);
146  dst4 = madd256_ps(w4a, delta_src, dst4);
147  _mm256_storeu_ps(delta_dst0 + x, dst0);
148  _mm256_storeu_ps(delta_dst1 + x, dst1);
149  _mm256_storeu_ps(delta_dst2 + x, dst2);
150  _mm256_storeu_ps(delta_dst3 + x, dst3);
151  _mm256_storeu_ps(delta_dst4 + x, dst4);
152  }
153  pdelta_src += out.width_;
154  pdelta_dst += h_stride2;
155  }
156  }
157  }
158  } else if (out.height_ == 1 && out.width_ == 1) {
159  for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
160  float* delta_dst0 = pdelta_dst_org;
161  float* delta_dst1 = &pdelta_dst_org[in_padded.width_ * 1];
162  float* delta_dst2 = &pdelta_dst_org[in_padded.width_ * 2];
163  float* delta_dst3 = &pdelta_dst_org[in_padded.width_ * 3];
164  float* delta_dst4 = &pdelta_dst_org[in_padded.width_ * 4];
165  __m256 dst0 = _mm256_loadu_ps(delta_dst0);
166  __m256 dst1 = _mm256_loadu_ps(delta_dst1);
167  __m256 dst2 = _mm256_loadu_ps(delta_dst2);
168  __m256 dst3 = _mm256_loadu_ps(delta_dst3);
169  __m256 dst4 = _mm256_maskload_ps(delta_dst4, imask);
170 
171  // *FROM
172  // ---0 0000
173  // ---1 1111
174  // ---2 2222
175  // ---3 3333
176  // ---4 4444
177  //
178  // *TO
179  // 1110 0000
180  // 3222 2211
181  // 4444 3333
182  // ---- ---4
183  __m256 sum0 = _mm256_blend_ps(
184  dst0,
185  leftShift<20>(dst1),
186  0xE0 /* 0b11100000 */
187  );
188  __m256 sum1 = _mm256_blend_ps(
189  leftShift<28>(dst3),
190  _mm256_blend_ps(leftShift<8>(dst2), rightShift<12>(dst1), 0x03 /* 0b00000011 */),
191  0x7F /* 0b01111111 */
192  );
193  __m256 sum2 = _mm256_blend_ps(
194  leftShift<16>(dst4),
195  rightShift<4>(dst3),
196  0x0F /* 0b00001111 */
197  );
198  __m128 sum3 = _mm256_extractf128_ps(dst4, 1);
199 
200  size_t widx = 25 * inc;
201  size_t wstep = 25 * in.depth_;
202 
203  if (tbl.is_empty()) {
204  for (serial_size_t outc = 0; outc < out.depth_; outc++, widx+=wstep) {
205  __m256 delta_src = _mm256_broadcast_ss(&curr_delta[outc]);
206  const float* pw = (const float*)&W[widx];
207  __m256 w0 = _mm256_loadu_ps(pw+0);
208  __m256 w1 = _mm256_loadu_ps(pw + 8);
209  __m256 w2 = _mm256_loadu_ps(pw + 16);
210  __m128 w3 = _mm_load_ss(pw + 24);
211  sum0 = madd256_ps(w0, delta_src, sum0);
212  sum1 = madd256_ps(w1, delta_src, sum1);
213  sum2 = madd256_ps(w2, delta_src, sum2);
214  sum3 = madd128_ss(w3, _mm256_castps256_ps128(delta_src), sum3);
215  }
216  }
217  else {
218  for (serial_size_t outc = 0; outc < out.depth_; outc++, widx += wstep) {
219  if (!tbl.is_connected(outc, inc)) {
220  continue;
221  }
222  __m256 delta_src = _mm256_broadcast_ss(&curr_delta[outc]);
223  const float* pw = (const float*)&W[widx];
224  __m256 w0 = _mm256_loadu_ps(pw + 0);
225  __m256 w1 = _mm256_loadu_ps(pw + 8);
226  __m256 w2 = _mm256_loadu_ps(pw + 16);
227  __m128 w3 = _mm_load_ss(pw + 24);
228  sum0 = madd256_ps(w0, delta_src, sum0);
229  sum1 = madd256_ps(w1, delta_src, sum1);
230  sum2 = madd256_ps(w2, delta_src, sum2);
231  sum3 = madd128_ss(w3, _mm256_castps256_ps128(delta_src), sum3);
232  }
233  }
234 
235  // *FROM
236  // 1110 0000
237  // 3222 2211
238  // 4444 3333
239  // ---- ---4
240  //
241  // *TO
242  // ---0 0000
243  // ---1 1111
244  // ---2 2222
245  // ---3 3333
246  // ---4 4444
247  dst0 = _mm256_blend_ps(
248  dst0,
249  sum0,
250  0x1F /* 0b00011111 */
251  );
252  dst1 = _mm256_blend_ps(
253  dst1,
254  _mm256_or_ps(
255  rightShift<20>(sum0),
256  leftShift<12>(sum1)
257  ),
258  0x1F /* 0b00011111 */
259  );
260  dst2 = _mm256_blend_ps(
261  dst2,
262  rightShift<8>(sum1),
263  0x1F /* 0b00011111 */
264  );
265  dst3 = _mm256_blend_ps(
266  dst3,
267  _mm256_or_ps(
268  rightShift<28>(sum1),
269  leftShift<4>(sum2)
270  ),
271  0x1F /* 0b00011111 */
272  );
273  dst4 = _mm256_blend_ps(
274  dst4,
275  _mm256_set_m128(
276  sum3,
277  _mm256_extractf128_ps(sum2, 1)
278  ),
279  0x1F /* 0b00011111 */
280  );
281 
282  _mm256_storeu_ps(delta_dst0, dst0);
283  _mm256_storeu_ps(delta_dst1, dst1);
284  _mm256_storeu_ps(delta_dst2, dst2);
285  _mm256_storeu_ps(delta_dst3, dst3);
286  _mm256_maskstore_ps(delta_dst4, imask, dst4);
287  } // for
288  } else {
289  for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
290  for (serial_size_t outc = 0; outc < out.depth_; outc++) {
291  if (!tbl.is_connected(outc, inc)) continue;
292 
293  const float* pw = &W[25 * (in.depth_ * outc + inc)];
294  const float* pdelta_src = &curr_delta[out.get_index(0, 0, outc)];
295  float* pdelta_dst = pdelta_dst_org;
296  __m256 w0a = _mm256_maskload_ps(pw+0, imask);
297  __m256 w1a = _mm256_maskload_ps(pw+5, imask);
298  __m256 w2a = _mm256_maskload_ps(pw+10, imask);
299  __m256 w3a = _mm256_maskload_ps(pw+15, imask);
300  __m256 w4a = _mm256_maskload_ps(pw+20, imask);
301  for (serial_size_t y = 0; y < out.height_; y++) {
302  float* delta_dst0 = pdelta_dst;
303  float* delta_dst1 = &pdelta_dst[in_padded.width_ * 1];
304  float* delta_dst2 = &pdelta_dst[in_padded.width_ * 2];
305  float* delta_dst3 = &pdelta_dst[in_padded.width_ * 3];
306  float* delta_dst4 = &pdelta_dst[in_padded.width_ * 4];
307  for (serial_size_t x = 0; x < out.width_; x++) {
308  __m256 delta_src = _mm256_broadcast_ss(pdelta_src + x);
309  __m256 dst0 = _mm256_loadu_ps(delta_dst0);
310  __m256 dst1 = _mm256_loadu_ps(delta_dst1);
311  __m256 dst2 = _mm256_loadu_ps(delta_dst2);
312  __m256 dst3 = _mm256_loadu_ps(delta_dst3);
313  __m256 dst4 = _mm256_maskload_ps(delta_dst4, imask);
314  dst0 = madd256_ps(w0a, delta_src, dst0);
315  dst1 = madd256_ps(w1a, delta_src, dst1);
316  dst2 = madd256_ps(w2a, delta_src, dst2);
317  dst3 = madd256_ps(w3a, delta_src, dst3);
318  dst4 = madd256_ps(w4a, delta_src, dst4);
319  _mm256_storeu_ps(delta_dst0, dst0);
320  _mm256_storeu_ps(delta_dst1, dst1);
321  _mm256_storeu_ps(delta_dst2, dst2);
322  _mm256_storeu_ps(delta_dst3, dst3);
323  _mm256_maskstore_ps(delta_dst4, imask, dst4);
324  delta_dst0 += w_stride;
325  delta_dst1 += w_stride;
326  delta_dst2 += w_stride;
327  delta_dst3 += w_stride;
328  delta_dst4 += w_stride;
329  } // for x
330  pdelta_src += out.width_;
331  pdelta_dst += h_stride2;
332  } // for y
333  } // for outc
334  } // for inc
335  }
336 
337  // accumulate dw
338  if (out.width_ == 1 && out.height_ == 1) {
339  const float* pprev_out = &prev_out[0];
340  for (serial_size_t inc = 0; inc < in.depth_; ++inc, pprev_out += in_padded_area) {
341  VECTORIZE_ALIGN(32) float floats[28];
342  size_t in_padded_width = in_padded.width_;
343  _mm256_store_ps(&floats[0], _mm256_loadu_ps(pprev_out + in_padded_width * 0));
344  _mm256_storeu_ps(&floats[5], _mm256_loadu_ps(pprev_out + in_padded_width * 1));
345  _mm256_storeu_ps(&floats[10], _mm256_loadu_ps(pprev_out + in_padded_width * 2));
346  _mm256_storeu_ps(&floats[15], _mm256_loadu_ps(pprev_out + in_padded_width * 3));
347  _mm256_storeu_ps(&floats[20], _mm256_maskload_ps(pprev_out + in_padded_width * 4, imask));
348  __m256 prevos0 = _mm256_load_ps(&floats[0]);
349  __m256 prevos1 = _mm256_load_ps(&floats[8]);
350  __m256 prevos2 = _mm256_load_ps(&floats[16]);
351  __m128 prevos3 = _mm_load_ss(&floats[24]);
352  serial_size_t widx = 25 * inc;
353  serial_size_t widx_delta = 25 * in.depth_;
354  float* pdW = &dW[widx];
355  for (serial_size_t outc = 0; outc < out.depth_; outc++, pdW += widx_delta) {
356  if (!tbl.is_connected(outc, inc)) {
357  continue;
358  }
359  __m256 delta = _mm256_broadcast_ss(&curr_delta[outc]);
360  __m256 w0 = _mm256_loadu_ps(pdW+0);
361  __m256 w1 = _mm256_loadu_ps(pdW+8);
362  __m256 w2 = _mm256_loadu_ps(pdW + 16);
363  __m128 w3 = _mm_load_ss(pdW + 24);
364  w0 = madd256_ps(prevos0, delta, w0);
365  w1 = madd256_ps(prevos1, delta, w1);
366  w2 = madd256_ps(prevos2, delta, w2);
367  w3 = madd128_ss(prevos3, _mm256_castps256_ps128(delta), w3);
368  _mm256_storeu_ps(pdW + 0, w0);
369  _mm256_storeu_ps(pdW + 8, w1);
370  _mm256_storeu_ps(pdW+16, w2);
371  _mm_store_ss(pdW+24, w3);
372  }
373  }
374  } else {
375  // prepare load-mask beforehand
376  const size_t nblocks = out.width_ >> 3;
377  static const int32_t masks[] = {
378  -1, -1, -1, -1,
379  -1, -1, -1, -1,
380  0, 0, 0, 0,
381  0, 0, 0, 0,
382  };
383  const size_t remainder = out.width_ & 7;
384  __m256i mask = _mm256_loadu_si256((const __m256i*)(masks + 8 - remainder));
385  auto& weight = params.weight;
386  for (serial_size_t inc = 0; inc < in.depth_; ++inc) {
387  for (serial_size_t outc = 0; outc < out.depth_; outc++) {
388 
389  if (!tbl.is_connected(outc, inc)) continue;
390  const float* delta = &curr_delta[out.get_index(0, 0, outc)];
391 
392  serial_size_t widx = weight.get_index(0, 0, in.depth_ * outc + inc);
393  for (serial_size_t wy = 0; wy < 5 /* weight.height_ */; wy++) {
394  for (serial_size_t wx = 0; wx < 5 /* weight.width_ */; wx++) {
395  const float* prevo = &prev_out[in_padded.get_index(wx, wy, inc)];
396 
397  if (w_stride > 1) {
398  float_t dst = float_t(0);
399 
400  for (serial_size_t y = 0; y < params.out.height_; y++) {
401  serial_size_t prevo_idx = y * params.in_padded.width_ * params.h_stride;
402  serial_size_t delta_idx = y * params.out.width_;
403 
404  for (serial_size_t x = 0; x < params.out.width_; x++) {
405  dst += prevo[prevo_idx + x * params.w_stride] * delta[delta_idx + x];
406  }
407  }
408  dW[widx] += dst;
409  }
410  else {
411  __m128 prev_sum = _mm_load_ss(&dW[widx]);
412  __m256 sum0 = _mm256_setzero_ps();
413  __m256 sum1 = _mm256_setzero_ps();
414  for (serial_size_t y = 0; y < out.height_; y++) {
415  // vectorize::dot
416  const float* pa = prevo + y * in_padded.width_ * params.h_stride;
417  const float* pb = delta + y * out.width_;
418  for (size_t i = 0; i < nblocks; ++i) {
419  __m256 a = _mm256_loadu_ps(pa + 8 * i);
420  __m256 b = _mm256_loadu_ps(pb + 8 * i);
421  sum0 = madd256_ps(a, b, sum0);
422  }
423  if (remainder) {
424  __m256 a = _mm256_maskload_ps(pa + 8 * nblocks, mask);
425  __m256 b = _mm256_maskload_ps(pb + 8 * nblocks, mask);
426  sum1 = madd256_ps(a, b, sum1);
427  }
428  }
429  sum1 = _mm256_and_ps(sum1, _mm256_castsi256_ps(mask));
430  __m256 sum = _mm256_add_ps(sum0, sum1);
431  _mm_store_ss(&dW[widx], _mm_add_ps(prev_sum, hsum256_ps(sum)));
432  }
433  ++widx;
434  }
435  }
436  }
437  }
438  }
439 
440  // accumulate db
441  if (params.has_bias) {
442  //fvec_t& db = *in_grad[2];
443 
444  if (out.width_ == 1 && out.height_ == 1) {
445  for (serial_size_t outc = 0; outc < out.depth_; outc++) {
446  db[outc] += curr_delta[outc];
447  }
448  } else {
449  for (serial_size_t outc = 0; outc < out.depth_; outc++) {
450  const float *delta = &curr_delta[out.get_index(0, 0, outc)];
451  db[outc] += std::accumulate(delta, delta + out.width_ * out.height_, float(0));
452  }
453  }
454  }
455 } // avx_conv2d_5x5_back_kernel float ver
456 
457 // double ver
458 template <typename Allocator>
459 void avx_conv2d_5x5_back_kernel(const core::conv_params& params,
460  const std::vector<std::vector<double, Allocator>>& prev_out,
461  const std::vector<double, Allocator>& W,
462  std::vector<std::vector<double, Allocator>>& dW,
463  std::vector<std::vector<double, Allocator>>& db,
464  std::vector<std::vector<double, Allocator>>& curr_delta,
465  std::vector<std::vector<double, Allocator>>& prev_delta) {
466  // backward-pass fallbacks to tiny-backend at float_t == double
467  conv2d_op_internal(prev_out, W, dW, db, curr_delta, prev_delta, params, true);
468 }
469 
470 // float ver
471 template <typename Allocator>
472 void avx_conv2d_5x5_back_kernel(const core::conv_params& params,
473  const std::vector<std::vector<float, Allocator>>& prev_out,
474  const std::vector<float, Allocator>& W,
475  std::vector<std::vector<float, Allocator>>& dW,
476  std::vector<std::vector<float, Allocator>>& db,
477  std::vector<std::vector<float, Allocator>>& curr_delta,
478  std::vector<std::vector<float, Allocator>>& prev_delta) {
479  for_i(prev_out.size(), [&](int sample) {
480  avx_conv2d_5x5_back_kernel_one(params, prev_out[sample], W, dW[sample], db[sample],
481  curr_delta[sample], &prev_delta[sample]);
482  });
483 }
484 
485 
486 #endif // CNN_USE_AVX
487 
488 inline void
489 conv2d_grad_op_avx(const tensor_t& prev_out,
490  const vec_t& W,
491  tensor_t& dW,
492  tensor_t& db,
493  tensor_t& curr_delta,
494  tensor_t& prev_delta,
495  const core::conv_params& params,
496  const bool layer_parallelize) {
497 #ifdef CNN_USE_AVX
498  if (params.weight.height_ == 5 && params.weight.width_ == 5) {
499  avx_conv2d_5x5_back_kernel(params, prev_out, W, dW, db, curr_delta, prev_delta);
500  return;
501  }
502 #endif
503 
504  conv2d_op_internal(prev_out, W, dW, db, curr_delta,
505  prev_delta, params, layer_parallelize);
506 }
507 
508 } // namespace kernels
509 } // namespace tiny_dnn