30 #include "tiny_dnn/core/params/conv_params.h"
31 #include "tiny_dnn/core/kernels/conv2d_op_internal.h"
34 #include "tiny_dnn/core/kernels/avx_kernel_common.h"
43 template <
typename Allocator>
44 void avx_conv2d_5x5_kernel(
const core::conv_params& params,
45 const std::vector<float, Allocator>& in,
46 const std::vector<float, Allocator>& W,
47 const std::vector<float, Allocator>& bias,
48 std::vector<float, Allocator>& a,
49 const bool layer_parallelize) {
50 assert(params.weight.height_ == 5 && params.weight.width_ == 5);
52 auto& out = params.out;
53 auto& in_padded = params.in_padded;
54 auto& tbl = params.tbl;
55 auto w_stride = params.w_stride;
57 const serial_size_t out_area = out.area();
58 serial_size_t oidx = 0;
59 float bias_scale = params.has_bias ? 1.0f : 0.0f;
60 const serial_size_t stride = params.h_stride * in_padded.width_;
61 const serial_size_t inarea = in_padded.area();
63 static const __m256i imask = _mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0);
66 const __m128 y_bias_scale = _mm_set_ss(bias_scale);
67 if (out.height_ == 1 && out.width_ == 1) {
68 const float* pw = (
const float*)&W[0];
69 for (serial_size_t o = 0; o < out.depth_; ++o) {
70 __m256 sum0 = _mm256_setzero_ps();
71 __m256 sum1 = _mm256_setzero_ps();
72 __m256 sum2 = _mm256_setzero_ps();
73 __m128 sum3 = _mm_setzero_ps();
74 const float* pi = (
const float*)&in[0];
75 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc, pw += 25, pi += inarea) {
76 if (!tbl.is_connected(o, inc)) {
79 __m256 w0 = _mm256_loadu_ps(pw + 0);
80 __m256 w1 = _mm256_loadu_ps(pw + 8);
81 __m256 w2 = _mm256_loadu_ps(pw + 16);
82 __m256 i0 = _mm256_loadu_ps(pi + 0);
83 __m256 i1 = _mm256_loadu_ps(pi + 8);
84 __m256 i2 = _mm256_loadu_ps(pi + 16);
85 __m128 w3 = _mm_load_ss(pw + 24);
86 __m128 i3 = _mm_load_ss(pi + 24);
87 __m256 tmp0 = _mm256_mul_ps(w0, i0);
88 __m256 tmp1 = _mm256_mul_ps(w1, i1);
89 __m256 tmp2 = _mm256_mul_ps(w2, i2);
90 __m128 tmp3 = _mm_mul_ps(w3, i3);
91 sum0 = _mm256_add_ps(tmp0, sum0);
92 sum1 = _mm256_add_ps(tmp1, sum1);
93 sum2 = _mm256_add_ps(tmp2, sum2);
94 sum3 = _mm_add_ps(tmp3, sum3);
96 __m256 sum = _mm256_add_ps(_mm256_add_ps(sum0, sum1), sum2);
97 __m128 b = _mm_load_ss(&bias[o]);
98 __m128 hsum = hsum256_ps(sum);
99 b = madd128_ss(b, y_bias_scale, sum3);
100 _mm_store_ss(&a[o], _mm_add_ss(hsum, b));
103 const serial_size_t nblocks = out.width_ / 4;
104 for (serial_size_t o = 0; o < out.depth_; ++o, oidx += out_area) {
105 float* pa = &a[oidx];
107 float b = bias[o] * bias_scale;
110 __m256 b2 = _mm256_set1_ps(b);
112 headSize = 8 - (oidx & 7);
113 assert(headSize < out_area);
114 for (
size_t i=0; i<headSize; ++i) {
115 _mm_store_ss(&pa[i], _mm256_castps256_ps128(b2));
118 size_t cnt = (out_area - headSize) / 16;
119 float* pa2 = pa + headSize;
120 for (
size_t i=0; i<cnt; ++i) {
121 _mm256_store_ps(&pa2[i*16+0], b2);
122 _mm256_store_ps(&pa2[i*16+8], b2);
124 for (
size_t i=headSize+cnt*16; i<out_area; ++i) {
128 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc) {
129 if (!tbl.is_connected(o, inc))
continue;
131 const float* pw = (
const float*) &W[25 * (params.in.depth_ * o + inc)];
132 const float* pi = (
const float*) &in[in_padded.get_index(0, 0, inc)];
134 __m256 w0a = _mm256_maskload_ps(pw+0, imask);
135 __m256 w1a = _mm256_maskload_ps(pw+5, imask);
136 __m256 w2a = _mm256_maskload_ps(pw+10, imask);
137 __m256 w3a = _mm256_maskload_ps(pw+15, imask);
138 __m256 w4a = _mm256_maskload_ps(pw+20, imask);
139 __m256 w0b = leftShift<4>(w0a);
140 __m256 w1b = leftShift<4>(w1a);
141 __m256 w2b = leftShift<4>(w2a);
142 __m256 w3b = leftShift<4>(w3a);
143 __m256 w4b = leftShift<4>(w4a);
144 __m256 w0c = leftShift<8>(w0a);
145 __m256 w1c = leftShift<8>(w1a);
146 __m256 w2c = leftShift<8>(w2a);
147 __m256 w3c = leftShift<8>(w3a);
148 __m256 w4c = leftShift<8>(w4a);
149 __m256 w0d = leftShift<12>(w0a);
150 __m256 w1d = leftShift<12>(w1a);
151 __m256 w2d = leftShift<12>(w2a);
152 __m256 w3d = leftShift<12>(w3a);
153 __m256 w4d = leftShift<12>(w4a);
155 for (serial_size_t y = 0; y < out.height_; y++) {
156 const float* pi0 = (pi + y * stride);
157 const float* pi1 = pi0 + 1 * in_padded.width_;
158 const float* pi2 = pi0 + 2 * in_padded.width_;
159 const float* pi3 = pi0 + 3 * in_padded.width_;
160 const float* pi4 = pi0 + 4 * in_padded.width_;
163 __m256 dst0, dst1, dst2, dst3;
165 for (
size_t i = 0; i < nblocks; ++i) {
166 __m256 i0 = _mm256_loadu_ps(pi0);
167 __m256 i1 = _mm256_loadu_ps(pi1);
168 __m256 i2 = _mm256_loadu_ps(pi2);
169 __m256 i3 = _mm256_loadu_ps(pi3);
170 __m256 i4 = _mm256_loadu_ps(pi4);
171 __m128 sum = _mm_loadu_ps(ppa2);
172 dst0 = _mm256_mul_ps(w0a, i0);
173 dst1 = _mm256_mul_ps(w0b, i0);
174 dst2 = _mm256_mul_ps(w0c, i0);
175 dst3 = _mm256_mul_ps(w0d, i0);
176 dst0 = madd256_ps(w1a, i1, dst0);
177 dst1 = madd256_ps(w1b, i1, dst1);
178 dst2 = madd256_ps(w1c, i1, dst2);
179 dst3 = madd256_ps(w1d, i1, dst3);
180 dst0 = madd256_ps(w2a, i2, dst0);
181 dst1 = madd256_ps(w2b, i2, dst1);
182 dst2 = madd256_ps(w2c, i2, dst2);
183 dst3 = madd256_ps(w2d, i2, dst3);
184 dst0 = madd256_ps(w3a, i3, dst0);
185 dst1 = madd256_ps(w3b, i3, dst1);
186 dst2 = madd256_ps(w3c, i3, dst2);
187 dst3 = madd256_ps(w3d, i3, dst3);
188 dst0 = madd256_ps(w4a, i4, dst0);
189 dst1 = madd256_ps(w4b, i4, dst1);
190 __m128 hsum01 = hsum2x256_ps(dst0, dst1);
191 dst2 = madd256_ps(w4c, i4, dst2);
192 dst3 = madd256_ps(w4d, i4, dst3);
193 __m128 hsum23 = hsum2x256_ps(dst2, dst3);
194 __m128 sum2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(hsum01), _mm_castps_pd(hsum23)));
195 sum = _mm_add_ps(sum, sum2);
196 _mm_storeu_ps(ppa2, sum);
206 for (; x < out.width_; ++x) {
207 __m128 sum = _mm_load_ss(&ppa[x]);
208 __m256 i0 = _mm256_loadu_ps(pi0);
209 __m256 i1 = _mm256_loadu_ps(pi1);
210 __m256 i2 = _mm256_loadu_ps(pi2);
211 __m256 i3 = _mm256_loadu_ps(pi3);
212 __m256 i4 = _mm256_maskload_ps(pi4, imask);
213 __m256 sum0 = _mm256_mul_ps(w0a, i0);
214 __m256 sum1 = _mm256_mul_ps(w1a, i1);
215 sum0 = madd256_ps(w2a, i2, sum0);
216 sum1 = madd256_ps(w3a, i3, sum1);
217 sum0 = madd256_ps(w4a, i4, sum0);
218 sum0 = _mm256_add_ps(sum0, sum1);
219 _mm_store_ss(&ppa[x], _mm_add_ss(sum, hsum256_ps(sum0)));
235 template <
typename Allocator>
236 void avx_conv2d_5x5_kernel(
const core::conv_params& params,
237 const std::vector<double, Allocator>& in,
238 const std::vector<double, Allocator>& W,
239 const std::vector<double, Allocator>& bias,
240 std::vector<double, Allocator>& a,
241 const bool layer_parallelize) {
242 assert(params.weight.height_ == 5 && params.weight.width_ == 5);
244 auto& out = params.out;
245 auto& in_padded = params.in_padded;
246 auto& tbl = params.tbl;
247 auto w_stride = params.w_stride;
249 const size_t out_area = out.area();
250 double bias_scale = params.has_bias ? 1.0 : 0.0;
251 const __m128d y_bias_scale = _mm_set_sd(bias_scale);
252 serial_size_t oidx = 0;
254 const size_t in_stride = params.h_stride * in_padded.width_;
255 const size_t in_padded_area = in_padded.area();
257 if (out.height_ == 1 && out.width_ == 1) {
258 const double* pw = &W[0];
259 for (
size_t o = 0; o < out.depth_; ++o) {
260 __m256d sum0 = _mm256_setzero_pd();
261 __m256d sum1 = _mm256_setzero_pd();
262 __m256d sum2 = _mm256_setzero_pd();
263 __m256d sum3 = _mm256_setzero_pd();
264 __m256d sum4 = _mm256_setzero_pd();
265 __m256d sum5 = _mm256_setzero_pd();
266 __m128d sum6 = _mm_setzero_pd();
268 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc, pw += 25, inidx += in_padded_area) {
269 if (!tbl.is_connected(o, inc)) {
272 __m256d w0 = _mm256_loadu_pd(pw + 0);
273 __m256d w1 = _mm256_loadu_pd(pw + 4);
274 __m256d w2 = _mm256_loadu_pd(pw + 8);
275 __m256d w3 = _mm256_loadu_pd(pw + 12);
276 __m256d w4 = _mm256_loadu_pd(pw + 16);
277 __m256d w5 = _mm256_loadu_pd(pw + 20);
278 __m128d w6 = _mm_load_sd(pw + 24);
279 const double* pi = (
const double*)&in[inidx];
280 __m256d i0 = _mm256_loadu_pd(pi + 0);
281 __m256d i1 = _mm256_loadu_pd(pi + 4);
282 __m256d i2 = _mm256_loadu_pd(pi + 8);
283 __m256d i3 = _mm256_loadu_pd(pi + 12);
284 __m256d i4 = _mm256_loadu_pd(pi + 16);
285 __m256d i5 = _mm256_loadu_pd(pi + 20);
286 __m128d i6 = _mm_load_sd(pi + 24);
287 __m256d tmp0 = _mm256_mul_pd(w0, i0);
288 __m256d tmp1 = _mm256_mul_pd(w1, i1);
289 __m256d tmp2 = _mm256_mul_pd(w2, i2);
290 __m256d tmp3 = _mm256_mul_pd(w3, i3);
291 __m256d tmp4 = _mm256_mul_pd(w4, i4);
292 __m256d tmp5 = _mm256_mul_pd(w5, i5);
293 __m128d tmp6 = _mm_mul_pd(w6, i6);
294 sum0 = _mm256_add_pd(tmp0, sum0);
295 sum1 = _mm256_add_pd(tmp1, sum1);
296 sum2 = _mm256_add_pd(tmp2, sum2);
297 sum3 = _mm256_add_pd(tmp3, sum3);
298 sum4 = _mm256_add_pd(tmp4, sum4);
299 sum5 = _mm256_add_pd(tmp5, sum5);
300 sum6 = _mm_add_pd(tmp6, sum6);
302 sum0 = _mm256_add_pd(sum0, sum1);
303 sum2 = _mm256_add_pd(sum2, sum3);
304 sum4 = _mm256_add_pd(sum4, sum5);
305 sum0 = _mm256_add_pd(sum0, sum2);
306 __m256d sum = _mm256_add_pd(sum0, sum4);
307 __m128d b = _mm_load_sd(&bias[o]);
308 __m128d hsum = hsum256_pd(sum);
309 b = madd128_sd(b, y_bias_scale, sum6);
310 _mm_store_sd(&a[o], _mm_add_sd(hsum, b));
313 for (serial_size_t o = 0; o < out.depth_; ++o, oidx += out_area) {
314 double* pa = &a[oidx];
315 double b = bias[o] * bias_scale;
318 __m256d b2 = _mm256_set1_pd(b);
320 headSize = 4 - (oidx & 3);
321 assert(headSize < out_area);
322 for (
size_t i = 0; i < headSize; ++i) {
323 _mm_store_sd(&pa[i], _mm256_castpd256_pd128(b2));
326 size_t cnt = (out_area - headSize) / 8;
327 double* pa2 = pa + headSize;
328 for (
size_t i = 0; i < cnt; ++i) {
329 _mm256_store_pd(&pa2[i*8+0], b2);
330 _mm256_store_pd(&pa2[i*8+4], b2);
332 for (
size_t i = headSize + cnt*8; i < out_area; ++i) {
333 _mm_store_sd(&pa[i], _mm256_castpd256_pd128(b2));
337 for (serial_size_t inc = 0; inc < params.in.depth_; ++inc) {
338 if (!tbl.is_connected(o, inc))
continue;
340 const double* pw = (
const double*)&W[25 * (params.in.depth_ * o + inc)];
341 const double* pi = &in[in_padded.get_index(0, 0, inc)];
343 __m256d w0a = _mm256_loadu_pd(pw+0);
344 __m128d w0b = _mm_load_sd(pw+4);
345 __m256d w1a = _mm256_loadu_pd(pw+5);
346 __m128d w1b = _mm_load_sd(pw+9);
347 __m256d w2a = _mm256_loadu_pd(pw+10);
348 __m128d w2b = _mm_load_sd(pw+14);
349 __m256d w3a = _mm256_loadu_pd(pw+15);
350 __m128d w3b = _mm_load_sd(pw+19);
351 __m256d w4a = _mm256_loadu_pd(pw+20);
352 __m128d w4b = _mm_load_sd(pw+24);
355 for (serial_size_t y = 0; y < out.height_; ++y, pi += in_stride, ppa += out.width_) {
356 const double* pi0 = pi + 0 * in_padded.width_;
357 const double* pi1 = pi + 1 * in_padded.width_;
358 const double* pi2 = pi + 2 * in_padded.width_;
359 const double* pi3 = pi + 3 * in_padded.width_;
360 const double* pi4 = pi + 4 * in_padded.width_;
361 for (serial_size_t x = 0; x < out.width_; ++x) {
362 __m128d sum = _mm_load_sd(&ppa[x]);
363 __m256d i0a = _mm256_loadu_pd(pi0);
364 __m128d i0b = _mm_load_sd(pi0 + 4);
365 __m256d i1a = _mm256_loadu_pd(pi1);
366 __m128d i1b = _mm_load_sd(pi1 + 4);
367 __m256d i2a = _mm256_loadu_pd(pi2);
368 __m128d i2b = _mm_load_sd(pi2 + 4);
369 __m256d i3a = _mm256_loadu_pd(pi3);
370 __m128d i3b = _mm_load_sd(pi3 + 4);
371 __m256d i4a = _mm256_loadu_pd(pi4);
372 __m128d i4b = _mm_load_sd(pi4 + 4);
373 __m256d sum_a = _mm256_mul_pd(w0a, i0a);
374 __m128d sum_b = _mm_mul_sd(w0b, i0b);
375 sum_a = madd256_pd(w1a, i1a, sum_a);
376 sum_b = madd128_pd(w1b, i1b, sum_b);
377 sum_a = madd256_pd(w2a, i2a, sum_a);
378 sum_b = madd128_pd(w2b, i2b, sum_b);
379 sum_a = madd256_pd(w3a, i3a, sum_a);
380 sum_b = madd128_pd(w3b, i3b, sum_b);
381 sum_a = madd256_pd(w4a, i4a, sum_a);
382 sum_b = madd128_pd(w4b, i4b, sum_b);
383 __m128d sum_c = hsum256_pd(sum_a);
384 sum = _mm_add_sd(sum, sum_b);
385 _mm_store_sd(&ppa[x], _mm_add_sd(sum, sum_c));
400 inline void conv2d_op_avx(
const tensor_t& in_data,
404 const core::conv_params& params,
405 const bool layer_parallelize) {
407 if (params.weight.height_ == 5 && params.weight.width_ == 5) {
409 for_i(layer_parallelize, in_data.size(), [&](
int i) {
410 avx_conv2d_5x5_kernel(params, in_data[i], W, bias, out_data[i], layer_parallelize);
415 conv2d_op_internal(in_data, W, bias, out_data, params, layer_parallelize);