18 #include "gemmlowp/public/gemmlowp.h"
19 #include "tiny_dnn/core/kernels/tiny_quantization_kernel.h"
25 template <
bool TransposeA,
bool TransposeB,
bool TransposeC>
26 void gemmlowp_multiply(
const uint8_t* a_data,
27 const uint8_t* b_data,
37 const uint8_t* a_data_as_uint8 = a_data;
38 const uint8_t* b_data_as_uint8 = b_data;
39 int32_t* c_data_as_int32 = c_data;
40 static const gemmlowp::MapOrder ResultOrder =
41 !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
42 static const gemmlowp::MapOrder LhsOrder =
43 !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
44 static const gemmlowp::MapOrder RhsOrder =
45 !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
46 gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k,
48 gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n,
50 gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
52 const std::tuple<> empty_pipeline = {};
53 gemmlowp::GemmContext context;
54 gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
55 gemmlowp::DefaultL8R8BitDepthParams>(
56 &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
59 template <
class T1,
class T2,
class Toutput>
60 void tiny_quantized_matmul(
const std::vector<T1>& a,
61 const std::vector<T2>& b,
62 std::vector<Toutput>& c,
63 const std::vector<size_t> shape_all,
64 const int32_t offset_a,
65 const int32_t offset_b,
66 const int32_t offset_c,
68 const int32_t shift_c) {
76 int a_dim_remaining = 1 - transpose_a_;
77 int b_dim_remaining = 1 - transpose_b_;
79 const T1* a_data = &a[0];
80 const T2* b_data = &b[0];
81 Toutput* c_data = &c[0];
83 const bool transpose_c =
false;
84 const size_t m = shape_all[a_dim_remaining];
85 const size_t n = shape_all[2 + b_dim_remaining];
86 const size_t k = shape_all[transpose_a_];
87 const size_t lda = shape_all[1];
88 const size_t ldb = shape_all[3];
94 if (std::is_same<T1, uint8_t>() && std::is_same<T2, uint8_t>() &&
95 std::is_same<Toutput, int32_t>() && (offset_c == 0) && (mult_c == 1) &&
96 (shift_c == 0) && (transpose_c ==
false)) {
99 gemmlowp_multiply<true, true, false>(a_data, b_data, c_data, m, n, k,
100 offset_a, offset_b, lda, ldb,
103 gemmlowp_multiply<true, false, false>(a_data, b_data, c_data, m, n, k,
104 offset_a, offset_b, lda, ldb,
109 gemmlowp_multiply<false, true, false>(a_data, b_data, c_data, m, n, k,
110 offset_a, offset_b, lda, ldb,
113 gemmlowp_multiply<false, false, false>(a_data, b_data, c_data, m, n, k,
114 offset_a, offset_b, lda, ldb,