tiny_dnn  1.0.0
A header only, dependency-free deep learning framework in C++11
conv2d_op_avx.h
1 /*
2  Copyright (c) 2016, Taiga Nomi, Edgar Riba
3  All rights reserved.
4 
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  * Redistributions of source code must retain the above copyright
8  notice, this list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright
10  notice, this list of conditions and the following disclaimer in the
11  documentation and/or other materials provided with the distribution.
12  * Neither the name of the <organization> nor the
13  names of its contributors may be used to endorse or promote products
14  derived from this software without specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 #pragma once
28 
29 #include <vector>
30 #include "tiny_dnn/core/params/conv_params.h"
31 #include "tiny_dnn/core/kernels/conv2d_op_internal.h"
32 
33 #ifdef CNN_USE_AVX
34 #include "tiny_dnn/core/kernels/avx_kernel_common.h"
35 #endif
36 
37 namespace tiny_dnn {
38 namespace kernels {
39 
40 #ifdef CNN_USE_AVX
41 
42 // float ver
43 template <typename Allocator>
44 void avx_conv2d_5x5_kernel(const core::conv_params& params,
45  const std::vector<float, Allocator>& in,
46  const std::vector<float, Allocator>& W,
47  const std::vector<float, Allocator>& bias,
48  std::vector<float, Allocator>& a,
49  const bool layer_parallelize) {
50  assert(params.weight.height_ == 5 && params.weight.width_ == 5);
51 
52  auto& out = params.out;
53  auto& in_padded = params.in_padded;
54  auto& tbl = params.tbl;
55  auto w_stride = params.w_stride;
56 
57  const serial_size_t out_area = out.area();
58  serial_size_t oidx = 0;
59  float bias_scale = params.has_bias ? 1.0f : 0.0f;
60  const serial_size_t stride = params.h_stride * in_padded.width_;
61  const serial_size_t inarea = in_padded.area();
62 
63  static const __m256i imask = _mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0);
64  // static const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0));
65 
66  const __m128 y_bias_scale = _mm_set_ss(bias_scale);
67  if (out.height_ == 1 && out.width_ == 1) {
68  const float* pw = (const float*)&W[0];
69  for (serial_size_t o = 0; o < out.depth_; ++o) {
70  __m256 sum0 = _mm256_setzero_ps();
71  __m256 sum1 = _mm256_setzero_ps();
72  __m256 sum2 = _mm256_setzero_ps();
73  __m128 sum3 = _mm_setzero_ps();
74  const float* pi = (const float*)&in[0];
75  for (serial_size_t inc = 0; inc < params.in.depth_; ++inc, pw += 25, pi += inarea) {
76  if (!tbl.is_connected(o, inc)) {
77  continue;
78  }
79  __m256 w0 = _mm256_loadu_ps(pw + 0);
80  __m256 w1 = _mm256_loadu_ps(pw + 8);
81  __m256 w2 = _mm256_loadu_ps(pw + 16);
82  __m256 i0 = _mm256_loadu_ps(pi + 0);
83  __m256 i1 = _mm256_loadu_ps(pi + 8);
84  __m256 i2 = _mm256_loadu_ps(pi + 16);
85  __m128 w3 = _mm_load_ss(pw + 24);
86  __m128 i3 = _mm_load_ss(pi + 24);
87  __m256 tmp0 = _mm256_mul_ps(w0, i0);
88  __m256 tmp1 = _mm256_mul_ps(w1, i1);
89  __m256 tmp2 = _mm256_mul_ps(w2, i2);
90  __m128 tmp3 = _mm_mul_ps(w3, i3);
91  sum0 = _mm256_add_ps(tmp0, sum0);
92  sum1 = _mm256_add_ps(tmp1, sum1);
93  sum2 = _mm256_add_ps(tmp2, sum2);
94  sum3 = _mm_add_ps(tmp3, sum3);
95  }
96  __m256 sum = _mm256_add_ps(_mm256_add_ps(sum0, sum1), sum2);
97  __m128 b = _mm_load_ss(&bias[o]);
98  __m128 hsum = hsum256_ps(sum);
99  b = madd128_ss(b, y_bias_scale, sum3);
100  _mm_store_ss(&a[o], _mm_add_ss(hsum, b));
101  }
102  } else {
103  const serial_size_t nblocks = out.width_ / 4;
104  for (serial_size_t o = 0; o < out.depth_; ++o, oidx += out_area) {
105  float* pa = &a[oidx];
106  // init to bias value
107  float b = bias[o] * bias_scale;
108  {
109  size_t headSize = 0;
110  __m256 b2 = _mm256_set1_ps(b);
111  if (oidx & 7) {
112  headSize = 8 - (oidx & 7);
113  assert(headSize < out_area);
114  for (size_t i=0; i<headSize; ++i) {
115  _mm_store_ss(&pa[i], _mm256_castps256_ps128(b2));
116  }
117  }
118  size_t cnt = (out_area - headSize) / 16;
119  float* pa2 = pa + headSize;
120  for (size_t i=0; i<cnt; ++i) {
121  _mm256_store_ps(&pa2[i*16+0], b2);
122  _mm256_store_ps(&pa2[i*16+8], b2);
123  }
124  for (size_t i=headSize+cnt*16; i<out_area; ++i) {
125  pa[i] = b;
126  }
127  }
128  for (serial_size_t inc = 0; inc < params.in.depth_; ++inc) {
129  if (!tbl.is_connected(o, inc)) continue;
130 
131  const float* pw = (const float*) &W[25 * (params.in.depth_ * o + inc)];
132  const float* pi = (const float*) &in[in_padded.get_index(0, 0, inc)];
133 
134  __m256 w0a = _mm256_maskload_ps(pw+0, imask);
135  __m256 w1a = _mm256_maskload_ps(pw+5, imask);
136  __m256 w2a = _mm256_maskload_ps(pw+10, imask);
137  __m256 w3a = _mm256_maskload_ps(pw+15, imask);
138  __m256 w4a = _mm256_maskload_ps(pw+20, imask);
139  __m256 w0b = leftShift<4>(w0a);
140  __m256 w1b = leftShift<4>(w1a);
141  __m256 w2b = leftShift<4>(w2a);
142  __m256 w3b = leftShift<4>(w3a);
143  __m256 w4b = leftShift<4>(w4a);
144  __m256 w0c = leftShift<8>(w0a);
145  __m256 w1c = leftShift<8>(w1a);
146  __m256 w2c = leftShift<8>(w2a);
147  __m256 w3c = leftShift<8>(w3a);
148  __m256 w4c = leftShift<8>(w4a);
149  __m256 w0d = leftShift<12>(w0a);
150  __m256 w1d = leftShift<12>(w1a);
151  __m256 w2d = leftShift<12>(w2a);
152  __m256 w3d = leftShift<12>(w3a);
153  __m256 w4d = leftShift<12>(w4a);
154  float* ppa = pa;
155  for (serial_size_t y = 0; y < out.height_; y++) {
156  const float* pi0 = (pi + y * stride);
157  const float* pi1 = pi0 + 1 * in_padded.width_;
158  const float* pi2 = pi0 + 2 * in_padded.width_;
159  const float* pi3 = pi0 + 3 * in_padded.width_;
160  const float* pi4 = pi0 + 4 * in_padded.width_;
161  serial_size_t x = 0;
162  if (w_stride == 1) {
163  __m256 dst0, dst1, dst2, dst3;
164  float* ppa2 = ppa;
165  for (size_t i = 0; i < nblocks; ++i) {
166  __m256 i0 = _mm256_loadu_ps(pi0);
167  __m256 i1 = _mm256_loadu_ps(pi1);
168  __m256 i2 = _mm256_loadu_ps(pi2);
169  __m256 i3 = _mm256_loadu_ps(pi3);
170  __m256 i4 = _mm256_loadu_ps(pi4);
171  __m128 sum = _mm_loadu_ps(ppa2);
172  dst0 = _mm256_mul_ps(w0a, i0);
173  dst1 = _mm256_mul_ps(w0b, i0);
174  dst2 = _mm256_mul_ps(w0c, i0);
175  dst3 = _mm256_mul_ps(w0d, i0);
176  dst0 = madd256_ps(w1a, i1, dst0);
177  dst1 = madd256_ps(w1b, i1, dst1);
178  dst2 = madd256_ps(w1c, i1, dst2);
179  dst3 = madd256_ps(w1d, i1, dst3);
180  dst0 = madd256_ps(w2a, i2, dst0);
181  dst1 = madd256_ps(w2b, i2, dst1);
182  dst2 = madd256_ps(w2c, i2, dst2);
183  dst3 = madd256_ps(w2d, i2, dst3);
184  dst0 = madd256_ps(w3a, i3, dst0);
185  dst1 = madd256_ps(w3b, i3, dst1);
186  dst2 = madd256_ps(w3c, i3, dst2);
187  dst3 = madd256_ps(w3d, i3, dst3);
188  dst0 = madd256_ps(w4a, i4, dst0);
189  dst1 = madd256_ps(w4b, i4, dst1);
190  __m128 hsum01 = hsum2x256_ps(dst0, dst1);
191  dst2 = madd256_ps(w4c, i4, dst2);
192  dst3 = madd256_ps(w4d, i4, dst3);
193  __m128 hsum23 = hsum2x256_ps(dst2, dst3);
194  __m128 sum2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(hsum01), _mm_castps_pd(hsum23)));
195  sum = _mm_add_ps(sum, sum2);
196  _mm_storeu_ps(ppa2, sum);
197  pi0 += 4;
198  pi1 += 4;
199  pi2 += 4;
200  pi3 += 4;
201  pi4 += 4;
202  ppa2 += 4;
203  }
204  x = nblocks * 4;
205  }
206  for (; x < out.width_; ++x) {
207  __m128 sum = _mm_load_ss(&ppa[x]);
208  __m256 i0 = _mm256_loadu_ps(pi0);
209  __m256 i1 = _mm256_loadu_ps(pi1);
210  __m256 i2 = _mm256_loadu_ps(pi2);
211  __m256 i3 = _mm256_loadu_ps(pi3);
212  __m256 i4 = _mm256_maskload_ps(pi4, imask);
213  __m256 sum0 = _mm256_mul_ps(w0a, i0);
214  __m256 sum1 = _mm256_mul_ps(w1a, i1);
215  sum0 = madd256_ps(w2a, i2, sum0);
216  sum1 = madd256_ps(w3a, i3, sum1);
217  sum0 = madd256_ps(w4a, i4, sum0);
218  sum0 = _mm256_add_ps(sum0, sum1);
219  _mm_store_ss(&ppa[x], _mm_add_ss(sum, hsum256_ps(sum0)));
220 // printf("%d %d %d %f\n", inc, y, x, ppa[x]);
221  pi0 += w_stride;
222  pi1 += w_stride;
223  pi2 += w_stride;
224  pi3 += w_stride;
225  pi4 += w_stride;
226  } // x loop
227  ppa += out.width_;
228  } // y loop
229  } // in depth loop
230  } // out depth loop
231  } // else
232 } // avx_conv2d_5x5_kernel float ver
233 
234 // double ver
235 template <typename Allocator>
236 void avx_conv2d_5x5_kernel(const core::conv_params& params,
237  const std::vector<double, Allocator>& in,
238  const std::vector<double, Allocator>& W,
239  const std::vector<double, Allocator>& bias,
240  std::vector<double, Allocator>& a,
241  const bool layer_parallelize) {
242  assert(params.weight.height_ == 5 && params.weight.width_ == 5);
243 
244  auto& out = params.out;
245  auto& in_padded = params.in_padded;
246  auto& tbl = params.tbl;
247  auto w_stride = params.w_stride;
248 
249  const size_t out_area = out.area();
250  double bias_scale = params.has_bias ? 1.0 : 0.0;
251  const __m128d y_bias_scale = _mm_set_sd(bias_scale);
252  serial_size_t oidx = 0;
253 
254  const size_t in_stride = params.h_stride * in_padded.width_;
255  const size_t in_padded_area = in_padded.area();
256 
257  if (out.height_ == 1 && out.width_ == 1) {
258  const double* pw = &W[0];
259  for (size_t o = 0; o < out.depth_; ++o) {
260  __m256d sum0 = _mm256_setzero_pd();
261  __m256d sum1 = _mm256_setzero_pd();
262  __m256d sum2 = _mm256_setzero_pd();
263  __m256d sum3 = _mm256_setzero_pd();
264  __m256d sum4 = _mm256_setzero_pd();
265  __m256d sum5 = _mm256_setzero_pd();
266  __m128d sum6 = _mm_setzero_pd();
267  size_t inidx = 0;
268  for (serial_size_t inc = 0; inc < params.in.depth_; ++inc, pw += 25, inidx += in_padded_area) {
269  if (!tbl.is_connected(o, inc)) {
270  continue;
271  }
272  __m256d w0 = _mm256_loadu_pd(pw + 0);
273  __m256d w1 = _mm256_loadu_pd(pw + 4);
274  __m256d w2 = _mm256_loadu_pd(pw + 8);
275  __m256d w3 = _mm256_loadu_pd(pw + 12);
276  __m256d w4 = _mm256_loadu_pd(pw + 16);
277  __m256d w5 = _mm256_loadu_pd(pw + 20);
278  __m128d w6 = _mm_load_sd(pw + 24);
279  const double* pi = (const double*)&in[inidx];
280  __m256d i0 = _mm256_loadu_pd(pi + 0);
281  __m256d i1 = _mm256_loadu_pd(pi + 4);
282  __m256d i2 = _mm256_loadu_pd(pi + 8);
283  __m256d i3 = _mm256_loadu_pd(pi + 12);
284  __m256d i4 = _mm256_loadu_pd(pi + 16);
285  __m256d i5 = _mm256_loadu_pd(pi + 20);
286  __m128d i6 = _mm_load_sd(pi + 24);
287  __m256d tmp0 = _mm256_mul_pd(w0, i0);
288  __m256d tmp1 = _mm256_mul_pd(w1, i1);
289  __m256d tmp2 = _mm256_mul_pd(w2, i2);
290  __m256d tmp3 = _mm256_mul_pd(w3, i3);
291  __m256d tmp4 = _mm256_mul_pd(w4, i4);
292  __m256d tmp5 = _mm256_mul_pd(w5, i5);
293  __m128d tmp6 = _mm_mul_pd(w6, i6);
294  sum0 = _mm256_add_pd(tmp0, sum0);
295  sum1 = _mm256_add_pd(tmp1, sum1);
296  sum2 = _mm256_add_pd(tmp2, sum2);
297  sum3 = _mm256_add_pd(tmp3, sum3);
298  sum4 = _mm256_add_pd(tmp4, sum4);
299  sum5 = _mm256_add_pd(tmp5, sum5);
300  sum6 = _mm_add_pd(tmp6, sum6);
301  }
302  sum0 = _mm256_add_pd(sum0, sum1);
303  sum2 = _mm256_add_pd(sum2, sum3);
304  sum4 = _mm256_add_pd(sum4, sum5);
305  sum0 = _mm256_add_pd(sum0, sum2);
306  __m256d sum = _mm256_add_pd(sum0, sum4);
307  __m128d b = _mm_load_sd(&bias[o]);
308  __m128d hsum = hsum256_pd(sum);
309  b = madd128_sd(b, y_bias_scale, sum6);
310  _mm_store_sd(&a[o], _mm_add_sd(hsum, b));
311  }
312  } else {
313  for (serial_size_t o = 0; o < out.depth_; ++o, oidx += out_area) {
314  double* pa = &a[oidx];
315  double b = bias[o] * bias_scale;
316  {
317  size_t headSize = 0;
318  __m256d b2 = _mm256_set1_pd(b);
319  if (oidx & 3) {
320  headSize = 4 - (oidx & 3);
321  assert(headSize < out_area);
322  for (size_t i = 0; i < headSize; ++i) {
323  _mm_store_sd(&pa[i], _mm256_castpd256_pd128(b2));
324  }
325  }
326  size_t cnt = (out_area - headSize) / 8;
327  double* pa2 = pa + headSize;
328  for (size_t i = 0; i < cnt; ++i) {
329  _mm256_store_pd(&pa2[i*8+0], b2);
330  _mm256_store_pd(&pa2[i*8+4], b2);
331  }
332  for (size_t i = headSize + cnt*8; i < out_area; ++i) {
333  _mm_store_sd(&pa[i], _mm256_castpd256_pd128(b2));
334  }
335  }
336 
337  for (serial_size_t inc = 0; inc < params.in.depth_; ++inc) {
338  if (!tbl.is_connected(o, inc)) continue;
339 
340  const double* pw = (const double*)&W[25 * (params.in.depth_ * o + inc)];
341  const double* pi = &in[in_padded.get_index(0, 0, inc)];
342 
343  __m256d w0a = _mm256_loadu_pd(pw+0);
344  __m128d w0b = _mm_load_sd(pw+4);
345  __m256d w1a = _mm256_loadu_pd(pw+5);
346  __m128d w1b = _mm_load_sd(pw+9);
347  __m256d w2a = _mm256_loadu_pd(pw+10);
348  __m128d w2b = _mm_load_sd(pw+14);
349  __m256d w3a = _mm256_loadu_pd(pw+15);
350  __m128d w3b = _mm_load_sd(pw+19);
351  __m256d w4a = _mm256_loadu_pd(pw+20);
352  __m128d w4b = _mm_load_sd(pw+24);
353 
354  double* ppa = pa;
355  for (serial_size_t y = 0; y < out.height_; ++y, pi += in_stride, ppa += out.width_) {
356  const double* pi0 = pi + 0 * in_padded.width_;
357  const double* pi1 = pi + 1 * in_padded.width_;
358  const double* pi2 = pi + 2 * in_padded.width_;
359  const double* pi3 = pi + 3 * in_padded.width_;
360  const double* pi4 = pi + 4 * in_padded.width_;
361  for (serial_size_t x = 0; x < out.width_; ++x) {
362  __m128d sum = _mm_load_sd(&ppa[x]);
363  __m256d i0a = _mm256_loadu_pd(pi0);
364  __m128d i0b = _mm_load_sd(pi0 + 4);
365  __m256d i1a = _mm256_loadu_pd(pi1);
366  __m128d i1b = _mm_load_sd(pi1 + 4);
367  __m256d i2a = _mm256_loadu_pd(pi2);
368  __m128d i2b = _mm_load_sd(pi2 + 4);
369  __m256d i3a = _mm256_loadu_pd(pi3);
370  __m128d i3b = _mm_load_sd(pi3 + 4);
371  __m256d i4a = _mm256_loadu_pd(pi4);
372  __m128d i4b = _mm_load_sd(pi4 + 4);
373  __m256d sum_a = _mm256_mul_pd(w0a, i0a);
374  __m128d sum_b = _mm_mul_sd(w0b, i0b);
375  sum_a = madd256_pd(w1a, i1a, sum_a);
376  sum_b = madd128_pd(w1b, i1b, sum_b);
377  sum_a = madd256_pd(w2a, i2a, sum_a);
378  sum_b = madd128_pd(w2b, i2b, sum_b);
379  sum_a = madd256_pd(w3a, i3a, sum_a);
380  sum_b = madd128_pd(w3b, i3b, sum_b);
381  sum_a = madd256_pd(w4a, i4a, sum_a);
382  sum_b = madd128_pd(w4b, i4b, sum_b);
383  __m128d sum_c = hsum256_pd(sum_a);
384  sum = _mm_add_sd(sum, sum_b);
385  _mm_store_sd(&ppa[x], _mm_add_sd(sum, sum_c));
386  pi0 += w_stride;
387  pi1 += w_stride;
388  pi2 += w_stride;
389  pi3 += w_stride;
390  pi4 += w_stride;
391  } // x loop
392  } // y loop
393  } // in depth loop
394  } // out depth loop
395  } // else
396 } // avx_conv2d_5x5_kernel double ver
397 
398 #endif // CNN_USE_AVX
399 
400 inline void conv2d_op_avx(const tensor_t& in_data,
401  const vec_t& W,
402  const vec_t& bias,
403  tensor_t& out_data,
404  const core::conv_params& params,
405  const bool layer_parallelize) {
406 #ifdef CNN_USE_AVX
407  if (params.weight.height_ == 5 && params.weight.width_ == 5) {
408  // @todo consider better parallelization
409  for_i(layer_parallelize, in_data.size(), [&](int i) {
410  avx_conv2d_5x5_kernel(params, in_data[i], W, bias, out_data[i], layer_parallelize);
411  });
412  return;
413  }
414 #endif
415  conv2d_op_internal(in_data, W, bias, out_data, params, layer_parallelize);
416 }
417 
418 } // namespace kernels
419 } // namespace tiny_dnn