tiny_dnn  1.0.0
A header only, dependency-free deep learning framework in C++11
conv_params.h
1 /*
2  Copyright (c) 2016, Taiga Nomi, Edgar Riba
3  All rights reserved.
4 
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  * Redistributions of source code must retain the above copyright
8  notice, this list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright
10  notice, this list of conditions and the following disclaimer in the
11  documentation and/or other materials provided with the distribution.
12  * Neither the name of the <organization> nor the
13  names of its contributors may be used to endorse or promote products
14  derived from this software without specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 #pragma once
28 
29 #include "params.h"
30 
31 namespace tiny_dnn {
32 namespace core {
33 
35  std::vector<const vec_t*> prev_out_padded_;
36  std::vector<vec_t> prev_out_buf_;
37  std::vector<vec_t> prev_delta_padded_;
38 };
39 
41  connection_table() : rows_(0), cols_(0) {}
42  connection_table(const bool *ar, serial_size_t rows, serial_size_t cols)
43  : connected_(rows * cols), rows_(rows), cols_(cols) {
44  std::copy(ar, ar + rows * cols, connected_.begin());
45  }
46  connection_table(serial_size_t ngroups, serial_size_t rows, serial_size_t cols)
47  : connected_(rows * cols, false), rows_(rows), cols_(cols) {
48  if (rows % ngroups || cols % ngroups) {
49  throw nn_error("invalid group size");
50  }
51 
52  serial_size_t row_group = rows / ngroups;
53  serial_size_t col_group = cols / ngroups;
54 
55  serial_size_t idx = 0;
56 
57  for (serial_size_t g = 0; g < ngroups; g++) {
58  for (serial_size_t r = 0; r < row_group; r++) {
59  for (serial_size_t c = 0; c < col_group; c++) {
60  idx = (r + g * row_group) * cols_ + c + g * col_group;
61  connected_[idx] = true;
62  }
63  }
64  }
65  }
66 
67  bool is_connected(serial_size_t x, serial_size_t y) const {
68  return is_empty() ? true : connected_[y * cols_ + x];
69  }
70 
71  bool is_empty() const {
72  return rows_ == 0 && cols_ == 0;
73  }
74 
75  template <typename Archive>
76  void serialize(Archive & ar) {
77  ar(cereal::make_nvp("rows", rows_), cereal::make_nvp("cols", cols_));
78 
79  if (is_empty()) {
80  ar(cereal::make_nvp("connection", std::string("all")));
81  }
82  else {
83  ar(cereal::make_nvp("connection", connected_));
84  }
85  }
86 
87  std::deque<bool> connected_;
88  serial_size_t rows_;
89  serial_size_t cols_;
90 };
91 
92 class conv_params : public Params {
93  public:
94  connection_table tbl;
96  index3d<serial_size_t> in_padded;
99  bool has_bias;
100  padding pad_type;
101  serial_size_t w_stride;
102  serial_size_t h_stride;
103 
104  friend std::ostream& operator<<(std::ostream &o,
105  const core::conv_params& param) {
106  o << "in: " << param.in << "\n";
107  o << "out: " << param.out << "\n";
108  o << "in_padded: " << param.in_padded << "\n";
109  o << "weight: " << param.weight << "\n";
110  o << "has_bias: " << param.has_bias << "\n";
111  o << "w_stride: " << param.w_stride << "\n";
112  o << "h_stride: " << param.h_stride << "\n";
113  return o;
114  }
115 };
116 
117 inline conv_params Params::conv() const {
118  return *(static_cast<const conv_params*>(this));
119 }
120 
122  public:
123  Conv2dPadding() {}
124  Conv2dPadding(const conv_params& params) : params_(params) {}
125 
126  /* Applies padding to an input tensor given the convolution parameters
127  *
128  * @param in The input tensor
129  * @param out The output tensor with padding applied
130  */
131  void copy_and_pad_input(const tensor_t& in, tensor_t& out) {
132  if (params_.pad_type == padding::valid) {
133  return;
134  }
135 
136  tensor_t buf(in.size());
137 
138  for_i(true, buf.size(), [&](int sample) {
139  // alloc temporary buffer.
140  buf[sample].resize(params_.in_padded.size());
141 
142  // make padded version in order to avoid corner-case in fprop/bprop
143  for (serial_size_t c = 0; c < params_.in.depth_; c++) {
144  float_t* pimg = &buf[sample][params_.in_padded.get_index(
145  params_.weight.width_ / 2,
146  params_.weight.height_ / 2, c)];
147  const float_t* pin = &in[sample][params_.in.get_index(0, 0, c)];
148 
149  for (serial_size_t y = 0; y < params_.in.height_; y++) {
150  std::copy(pin, pin + params_.in.width_, pimg);
151  pin += params_.in.width_;
152  pimg += params_.in_padded.width_;
153  }
154  }
155  });
156 
157  // shrink buffer to output
158  out = buf;
159  }
160 
161  /* Applies unpadding to an input tensor given the convolution parameters
162  *
163  * @param in The input tensor
164  * @param out The output tensor with unpadding applied
165  */
166  void copy_and_unpad_delta(const tensor_t& delta, tensor_t& delta_unpadded) {
167  if (params_.pad_type == padding::valid) {
168  return;
169  }
170 
171  tensor_t buf(delta.size());
172 
173  for_i(true, buf.size(), [&](int sample) {
174  // alloc temporary buffer.
175  buf[sample].resize(params_.in.size());
176 
177  for (serial_size_t c = 0; c < params_.in.depth_; c++) {
178  const float_t *pin =
179  &delta[sample][params_.in_padded.get_index(
180  params_.weight.width_ / 2,
181  params_.weight.height_ / 2, c)];
182  float_t *pdst = &buf[sample][params_.in.get_index(0, 0, c)];
183 
184  for (serial_size_t y = 0; y < params_.in.height_; y++) {
185  std::copy(pin, pin + params_.in.width_, pdst);
186  pdst += params_.in.width_;
187  pin += params_.in_padded.width_;
188  }
189  }
190  });
191 
192  // shrink buffer to output
193  delta_unpadded = buf;
194  }
195 
196  private:
197  conv_params params_;
198 };
199 
200 } // namespace core
201 } // namespace tiny_dnn
Definition: conv_params.h:121
Definition: params.h:37
Definition: conv_params.h:92
error exception class for tiny-dnn
Definition: nn_error.h:37
Definition: conv_params.h:40