28 #include "tiny_dnn/util/util.h"
29 #include "tiny_dnn/util/image.h"
30 #include "tiny_dnn/activations/activation_function.h"
37 template <
typename Activation = activation::
identity>
40 CNN_USE_LAYER_MEMBERS;
50 serial_size_t in_height,
52 serial_size_t unpooling_size)
57 serial_size_t unpooling_size,
70 serial_size_t in_height,
72 serial_size_t unpooling_size,
74 :
Base({vector_type::data}),
75 unpool_size_(unpooling_size),
78 out_(unpool_out_dim(in_width, unpooling_size, stride), unpool_out_dim(in_height, unpooling_size, stride),
in_channels)
89 return in2out_[0].size();
92 void forward_propagation(serial_size_t index,
93 const std::vector<vec_t*>& in_data,
94 std::vector<vec_t*>& out_data)
override {
95 const vec_t& in = *in_data[0];
97 vec_t& a = *out_data[1];
98 std::vector<serial_size_t>& max_idx = max_unpooling_layer_worker_storage_[index].in2outmax_;
101 for (int i = r.begin(); i < r.end(); i++) {
102 const auto& in_index = out2in_[i];
103 a[i] = (max_idx[in_index] == i) ? in[in_index] : float_t(0);
107 this->forward_activation(*out_data[0], *out_data[1]);
110 void back_propagation(serial_size_t index,
111 const std::vector<vec_t*>& in_data,
112 const std::vector<vec_t*>& out_data,
113 std::vector<vec_t*>& out_grad,
114 std::vector<vec_t*>& in_grad)
override {
115 vec_t& prev_delta = *in_grad[0];
116 vec_t& curr_delta = *out_grad[1];
117 std::vector<serial_size_t>& max_idx = max_unpooling_layer_worker_storage_[index].in2outmax_;
119 CNN_UNREFERENCED_PARAMETER(in_data);
121 this->backward_activation(*out_grad[0], *out_data[0], curr_delta);
123 for_(parallelize_, 0, in2out_.size(), [&](
const blocked_range& r) {
124 for (int i = r.begin(); i != r.end(); i++) {
125 serial_size_t outi = out2in_[i];
126 prev_delta[i] = (max_idx[outi] == i) ? curr_delta[outi] : float_t(0);
131 std::vector<index3d<serial_size_t>>
in_shape()
const override {
return {in_}; }
132 std::vector<index3d<serial_size_t>>
out_shape()
const override {
return {out_, out_}; }
133 std::string
layer_type()
const override {
return "max-unpool"; }
134 size_t unpool_size()
const {
return unpool_size_;}
136 virtual void set_worker_count(serial_size_t worker_count)
override {
137 Base::set_worker_count(worker_count);
138 max_unpooling_layer_worker_storage_.resize(worker_count);
139 for (max_unpooling_layer_worker_specific_storage& mws : max_unpooling_layer_worker_storage_) {
140 mws.in2outmax_.resize(out_.size());
144 template <
class Archive>
145 static void load_and_construct(Archive & ar, cereal::construct<max_unpooling_layer> & construct) {
147 serial_size_t stride, unpool_size;
149 ar(cereal::make_nvp(
"in_size", in), cereal::make_nvp(
"unpool_size", unpool_size), cereal::make_nvp(
"stride", stride));
150 construct(in, unpool_size, stride);
153 template <
class Archive>
154 void serialize(Archive & ar) {
155 layer::serialize_prolog(ar);
156 ar(cereal::make_nvp(
"in_size", in_), cereal::make_nvp(
"unpool_size", unpool_size_), cereal::make_nvp(
"stride", stride_));
160 serial_size_t unpool_size_;
161 serial_size_t stride_;
162 std::vector<serial_size_t> out2in_;
163 std::vector<std::vector<serial_size_t> > in2out_;
165 struct max_unpooling_layer_worker_specific_storage {
166 std::vector<serial_size_t> in2outmax_;
169 std::vector<max_unpooling_layer_worker_specific_storage> max_unpooling_layer_worker_storage_;
171 index3d<serial_size_t> in_;
172 index3d<serial_size_t> out_;
174 static serial_size_t unpool_out_dim(serial_size_t in_size, serial_size_t unpooling_size, serial_size_t stride) {
175 return (
int) (float_t)in_size * stride + unpooling_size - 1;
178 void connect_kernel(serial_size_t unpooling_size, serial_size_t inx, serial_size_t iny, serial_size_t c)
180 serial_size_t dxmax =
static_cast<serial_size_t
>(std::min(unpooling_size, inx * stride_ - out_.width_));
181 serial_size_t dymax =
static_cast<serial_size_t
>(std::min(unpooling_size, iny * stride_ - out_.height_));
183 for (serial_size_t dy = 0; dy < dymax; dy++) {
184 for (serial_size_t dx = 0; dx < dxmax; dx++) {
185 serial_size_t out_index = out_.get_index(
static_cast<serial_size_t
>(inx * stride_ + dx),
186 static_cast<serial_size_t
>(iny * stride_ + dy), c);
187 serial_size_t in_index = in_.get_index(inx, iny, c);
189 if (in_index >= in2out_.size())
190 throw nn_error(
"index overflow");
191 if (out_index >= out2in_.size())
192 throw nn_error(
"index overflow");
193 out2in_[out_index] = in_index;
194 in2out_[in_index].push_back(out_index);
199 void init_connection()
201 in2out_.resize(in_.size());
202 out2in_.resize(out_.size());
204 for (max_unpooling_layer_worker_specific_storage& mws : max_unpooling_layer_worker_storage_) {
205 mws.in2outmax_.resize(in_.size());
208 for (serial_size_t c = 0; c < in_.depth_; ++c)
209 for (serial_size_t y = 0; y < in_.height_; ++y)
210 for (serial_size_t x = 0; x < in_.width_; ++x)
211 connect_kernel(
static_cast<serial_size_t
>(unpool_size_),
single-input, single-output network with activation function
Definition: feedforward_layer.h:37
serial_size_t in_size() const
!
Definition: layer.h:176
bool parallelize_
Flag indicating whether the layer/node operations ara paralellized.
Definition: layer.h:696
serial_size_t in_channels() const
number of outgoing edges in this layer
Definition: layer.h:146
applies max-pooing operaton to the spatial data
Definition: max_unpooling_layer.h:38
std::vector< index3d< serial_size_t > > in_shape() const override
array of input shapes (width x height x depth)
Definition: max_unpooling_layer.h:131
max_unpooling_layer(serial_size_t in_width, serial_size_t in_height, serial_size_t in_channels, serial_size_t unpooling_size, serial_size_t stride)
Definition: max_unpooling_layer.h:69
max_unpooling_layer(serial_size_t in_width, serial_size_t in_height, serial_size_t in_channels, serial_size_t unpooling_size)
Definition: max_unpooling_layer.h:49
size_t fan_out_size() const override
number of outgoing connections for each input unit used only for weight/bias initialization methods w...
Definition: max_unpooling_layer.h:88
std::string layer_type() const override
name of layer, should be unique for each concrete class
Definition: max_unpooling_layer.h:133
std::vector< index3d< serial_size_t > > out_shape() const override
array of output shapes (width x height x depth)
Definition: max_unpooling_layer.h:132
size_t fan_in_size() const override
number of incoming connections for each output unit used only for weight/bias initialization methods ...
Definition: max_unpooling_layer.h:84
Definition: parallel_for.h:70