tiny_dnn  1.0.0
A header only, dependency-free deep learning framework in C++11
avx_kernel_common.h
1 /*
2  Copyright (c) 2016, Taiga Nomi, Edgar Riba
3  All rights reserved.
4 
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  * Redistributions of source code must retain the above copyright
8  notice, this list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright
10  notice, this list of conditions and the following disclaimer in the
11  documentation and/or other materials provided with the distribution.
12  * Neither the name of the <organization> nor the
13  names of its contributors may be used to endorse or promote products
14  derived from this software without specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 #pragma once
28 
29 #ifndef CNN_USE_AVX
30 #error Advanced Vector Extensions required.
31 #endif
32 
33 #ifndef _mm256_set_m128
34 #define _mm256_set_m128(va, vb) \
35  _mm256_insertf128_ps(_mm256_castps128_ps256(vb), va, 1)
36 #endif
37 
38 inline __m256 madd256_ps(__m256 a, __m256 b, __m256 c) {
39  return _mm256_add_ps(_mm256_mul_ps(a, b), c);
40 }
41 inline __m128 madd128_ps(__m128 a, __m128 b, __m128 c) {
42  return _mm_add_ps(_mm_mul_ps(a, b), c);
43 }
44 inline __m128 madd128_ss(__m128 a, __m128 b, __m128 c) {
45  return _mm_add_ss(_mm_mul_ss(a, b), c);
46 }
47 inline __m256d madd256_pd(__m256d a, __m256d b, __m256d c) {
48  return _mm256_add_pd(_mm256_mul_pd(a, b), c);
49 }
50 inline __m128d madd128_pd(__m128d a, __m128d b, __m128d c) {
51  return _mm_add_pd(_mm_mul_pd(a, b), c);
52 }
53 inline __m128d madd128_sd(__m128d a, __m128d b, __m128d c) {
54  return _mm_add_sd(_mm_mul_sd(a, b), c);
55 }
56 
57 // Horizontally add elements of __m256 type argument (sadly, _mm256_hadd_ps isn't good enough)
58 // http://stackoverflow.com/a/13222410/4699324
59 // x = ( x7, x6, x5, x4, x3, x2, x1, x0 )
60 inline __m128 hsum256_ps(__m256 x) {
61  // hiQuad = ( x7, x6, x5, x4 )
62  const __m128 hiQuad = _mm256_extractf128_ps(x, 1);
63  // loQuad = ( x3, x2, x1, x0 )
64  const __m128 loQuad = _mm256_castps256_ps128(x);
65  // sumQuad = ( x3+x7, x2+x6, x1+x5, x0+x4 )
66  const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);
67  // loDual = ( -, -, x1+x5, x0+x4 )
68  const __m128 loDual = sumQuad;
69  // hiDual = ( -, -, x3+x7, x2+x6 )
70  const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);
71  // sumDual = ( -, -, x1+x3 + x5+x7, x0+x2 + x4+x6 )
72  const __m128 sumDual = _mm_add_ps(loDual, hiDual);
73  // lo = ( -, -, -, x0+x2 + x4+x6 )
74  const __m128 lo = sumDual;
75  // hi = ( -, -, -, x1+x3 + x5+x7 )
76  const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);
77  // sum = ( -, -, -, x0+x1+x2+x3 + x4+x5+x6+x7 )
78  const __m128 sum = _mm_add_ss(lo, hi);
79  return sum;
80 }
81 
82 // Horizontally add elements of each __m256 type arguments at once
83 inline __m128 hsum2x256_ps(__m256 a, __m256 b) {
84  // (b3, b2, b1, b0, a3, a2, a1, a0)
85  __m256 x = _mm256_permute2f128_ps(a, b, 0x20);
86  // (b7, b6, b5, b4, a7, a6, a5, a4)
87  __m256 y = _mm256_permute2f128_ps(a, b, 0x31);
88  // (b3+b7, b2+b6, b1+b5, b0+b4, a3+a7, a2+a6, a1+a5, a0+a4)
89  x = _mm256_add_ps(x, y);
90  // (-, -, b3+b7, b2+b6, -, -, a3+a7, a2+a6)
91  y = _mm256_permute_ps(x, _MM_SHUFFLE(3, 2, 3, 2));
92  // (-, -, b1+b5+b3+b7, b0+b4+b2+b6, -, -, a1+a5+a3+a7, a0+a4+a2+a6)
93  x = _mm256_add_ps(x, y);
94  // (-, -, -, b1+b5+b3+b7, -, -, -, a1+a5+a3+a7)
95  y = _mm256_permute_ps(x, _MM_SHUFFLE(1, 1, 1, 1));
96  // (-, -, -, b1+b5+b3+b7+b0+b4+b2+b6, -, -, -, a1+a5+a3+a7+a0+a4+a2+a6)
97  x = _mm256_add_ps(x, y);
98  // (-, -, -, b1+b5+b3+b7+b0+b4+b2+b6)
99  __m128 upper = _mm256_extractf128_ps(x, 1);
100  // (-, -, -, -, -, -, b1+b5+b3+b7+b0+b4+b2+b6, a1+a5+a3+a7+a0+a4+a2+a6)
101  __m128 ret = _mm_unpacklo_ps(_mm256_castps256_ps128(x), upper);
102  return ret;
103 }
104 
105 inline __m128d hsum256_pd(__m256d x) {
106  // hiDual = ( x3, x2 )
107  const __m128d hiDual = _mm256_extractf128_pd(x, 1);
108  // loDual = ( x1, x0 )
109  const __m128d loDual = _mm256_castpd256_pd128(x);
110  // sumQuad = ( x2+x3, x0+x1 )
111  const __m128d sumDual = _mm_add_pd(loDual, hiDual);
112  // sum = ( 0, x0+x1+x2+x3 );
113  const __m128d sum = _mm_hadd_pd(sumDual, _mm_setzero_pd());
114  return sum;
115 }
116 
117 template<int n>
118 struct foobar : std::false_type
119 { };
120 
121 
122 // Byte Shift YMM Register Across 128-bit Lanes
123 // limitation : shift amount is immediate and is multiples of 4
124 
125 template <int n>
126 inline __m256 leftShift(__m256 a) {
127  static_assert(foobar<n>::value, "unsupported shift amount");
128  return a;
129 }
130 
131 // http://stackoverflow.com/q/19516585
132 template <>
133 inline __m256 leftShift<4>(__m256 x) {
134  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
135 
136  // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
137  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
138  // t1 = (x2, x1, x0, x3, 0, 0, 0, 0)
139  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
140  // y = (x6, x5, x4, x3, x2, x1, x0, 0)
141  __m256 y = _mm256_blend_ps(t0, t1, 0x11);
142  return y;
143 }
144 
145 // http://stackoverflow.com/q/19516585
146 template <>
147 inline __m256 leftShift<8>(__m256 x) {
148  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
149 
150  // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
151  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
152  // t1 = (x1, x0, x3, x2, 0, 0, 0, 0)
153  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
154  // y = (x5, x4, x3, x2, x1, x0, 0, 0)
155  __m256 y = _mm256_blend_ps(t0, t1, 0x33 /* 0b00110011 */ );
156  return y;
157 }
158 
159 template <>
160 inline __m256 leftShift<12>(__m256 x) {
161  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
162 
163  // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
164  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
165  // t1 = (x0, x3, x2, x1, 0, 0, 0, 0)
166  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
167  // y = (x4, x3, x2, x1, x0, 0, 0, 0)
168  __m256 y = _mm256_blend_ps(t0, t1, 0x77 /* 0b01110111 */ );
169  return y;
170 }
171 
172 template <>
173 inline __m256 leftShift<16>(__m256 x) {
174  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
175 
176  // y = (x3, x2, x1, x0, 0, 0, 0, 0)
177  __m256 y = _mm256_permute2f128_ps(x, x, 0x08);
178  return y;
179 }
180 
181 template <>
182 inline __m256 leftShift<20>(__m256 x) {
183  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
184 
185  // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
186  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
187  // t1 = (x2, x1, x0, x3, 0, 0, 0, 0)
188  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
189  // y = (x2, x1, x0, 0, 0, 0, 0, 0)
190  __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0x10 /* 0b00010000 */ );
191  return y;
192 }
193 
194 template <>
195 inline __m256 leftShift<24>(__m256 x) {
196  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
197 
198  // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
199  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
200  // t1 = (x1, x0, x3, x2, 0, 0, 0, 0)
201  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
202  // y = (x1, x0, 0, 0, 0, 0, 0, 0)
203  __m256 y = _mm256_blend_ps(_mm256_setzero_ps(), t1, 0xC0 /* 0b11000000 */ );
204  return y;
205 }
206 
207 template <>
208 inline __m256 leftShift<28>(__m256 x) {
209  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
210 
211  // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
212  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
213  // t1 = (x0, x3, x2, x1, 0, 0, 0, 0)
214  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
215  // y = (x0, 0, 0, 0, 0, 0, 0, 0)
216  __m256 y = _mm256_blend_ps(_mm256_setzero_ps(), t1, 0x80 /* 0b10000000 */ );
217  return y;
218 }
219 
220 template <int n>
221 inline __m256 rightShift(__m256 a)
222 {
223  static_assert(foobar<n>::value, "unsupported shift amount");
224  return a;
225 }
226 
227 // http://stackoverflow.com/a/19532415/4699324
228 template <>
229 inline __m256 rightShift<4>(__m256 x) {
230  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
231 
232  // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
233  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
234  // t1 = (0, 0, 0, 0, x4, x7, x6, x5)
235  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
236 
237  // ( -, x7, x6, x5, -, x3, x2, x1)
238  // ( 0, -, -, -, x4, -, -, -)
239  // y = ( 0, x7, x6, x5, x4, x3, x2, x1)
240  __m256 y = _mm256_blend_ps(t0, t1, 0x88 /* 0b10001000 */ );
241  return y;
242 }
243 
244 template <>
245 inline __m256 rightShift<8>(__m256 x) {
246  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
247 
248  // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
249  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
250  // t1 = (0, 0, 0, 0, x5, x4, x7, x6)
251  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
252 
253  // ( -, -, x7, x6, -, -, x3, x2)
254  // ( 0, 0, -, -, x5, x4, -, -)
255  // y = ( 0, 0, x7, x6, x5, x4, x3, x2)
256  __m256 y = _mm256_blend_ps(t0, t1, 0xCC /* 0b11001100 */ );
257  return y;
258 }
259 
260 template <>
261 inline __m256 rightShift<12>(__m256 x) {
262  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
263 
264  // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
265  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
266  // t1 = ( 0, 0, 0, 0, x6, x5, x4, x7)
267  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
268 
269  // ( -, -, -, x7, -, -, -, x3)
270  // ( 0, 0, 0, -, x6, x5, x4, -)
271  // y = ( 0, 0, 0, x7, x6, x5, x4, x3)
272  __m256 y = _mm256_blend_ps(t0, t1, 0xEE /* 0b11101110 */ );
273  return y;
274 }
275 
276 template <>
277 inline __m256 rightShift<16>(__m256 x)
278 {
279  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
280 
281  // y = ( 0, 0, 0, 0, x7, x6, x5, x4)
282  __m256 y = _mm256_permute2f128_ps(x, x, 0x81);
283  return y;
284 }
285 
286 template <>
287 inline __m256 rightShift<20>(__m256 x) {
288  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
289 
290  // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
291  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
292  // t1 = ( 0, 0, 0, 0, x4, x7, x6, x5)
293  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
294 
295  // ( -, -, -, -, -, x7, x6, x5)
296  // ( 0, 0, 0, 0, 0, -, -, -)
297  // y = ( 0, 0, 0, 0, 0, x7, x6, x5)
298  __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0xF8 /* 0b11111000 */ );
299  return y;
300 }
301 
302 template <>
303 inline __m256 rightShift<24>(__m256 x) {
304  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
305 
306  // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
307  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
308  // t1 = ( 0, 0, 0, 0, x5, x4, x7, x6)
309  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
310 
311  // ( -, -, -, -, -, -, x7, x6)
312  // ( 0, 0, 0, 0, 0, 0, -, -)
313  // y = ( 0, 0, 0, 0, 0, 0, x7, x6)
314  __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0xFC /* 0b11111100 */ );
315  return y;
316 }
317 
318 template <>
319 inline __m256 rightShift<28>(__m256 x) {
320  // x = (x7, x6, x5, x4, x3, x2, x1, x0)
321 
322  // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
323  __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
324  // t1 = ( 0, 0, 0, 0, x6, x5, x4, x7)
325  __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
326 
327  // ( -, -, -, -, -, -, -, x7)
328  // ( 0, 0, 0, 0, 0, 0, 0, -)
329  // y = ( 0, 0, 0, 0, 0, 0, 0, x7)
330  __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0xFE /* 0b11111110 */ );
331  return y;
332 }
333 
Definition: avx_kernel_common.h:119