29 #include "tiny_dnn/core/params/conv_params.h"
30 #include "tiny_dnn/core/kernels/tiny_quantization_kernel.h"
36 inline void tiny_quantized_conv2d_kernel(
const conv_params& params,
41 const bool layer_parallelize) {
43 float_t min_input(in[0]);
44 float_t max_input(in[0]);
45 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
46 for (serial_size_t ins = 0; ins < params.in_padded.height_*params.in_padded.height_; ins++) {
47 serial_size_t idx = params.in_padded.get_index(0, 0, inc);
48 min_input = std::min(min_input, (&in[idx])[ins]);
49 max_input = std::max(max_input, (&in[idx])[ins]);
52 std::vector<uint8_t> in_quantized =
53 float_tensor_to_quantized<uint8_t>(in, min_input, max_input);
55 float_t min_filter(W[0]);
56 float_t max_filter(W[0]);
57 for (serial_size_t inc = 0; inc < params.in_padded.depth_; inc++) {
58 for (serial_size_t ins = 0; ins < params.weight.height_*params.weight.height_; ins++) {
59 serial_size_t idx = params.in_padded.get_index(0, 0, inc);
60 min_filter = std::min(min_filter, (&W[idx])[ins]);
61 max_filter = std::max(max_filter, (&W[idx])[ins]);
64 if (min_filter == max_filter) {
65 max_filter = W[0] + 1e-3f;
66 min_filter = W[0] - 1e-3f;
68 std::vector<uint8_t> W_quantized =
69 float_tensor_to_quantized<uint8_t>(W, min_filter, max_filter);
73 std::vector<uint8_t> bias_quantized;
74 if (params.has_bias) {
75 for (serial_size_t inc = 0; inc < params.out.depth_; inc++) {
76 min_bias = std::min(min_bias, bias[inc]);
77 max_bias = std::max(max_bias, bias[inc]);
79 if (min_bias == max_bias) {
80 max_bias = bias[0] + 1e-3f;
81 min_bias = bias[0] - 1e-3f;
84 float_tensor_to_quantized<uint8_t>(bias, min_bias, max_bias);
87 float_t min_output_value;
88 float_t max_output_value;
89 quantization_range_for_multiplication<uint8_t, uint8_t, int32_t>(
90 min_input, max_input, min_filter, max_filter, &min_output_value,
93 std::vector<int32_t> a_quantized(a.size(),
static_cast<int32_t
>(0));
96 const int32_t offset_input =
97 int64_to_int32(float_to_quantized_unclamped<uint8_t>(0.0f, min_input, max_input));
98 const int32_t offset_filter =
99 int64_to_int32(float_to_quantized_unclamped<uint8_t>(0.0f, min_filter, max_filter));
100 const int32_t zero_in_total_space =
101 int64_to_int32(float_to_quantized<int32_t>(0.0f, min_output_value, max_output_value));
103 for_i(layer_parallelize, params.out.depth_, [&](
int o) {
104 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
105 if (!params.tbl.is_connected(o, inc)) continue;
107 serial_size_t idx = 0;
108 idx = params.in.depth_ * o + inc;
109 idx = params.weight.get_index(0, 0, idx);
110 const uint8_t *pw = &W_quantized[idx];
112 idx = params.in_padded.get_index(0, 0, inc);
113 const uint8_t *pi = &in_quantized[idx];
115 idx = params.out.get_index(0, 0, o);
116 int32_t *pa_quantized = &a_quantized[idx];
118 for (serial_size_t y = 0; y < params.out.height_; y++) {
119 for (serial_size_t x = 0; x < params.out.width_; x++) {
120 const uint8_t * ppw = pw;
121 const uint8_t * ppi = pi + params.in_padded.width_ *
122 (y * params.h_stride) +
127 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
128 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
129 idx = wy * params.in_padded.width_ + wx;
130 sum += (static_cast<int32_t>(*ppw++) - offset_filter)
131 * (static_cast<int32_t>(ppi[idx]) - offset_input);
134 pa_quantized[y * params.out.width_ + x] += sum;
138 if (params.has_bias) {
139 int32_t * pa_quantized = &a_quantized[params.out.get_index(0, 0, o)];
140 int32_t * paa_quantized = pa_quantized + params.out.width_ * params.out.height_;
141 std::for_each(pa_quantized, paa_quantized, [&](int32_t& f) {
142 f += (bias_quantized[o] - zero_in_total_space);
147 float_t min_output_requantized;
148 float_t max_output_requantized;
149 std::vector<uint8_t> a_requantized(a_quantized.size(),
static_cast<uint8_t
>(0));
152 quantize_down_and_shrink_range<int32_t, uint8_t>(a_quantized, min_output_value, max_output_value,
153 &min_output_requantized, &max_output_requantized, &a_requantized);
156 a = quantized_tensor_to_float<uint8_t>(a_requantized, min_output_requantized, max_output_requantized);
159 inline void tiny_quantized_conv2d_back_kernel(
const conv_params& params,
160 const vec_t& prev_out,
167 float_t min_prev_out(prev_out[0]);
168 float_t max_prev_out(prev_out[0]);
169 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
170 for (serial_size_t ins = 0; ins < params.in_padded.height_*params.in_padded.height_; ins++) {
171 serial_size_t idx = params.in_padded.get_index(0, 0, inc);
172 min_prev_out = std::min(min_prev_out, (&prev_out[idx])[ins]);
173 max_prev_out = std::max(min_prev_out, (&prev_out[idx])[ins]);
176 std::vector<uint8_t> prev_out_quantized =
177 float_tensor_to_quantized<uint8_t>(prev_out, min_prev_out, max_prev_out);
180 float_t min_filter(W[0]);
181 float_t max_filter(W[0]);
182 for (serial_size_t inc = 0; inc < params.in_padded.depth_; inc++) {
183 for (serial_size_t ins = 0; ins < params.weight.height_*params.weight.height_; ins++) {
184 serial_size_t idx = params.in_padded.get_index(0, 0, inc);
185 min_filter = std::min(min_filter, (&W[idx])[ins]);
186 max_filter = std::max(max_filter, (&W[idx])[ins]);
189 if (min_filter == max_filter) {
190 max_filter = W[0] + 1e-3f;
191 min_filter = W[0] - 1e-3f;
193 std::vector<uint8_t> W_quantized =
194 float_tensor_to_quantized<uint8_t>(W, min_filter, max_filter);
197 float_t min_curr_delta(curr_delta[0]);
198 float_t max_curr_delta(curr_delta[0]);
199 for (serial_size_t inc = 0; inc < params.out.depth_; inc++) {
200 for (serial_size_t ins = 0; ins < params.out.height_*params.out.height_; ins++) {
201 serial_size_t idx = params.out.get_index(0, 0, inc);
202 min_curr_delta = std::min(min_curr_delta, (&curr_delta[idx])[ins]);
203 max_curr_delta = std::max(max_curr_delta, (&curr_delta[idx])[ins]);
206 std::vector<uint8_t> curr_delta_quantized =
207 float_tensor_to_quantized<uint8_t>(curr_delta, min_curr_delta, max_curr_delta);
210 float_t min_prev_delta_value;
211 float_t max_prev_delta_value;
212 quantization_range_for_multiplication<uint8_t, uint8_t, int32_t>(
213 min_curr_delta, max_curr_delta, min_filter, max_filter, &min_prev_delta_value,
214 &max_prev_delta_value);
216 std::vector<int32_t> prev_delta_quantized(prev_delta->size(),
static_cast<int32_t
>(0));
219 float_t min_dW_value;
220 float_t max_dW_value;
221 quantization_range_for_multiplication<uint8_t, uint8_t, int32_t>(
222 min_curr_delta, max_curr_delta, min_prev_out, max_prev_out, &min_dW_value,
225 std::vector<int32_t> dW_quantized(dW.size(),
static_cast<int32_t
>(0));
228 const int32_t offset_prev_out =
229 int64_to_int32(float_to_quantized_unclamped<uint8_t>(0.0f, min_prev_out, max_prev_out));
230 const int32_t offset_filter =
231 int64_to_int32(float_to_quantized_unclamped<uint8_t>(0.0f, min_filter, max_filter));
232 const int32_t offset_curr_delta =
233 int64_to_int32(float_to_quantized_unclamped<uint8_t>(0.0f, min_curr_delta, max_curr_delta));
238 for_i(params.in.depth_, [&](
int inc) {
239 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
240 if (!params.tbl.is_connected(outc, inc)) continue;
242 serial_size_t idx = 0;
243 idx = params.in.depth_ * outc + inc;
244 idx = params.weight.get_index(0, 0, idx);
245 const uint8_t *pw = &W_quantized[idx];
247 idx = params.out.get_index(0, 0, outc);
248 const uint8_t *pdelta_src = &curr_delta_quantized[idx];
250 idx = params.in_padded.get_index(0, 0, inc);
251 int32_t *pdelta_quantized_dst = &(prev_delta_quantized)[idx];
253 for (serial_size_t y = 0; y < params.out.height_; y++) {
254 for (serial_size_t x = 0; x < params.out.width_; x++) {
255 const uint8_t * ppw = pw;
257 idx = y * params.out.width_ + x;
258 const uint8_t ppdelta_src = pdelta_src[idx];
260 int32_t * ppdelta_quantized_dst = pdelta_quantized_dst +
261 y * params.h_stride * params.in_padded.width_ +
264 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
265 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
266 idx = wy * params.in_padded.width_ + wx;
267 ppdelta_quantized_dst[idx] += (static_cast<int32_t>(*ppw++) - offset_filter)
268 * (static_cast<int32_t>(ppdelta_src) - offset_curr_delta);
276 float_t min_prev_delta_requantized;
277 float_t max_prev_delta_requantized;
278 std::vector<uint8_t> prev_delta_requantized(prev_delta_quantized.size(),
static_cast<uint8_t
>(0));
281 quantize_down_and_shrink_range<int32_t, uint8_t>(prev_delta_quantized, min_prev_delta_value, max_prev_delta_value,
282 &min_prev_delta_requantized, &max_prev_delta_requantized, &prev_delta_requantized);
285 vec_t prev_delta_vec = quantized_tensor_to_float<uint8_t>(prev_delta_requantized, min_prev_delta_requantized, max_prev_delta_requantized);
286 prev_delta = &prev_delta_vec;
289 for_i(params.in.depth_, [&](
int inc) {
290 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
291 if (!params.tbl.is_connected(outc, inc)) continue;
293 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
294 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
295 int32_t dst = int32_t(0);
297 serial_size_t idx = 0;
298 idx = params.in_padded.get_index(wx, wy, inc);
299 const uint8_t * prevo = &prev_out_quantized[idx];
301 idx = params.out.get_index(0, 0, outc);
302 const uint8_t * delta = &curr_delta_quantized[idx];
304 for (serial_size_t y = 0; y < params.out.height_; y++) {
305 for (serial_size_t x = 0; x < params.out.width_; x++) {
306 dst += (static_cast<int32_t>(*(prevo + y * params.in_padded.width_ + x)) - offset_prev_out) *
307 (static_cast<int32_t>(*(delta + y * params.out.width_ + x)) - offset_curr_delta);
311 idx = params.in.depth_ * outc + inc;
312 dW_quantized[params.weight.get_index(wx, wy, idx)] += dst;
318 float_t min_dW_requantized;
319 float_t max_dW_requantized;
320 std::vector<uint8_t> dW_requantized(dW_quantized.size(),
static_cast<uint8_t
>(0));
323 quantize_down_and_shrink_range<int32_t, uint8_t>(dW_quantized, min_dW_value, max_dW_value,
324 &min_dW_requantized, &max_dW_requantized, &dW_requantized);
327 dW = quantized_tensor_to_float<uint8_t>(dW_requantized, min_dW_requantized, max_dW_requantized);
330 if (params.has_bias) {
333 for (serial_size_t outc = 0; outc < params.out.depth_; outc++) {
334 serial_size_t idx = params.out.get_index(0, 0, outc);
335 const float_t * delta = &curr_delta[idx];
336 const float_t * deltaa = delta + params.out.width_ *
338 db[outc] += std::accumulate(delta, deltaa, float_t(0));
343 inline void tiny_quantized_conv2d_kernel(
const conv_params& params,
352 const bool layer_parallelize) {
354 float_t min_filter(W_r[0]);
355 float_t max_filter(W_r[1]);
356 if (W_r[0] == W_r[1]) {
357 max_filter = W_r[1] + 1e-3f;
358 min_filter = W_r[0] - 1e-3f;
361 float_t min_bias(b_r[0]);
362 float_t max_bias(b_r[1]);
363 if (params.has_bias) {
364 if (min_bias == max_bias) {
365 max_bias = b_r[1] + 1e-3f;
366 min_bias = b_r[0] - 1e-3f;
370 float_t min_output_value;
371 float_t max_output_value;
372 quantization_range_for_multiplication<uint8_t, uint8_t, int32_t>(
373 in_r[0], in_r[1], min_filter, max_filter, &min_output_value,
376 std::vector<uint8_t> in_quantized, W_quantized, bias_quantized;
377 for (
size_t i = 0; i < in.size(); i++) {
378 in_quantized.push_back(
static_cast<uint8_t
>(in[i]));
380 for (
size_t i = 0; i < W.size(); i++) {
381 W_quantized.push_back(
static_cast<uint8_t
>(W[i]));
383 for (
size_t i = 0; i < bias.size(); i++) {
384 bias_quantized.push_back(
static_cast<uint8_t
>(bias[i]));
387 std::vector<int32_t> a_quantized(a.size(),
static_cast<int32_t
>(0));
390 const int32_t offset_input =
391 int64_to_int32(float_to_quantized_unclamped<uint8_t>(0.0f, in_r[0], in_r[1]));
392 const int32_t offset_filter =
393 int64_to_int32(float_to_quantized_unclamped<uint8_t>(0.0f, min_filter, max_filter));
394 const int32_t zero_in_total_space =
395 int64_to_int32(float_to_quantized<int32_t>(0.0f, min_output_value, max_output_value));
397 for_i(layer_parallelize, params.out.depth_, [&](
int o) {
398 for (serial_size_t inc = 0; inc < params.in.depth_; inc++) {
399 if (!params.tbl.is_connected(o, inc)) continue;
401 serial_size_t idx = 0;
402 idx = params.in.depth_ * o + inc;
403 idx = params.weight.get_index(0, 0, idx);
404 const uint8_t *pw = &W_quantized[idx];
406 idx = params.in_padded.get_index(0, 0, inc);
407 const uint8_t *pi = &in_quantized[idx];
409 idx = params.out.get_index(0, 0, o);
410 int32_t *pa_quantized = &a_quantized[idx];
412 for (serial_size_t y = 0; y < params.out.height_; y++) {
413 for (serial_size_t x = 0; x < params.out.width_; x++) {
414 const uint8_t * ppw = pw;
415 const uint8_t * ppi = pi + params.in_padded.width_ *
416 (y * params.h_stride) +
421 for (serial_size_t wy = 0; wy < params.weight.height_; wy++) {
422 for (serial_size_t wx = 0; wx < params.weight.width_; wx++) {
423 idx = wy * params.in_padded.width_ + wx;
424 sum += (static_cast<int32_t>(*ppw++) - offset_filter)
425 * (static_cast<int32_t>(ppi[idx]) - offset_input);
428 pa_quantized[y * params.out.width_ + x] += sum;
432 if (params.has_bias) {
433 int32_t * pa_quantized = &a_quantized[params.out.get_index(0, 0, o)];
434 int32_t * paa_quantized = pa_quantized + params.out.width_ * params.out.height_;
435 std::for_each(pa_quantized, paa_quantized, [&](int32_t& f) {
436 f += static_cast<int32_t>((bias[o] - zero_in_total_space));
441 float_t min_output_requantized;
442 float_t max_output_requantized;
443 std::vector<uint8_t> a_requantized(a_quantized.size(),
static_cast<uint8_t
>(0));
446 quantize_down_and_shrink_range<int32_t, uint8_t>(a_quantized, min_output_value, max_output_value,
447 &min_output_requantized, &max_output_requantized, &a_requantized);
449 for (
size_t i = 0; i < a_requantized.size(); i++) {
450 a[i] =
static_cast<float_t
>(a_requantized[i]);
452 a_r[0] = min_output_requantized;
453 a_r[1] = max_output_requantized;