30 #include "tiny_dnn/core/params/conv_params.h"
31 #include "tiny_dnn/core/kernels/conv2d_op_internal.h"
34 #include "tiny_dnn/core/kernels/avx_kernel_common.h"
43 template <
typename Allocator>
44 void avx_conv2d_5x5_back_kernel_one(
const core::conv_params& params,
45 const std::vector<float, Allocator>& prev_out,
46 const std::vector<float, Allocator>& W,
47 std::vector<float, Allocator>& dW,
48 std::vector<float, Allocator>& db,
49 std::vector<float, Allocator>& curr_delta,
50 std::vector<float, Allocator>* prev_delta) {
52 auto& out = params.out;
53 auto& in_padded = params.in_padded;
54 auto& tbl = params.tbl;
55 auto w_stride = params.w_stride;
56 const size_t in_padded_area = in_padded.area();
57 float* pdelta_dst_org = &(*prev_delta)[0];
58 const size_t h_stride2 = params.h_stride * in_padded.width_;
59 static const __m256i imask = _mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0);
60 static const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0));
62 if (w_stride == 1 && out.width_ >= 4) {
63 const serial_size_t nblocks = out.width_ / 4;
64 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
65 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
66 if (!tbl.is_connected(outc, inc))
continue;
67 const float* pw = &W[25 * (in.depth_ * outc + inc)];
68 const float* pdelta_src = &curr_delta[out.get_index(0, 0, outc)];
69 float* pdelta_dst = pdelta_dst_org;
70 __m256 w0a = _mm256_and_ps(_mm256_loadu_ps(pw+0), mask);
71 __m256 w1a = _mm256_and_ps(_mm256_loadu_ps(pw+5), mask);
72 __m256 w2a = _mm256_and_ps(_mm256_loadu_ps(pw+10), mask);
73 __m256 w3a = _mm256_and_ps(_mm256_loadu_ps(pw+15), mask);
74 __m256 w4a = _mm256_and_ps(_mm256_loadu_ps(pw+20), mask);
75 __m256 w0b = leftShift<4>(w0a);
76 __m256 w1b = leftShift<4>(w1a);
77 __m256 w2b = leftShift<4>(w2a);
78 __m256 w3b = leftShift<4>(w3a);
79 __m256 w4b = leftShift<4>(w4a);
80 __m256 w0c = leftShift<8>(w0a);
81 __m256 w1c = leftShift<8>(w1a);
82 __m256 w2c = leftShift<8>(w2a);
83 __m256 w3c = leftShift<8>(w3a);
84 __m256 w4c = leftShift<8>(w4a);
85 __m256 w0d = leftShift<12>(w0a);
86 __m256 w1d = leftShift<12>(w1a);
87 __m256 w2d = leftShift<12>(w2a);
88 __m256 w3d = leftShift<12>(w3a);
89 __m256 w4d = leftShift<12>(w4a);
90 for (serial_size_t y = 0; y < out.height_; y++) {
91 const float* pdelta_src2 = pdelta_src;
92 float* delta_dst0 = pdelta_dst;
93 float* delta_dst1 = &pdelta_dst[in_padded.width_ * 1];
94 float* delta_dst2 = &pdelta_dst[in_padded.width_ * 2];
95 float* delta_dst3 = &pdelta_dst[in_padded.width_ * 3];
96 float* delta_dst4 = &pdelta_dst[in_padded.width_ * 4];
97 for (serial_size_t n = 0; n < nblocks; ++n) {
98 __m256 delta_src = _mm256_broadcast_ps((
const __m128*)pdelta_src2);
99 __m256 dst0 = _mm256_loadu_ps(delta_dst0 + 4 * n);
100 __m256 dst1 = _mm256_loadu_ps(delta_dst1 + 4 * n);
101 __m256 dst2 = _mm256_loadu_ps(delta_dst2 + 4 * n);
102 __m256 dst3 = _mm256_loadu_ps(delta_dst3 + 4 * n);
103 __m256 dst4 = _mm256_loadu_ps(delta_dst4 + 4 * n);
104 __m256 delta_src0 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(0, 0, 0, 0));
105 __m256 delta_src1 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(1, 1, 1, 1));
106 __m256 delta_src2 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(2, 2, 2, 2));
107 __m256 delta_src3 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(3, 3, 3, 3));
108 dst0 = madd256_ps(w0a, delta_src0, dst0);
109 dst1 = madd256_ps(w1a, delta_src0, dst1);
110 dst2 = madd256_ps(w2a, delta_src0, dst2);
111 dst3 = madd256_ps(w3a, delta_src0, dst3);
112 dst4 = madd256_ps(w4a, delta_src0, dst4);
113 dst0 = madd256_ps(w0b, delta_src1, dst0);
114 dst1 = madd256_ps(w1b, delta_src1, dst1);
115 dst2 = madd256_ps(w2b, delta_src1, dst2);
116 dst3 = madd256_ps(w3b, delta_src1, dst3);
117 dst4 = madd256_ps(w4b, delta_src1, dst4);
118 dst0 = madd256_ps(w0c, delta_src2, dst0);
119 dst1 = madd256_ps(w1c, delta_src2, dst1);
120 dst2 = madd256_ps(w2c, delta_src2, dst2);
121 dst3 = madd256_ps(w3c, delta_src2, dst3);
122 dst4 = madd256_ps(w4c, delta_src2, dst4);
123 dst0 = madd256_ps(w0d, delta_src3, dst0);
124 _mm256_storeu_ps(delta_dst0 + 4 * n, dst0);
125 dst1 = madd256_ps(w1d, delta_src3, dst1);
126 _mm256_storeu_ps(delta_dst1 + 4 * n, dst1);
127 dst2 = madd256_ps(w2d, delta_src3, dst2);
128 _mm256_storeu_ps(delta_dst2 + 4 * n, dst2);
129 dst3 = madd256_ps(w3d, delta_src3, dst3);
130 _mm256_storeu_ps(delta_dst3 + 4 * n, dst3);
131 dst4 = madd256_ps(w4d, delta_src3, dst4);
132 _mm256_storeu_ps(delta_dst4 + 4 * n, dst4);
135 for (serial_size_t x = nblocks * 4; x < out.width_; x++) {
136 __m256 delta_src = _mm256_broadcast_ss(pdelta_src + x);
137 __m256 dst0 = _mm256_loadu_ps(delta_dst0 + x);
138 __m256 dst1 = _mm256_loadu_ps(delta_dst1 + x);
139 __m256 dst2 = _mm256_loadu_ps(delta_dst2 + x);
140 __m256 dst3 = _mm256_loadu_ps(delta_dst3 + x);
141 __m256 dst4 = _mm256_loadu_ps(delta_dst4 + x);
142 dst0 = madd256_ps(w0a, delta_src, dst0);
143 dst1 = madd256_ps(w1a, delta_src, dst1);
144 dst2 = madd256_ps(w2a, delta_src, dst2);
145 dst3 = madd256_ps(w3a, delta_src, dst3);
146 dst4 = madd256_ps(w4a, delta_src, dst4);
147 _mm256_storeu_ps(delta_dst0 + x, dst0);
148 _mm256_storeu_ps(delta_dst1 + x, dst1);
149 _mm256_storeu_ps(delta_dst2 + x, dst2);
150 _mm256_storeu_ps(delta_dst3 + x, dst3);
151 _mm256_storeu_ps(delta_dst4 + x, dst4);
153 pdelta_src += out.width_;
154 pdelta_dst += h_stride2;
158 }
else if (out.height_ == 1 && out.width_ == 1) {
159 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
160 float* delta_dst0 = pdelta_dst_org;
161 float* delta_dst1 = &pdelta_dst_org[in_padded.width_ * 1];
162 float* delta_dst2 = &pdelta_dst_org[in_padded.width_ * 2];
163 float* delta_dst3 = &pdelta_dst_org[in_padded.width_ * 3];
164 float* delta_dst4 = &pdelta_dst_org[in_padded.width_ * 4];
165 __m256 dst0 = _mm256_loadu_ps(delta_dst0);
166 __m256 dst1 = _mm256_loadu_ps(delta_dst1);
167 __m256 dst2 = _mm256_loadu_ps(delta_dst2);
168 __m256 dst3 = _mm256_loadu_ps(delta_dst3);
169 __m256 dst4 = _mm256_maskload_ps(delta_dst4, imask);
183 __m256 sum0 = _mm256_blend_ps(
188 __m256 sum1 = _mm256_blend_ps(
190 _mm256_blend_ps(leftShift<8>(dst2), rightShift<12>(dst1), 0x03 ),
193 __m256 sum2 = _mm256_blend_ps(
198 __m128 sum3 = _mm256_extractf128_ps(dst4, 1);
200 size_t widx = 25 * inc;
201 size_t wstep = 25 * in.depth_;
203 if (tbl.is_empty()) {
204 for (serial_size_t outc = 0; outc < out.depth_; outc++, widx+=wstep) {
205 __m256 delta_src = _mm256_broadcast_ss(&curr_delta[outc]);
206 const float* pw = (
const float*)&W[widx];
207 __m256 w0 = _mm256_loadu_ps(pw+0);
208 __m256 w1 = _mm256_loadu_ps(pw + 8);
209 __m256 w2 = _mm256_loadu_ps(pw + 16);
210 __m128 w3 = _mm_load_ss(pw + 24);
211 sum0 = madd256_ps(w0, delta_src, sum0);
212 sum1 = madd256_ps(w1, delta_src, sum1);
213 sum2 = madd256_ps(w2, delta_src, sum2);
214 sum3 = madd128_ss(w3, _mm256_castps256_ps128(delta_src), sum3);
218 for (serial_size_t outc = 0; outc < out.depth_; outc++, widx += wstep) {
219 if (!tbl.is_connected(outc, inc)) {
222 __m256 delta_src = _mm256_broadcast_ss(&curr_delta[outc]);
223 const float* pw = (
const float*)&W[widx];
224 __m256 w0 = _mm256_loadu_ps(pw + 0);
225 __m256 w1 = _mm256_loadu_ps(pw + 8);
226 __m256 w2 = _mm256_loadu_ps(pw + 16);
227 __m128 w3 = _mm_load_ss(pw + 24);
228 sum0 = madd256_ps(w0, delta_src, sum0);
229 sum1 = madd256_ps(w1, delta_src, sum1);
230 sum2 = madd256_ps(w2, delta_src, sum2);
231 sum3 = madd128_ss(w3, _mm256_castps256_ps128(delta_src), sum3);
247 dst0 = _mm256_blend_ps(
252 dst1 = _mm256_blend_ps(
255 rightShift<20>(sum0),
260 dst2 = _mm256_blend_ps(
265 dst3 = _mm256_blend_ps(
268 rightShift<28>(sum1),
273 dst4 = _mm256_blend_ps(
277 _mm256_extractf128_ps(sum2, 1)
282 _mm256_storeu_ps(delta_dst0, dst0);
283 _mm256_storeu_ps(delta_dst1, dst1);
284 _mm256_storeu_ps(delta_dst2, dst2);
285 _mm256_storeu_ps(delta_dst3, dst3);
286 _mm256_maskstore_ps(delta_dst4, imask, dst4);
289 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
290 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
291 if (!tbl.is_connected(outc, inc))
continue;
293 const float* pw = &W[25 * (in.depth_ * outc + inc)];
294 const float* pdelta_src = &curr_delta[out.get_index(0, 0, outc)];
295 float* pdelta_dst = pdelta_dst_org;
296 __m256 w0a = _mm256_maskload_ps(pw+0, imask);
297 __m256 w1a = _mm256_maskload_ps(pw+5, imask);
298 __m256 w2a = _mm256_maskload_ps(pw+10, imask);
299 __m256 w3a = _mm256_maskload_ps(pw+15, imask);
300 __m256 w4a = _mm256_maskload_ps(pw+20, imask);
301 for (serial_size_t y = 0; y < out.height_; y++) {
302 float* delta_dst0 = pdelta_dst;
303 float* delta_dst1 = &pdelta_dst[in_padded.width_ * 1];
304 float* delta_dst2 = &pdelta_dst[in_padded.width_ * 2];
305 float* delta_dst3 = &pdelta_dst[in_padded.width_ * 3];
306 float* delta_dst4 = &pdelta_dst[in_padded.width_ * 4];
307 for (serial_size_t x = 0; x < out.width_; x++) {
308 __m256 delta_src = _mm256_broadcast_ss(pdelta_src + x);
309 __m256 dst0 = _mm256_loadu_ps(delta_dst0);
310 __m256 dst1 = _mm256_loadu_ps(delta_dst1);
311 __m256 dst2 = _mm256_loadu_ps(delta_dst2);
312 __m256 dst3 = _mm256_loadu_ps(delta_dst3);
313 __m256 dst4 = _mm256_maskload_ps(delta_dst4, imask);
314 dst0 = madd256_ps(w0a, delta_src, dst0);
315 dst1 = madd256_ps(w1a, delta_src, dst1);
316 dst2 = madd256_ps(w2a, delta_src, dst2);
317 dst3 = madd256_ps(w3a, delta_src, dst3);
318 dst4 = madd256_ps(w4a, delta_src, dst4);
319 _mm256_storeu_ps(delta_dst0, dst0);
320 _mm256_storeu_ps(delta_dst1, dst1);
321 _mm256_storeu_ps(delta_dst2, dst2);
322 _mm256_storeu_ps(delta_dst3, dst3);
323 _mm256_maskstore_ps(delta_dst4, imask, dst4);
324 delta_dst0 += w_stride;
325 delta_dst1 += w_stride;
326 delta_dst2 += w_stride;
327 delta_dst3 += w_stride;
328 delta_dst4 += w_stride;
330 pdelta_src += out.width_;
331 pdelta_dst += h_stride2;
338 if (out.width_ == 1 && out.height_ == 1) {
339 const float* pprev_out = &prev_out[0];
340 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pprev_out += in_padded_area) {
341 VECTORIZE_ALIGN(32) float floats[28];
342 size_t in_padded_width = in_padded.width_;
343 _mm256_store_ps(&floats[0], _mm256_loadu_ps(pprev_out + in_padded_width * 0));
344 _mm256_storeu_ps(&floats[5], _mm256_loadu_ps(pprev_out + in_padded_width * 1));
345 _mm256_storeu_ps(&floats[10], _mm256_loadu_ps(pprev_out + in_padded_width * 2));
346 _mm256_storeu_ps(&floats[15], _mm256_loadu_ps(pprev_out + in_padded_width * 3));
347 _mm256_storeu_ps(&floats[20], _mm256_maskload_ps(pprev_out + in_padded_width * 4, imask));
348 __m256 prevos0 = _mm256_load_ps(&floats[0]);
349 __m256 prevos1 = _mm256_load_ps(&floats[8]);
350 __m256 prevos2 = _mm256_load_ps(&floats[16]);
351 __m128 prevos3 = _mm_load_ss(&floats[24]);
352 serial_size_t widx = 25 * inc;
353 serial_size_t widx_delta = 25 * in.depth_;
354 float* pdW = &dW[widx];
355 for (serial_size_t outc = 0; outc < out.depth_; outc++, pdW += widx_delta) {
356 if (!tbl.is_connected(outc, inc)) {
359 __m256 delta = _mm256_broadcast_ss(&curr_delta[outc]);
360 __m256 w0 = _mm256_loadu_ps(pdW+0);
361 __m256 w1 = _mm256_loadu_ps(pdW+8);
362 __m256 w2 = _mm256_loadu_ps(pdW + 16);
363 __m128 w3 = _mm_load_ss(pdW + 24);
364 w0 = madd256_ps(prevos0, delta, w0);
365 w1 = madd256_ps(prevos1, delta, w1);
366 w2 = madd256_ps(prevos2, delta, w2);
367 w3 = madd128_ss(prevos3, _mm256_castps256_ps128(delta), w3);
368 _mm256_storeu_ps(pdW + 0, w0);
369 _mm256_storeu_ps(pdW + 8, w1);
370 _mm256_storeu_ps(pdW+16, w2);
371 _mm_store_ss(pdW+24, w3);
376 const size_t nblocks = out.width_ >> 3;
377 static const int32_t masks[] = {
383 const size_t remainder = out.width_ & 7;
384 __m256i mask = _mm256_loadu_si256((
const __m256i*)(masks + 8 - remainder));
385 auto& weight = params.weight;
386 for (serial_size_t inc = 0; inc < in.depth_; ++inc) {
387 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
389 if (!tbl.is_connected(outc, inc))
continue;
390 const float* delta = &curr_delta[out.get_index(0, 0, outc)];
392 serial_size_t widx = weight.get_index(0, 0, in.depth_ * outc + inc);
393 for (serial_size_t wy = 0; wy < 5 ; wy++) {
394 for (serial_size_t wx = 0; wx < 5 ; wx++) {
395 const float* prevo = &prev_out[in_padded.get_index(wx, wy, inc)];
398 float_t dst = float_t(0);
400 for (serial_size_t y = 0; y < params.out.height_; y++) {
401 serial_size_t prevo_idx = y * params.in_padded.width_ * params.h_stride;
402 serial_size_t delta_idx = y * params.out.width_;
404 for (serial_size_t x = 0; x < params.out.width_; x++) {
405 dst += prevo[prevo_idx + x * params.w_stride] * delta[delta_idx + x];
411 __m128 prev_sum = _mm_load_ss(&dW[widx]);
412 __m256 sum0 = _mm256_setzero_ps();
413 __m256 sum1 = _mm256_setzero_ps();
414 for (serial_size_t y = 0; y < out.height_; y++) {
416 const float* pa = prevo + y * in_padded.width_ * params.h_stride;
417 const float* pb = delta + y * out.width_;
418 for (
size_t i = 0; i < nblocks; ++i) {
419 __m256 a = _mm256_loadu_ps(pa + 8 * i);
420 __m256 b = _mm256_loadu_ps(pb + 8 * i);
421 sum0 = madd256_ps(a, b, sum0);
424 __m256 a = _mm256_maskload_ps(pa + 8 * nblocks, mask);
425 __m256 b = _mm256_maskload_ps(pb + 8 * nblocks, mask);
426 sum1 = madd256_ps(a, b, sum1);
429 sum1 = _mm256_and_ps(sum1, _mm256_castsi256_ps(mask));
430 __m256 sum = _mm256_add_ps(sum0, sum1);
431 _mm_store_ss(&dW[widx], _mm_add_ps(prev_sum, hsum256_ps(sum)));
441 if (params.has_bias) {
444 if (out.width_ == 1 && out.height_ == 1) {
445 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
446 db[outc] += curr_delta[outc];
449 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
450 const float *delta = &curr_delta[out.get_index(0, 0, outc)];
451 db[outc] += std::accumulate(delta, delta + out.width_ * out.height_,
float(0));
458 template <
typename Allocator>
459 void avx_conv2d_5x5_back_kernel(
const core::conv_params& params,
460 const std::vector<std::vector<double, Allocator>>& prev_out,
461 const std::vector<double, Allocator>& W,
462 std::vector<std::vector<double, Allocator>>& dW,
463 std::vector<std::vector<double, Allocator>>& db,
464 std::vector<std::vector<double, Allocator>>& curr_delta,
465 std::vector<std::vector<double, Allocator>>& prev_delta) {
467 conv2d_op_internal(prev_out, W, dW, db, curr_delta, prev_delta, params,
true);
471 template <
typename Allocator>
472 void avx_conv2d_5x5_back_kernel(
const core::conv_params& params,
473 const std::vector<std::vector<float, Allocator>>& prev_out,
474 const std::vector<float, Allocator>& W,
475 std::vector<std::vector<float, Allocator>>& dW,
476 std::vector<std::vector<float, Allocator>>& db,
477 std::vector<std::vector<float, Allocator>>& curr_delta,
478 std::vector<std::vector<float, Allocator>>& prev_delta) {
479 for_i(prev_out.size(), [&](
int sample) {
480 avx_conv2d_5x5_back_kernel_one(params, prev_out[sample], W, dW[sample], db[sample],
481 curr_delta[sample], &prev_delta[sample]);
489 conv2d_grad_op_avx(
const tensor_t& prev_out,
493 tensor_t& curr_delta,
494 tensor_t& prev_delta,
495 const core::conv_params& params,
496 const bool layer_parallelize) {
498 if (params.weight.height_ == 5 && params.weight.width_ == 5) {
499 avx_conv2d_5x5_back_kernel(params, prev_out, W, dW, db, curr_delta, prev_delta);
504 conv2d_op_internal(prev_out, W, dW, db, curr_delta,
505 prev_delta, params, layer_parallelize);