35 std::vector<const vec_t*> prev_out_padded_;
36 std::vector<vec_t> prev_out_buf_;
37 std::vector<vec_t> prev_delta_padded_;
43 : connected_(rows * cols), rows_(rows), cols_(cols) {
44 std::copy(ar, ar + rows * cols, connected_.begin());
46 connection_table(serial_size_t ngroups, serial_size_t rows, serial_size_t cols)
47 : connected_(rows * cols,
false), rows_(rows), cols_(cols) {
48 if (rows % ngroups || cols % ngroups) {
49 throw nn_error(
"invalid group size");
52 serial_size_t row_group = rows / ngroups;
53 serial_size_t col_group = cols / ngroups;
55 serial_size_t idx = 0;
57 for (serial_size_t g = 0; g < ngroups; g++) {
58 for (serial_size_t r = 0; r < row_group; r++) {
59 for (serial_size_t c = 0; c < col_group; c++) {
60 idx = (r + g * row_group) * cols_ + c + g * col_group;
61 connected_[idx] =
true;
67 bool is_connected(serial_size_t x, serial_size_t y)
const {
68 return is_empty() ? true : connected_[y * cols_ + x];
71 bool is_empty()
const {
72 return rows_ == 0 && cols_ == 0;
75 template <
typename Archive>
76 void serialize(Archive & ar) {
77 ar(cereal::make_nvp(
"rows", rows_), cereal::make_nvp(
"cols", cols_));
80 ar(cereal::make_nvp(
"connection", std::string(
"all")));
83 ar(cereal::make_nvp(
"connection", connected_));
87 std::deque<bool> connected_;
101 serial_size_t w_stride;
102 serial_size_t h_stride;
104 friend std::ostream& operator<<(std::ostream &o,
106 o <<
"in: " << param.in <<
"\n";
107 o <<
"out: " << param.out <<
"\n";
108 o <<
"in_padded: " << param.in_padded <<
"\n";
109 o <<
"weight: " << param.weight <<
"\n";
110 o <<
"has_bias: " << param.has_bias <<
"\n";
111 o <<
"w_stride: " << param.w_stride <<
"\n";
112 o <<
"h_stride: " << param.h_stride <<
"\n";
131 void copy_and_pad_input(
const tensor_t& in, tensor_t& out) {
132 if (params_.pad_type == padding::valid) {
136 tensor_t buf(in.size());
138 for_i(
true, buf.size(), [&](
int sample) {
140 buf[sample].resize(params_.in_padded.size());
143 for (serial_size_t c = 0; c < params_.in.depth_; c++) {
144 float_t* pimg = &buf[sample][params_.in_padded.get_index(
145 params_.weight.width_ / 2,
146 params_.weight.height_ / 2, c)];
147 const float_t* pin = &in[sample][params_.in.get_index(0, 0, c)];
149 for (serial_size_t y = 0; y < params_.in.height_; y++) {
150 std::copy(pin, pin + params_.in.width_, pimg);
151 pin += params_.in.width_;
152 pimg += params_.in_padded.width_;
166 void copy_and_unpad_delta(
const tensor_t& delta, tensor_t& delta_unpadded) {
167 if (params_.pad_type == padding::valid) {
171 tensor_t buf(delta.size());
173 for_i(
true, buf.size(), [&](
int sample) {
175 buf[sample].resize(params_.in.size());
177 for (serial_size_t c = 0; c < params_.in.depth_; c++) {
179 &delta[sample][params_.in_padded.get_index(
180 params_.weight.width_ / 2,
181 params_.weight.height_ / 2, c)];
182 float_t *pdst = &buf[sample][params_.in.get_index(0, 0, c)];
184 for (serial_size_t y = 0; y < params_.in.height_; y++) {
185 std::copy(pin, pin + params_.in.width_, pdst);
186 pdst += params_.in.width_;
187 pin += params_.in_padded.width_;
193 delta_unpadded = buf;
Definition: conv_params.h:121
Definition: conv_params.h:92
error exception class for tiny-dnn
Definition: nn_error.h:37
Definition: conv_params.h:40
Definition: conv_params.h:34