28 #include "tiny_dnn/util/util.h"
29 #include "tiny_dnn/layers/layer.h"
33 enum class slice_type {
70 :
layer(std::vector<vector_type>(1, vector_type::data), std::vector<vector_type>(num_outputs, vector_type::data)),
71 in_shape_(
in_shape), slice_type_(slice_type), num_outputs_(num_outputs) {
75 slice_layer(
const layer& prev_layer, slice_type slice_type, serial_size_t num_outputs)
76 :
layer(std::vector<vector_type>(1, vector_type::data), std::vector<vector_type>(num_outputs, vector_type::data)),
77 in_shape_(prev_layer.
out_shape()[0]), slice_type_(slice_type), num_outputs_(num_outputs) {
85 std::vector<shape3d>
in_shape()
const override {
94 std::vector<tensor_t*>& out_data)
override {
95 switch (slice_type_) {
96 case slice_type::slice_samples:
97 slice_data_forward(*in_data[0], out_data);
99 case slice_type::slice_channels:
100 slice_channels_forward(*in_data[0], out_data);
108 const std::vector<tensor_t*>& out_data,
109 std::vector<tensor_t*>& out_grad,
110 std::vector<tensor_t*>& in_grad)
override {
111 CNN_UNREFERENCED_PARAMETER(in_data);
112 CNN_UNREFERENCED_PARAMETER(out_data);
114 switch (slice_type_) {
115 case slice_type::slice_samples:
116 slice_data_backward(out_grad, *in_grad[0]);
118 case slice_type::slice_channels:
119 slice_channels_backward(out_grad, *in_grad[0]);
126 template <
class Archive>
127 static void load_and_construct(Archive & ar, cereal::construct<slice_layer> & construct) {
129 slice_type slice_type;
130 serial_size_t num_outputs;
132 ar(cereal::make_nvp(
"in_size",
in_shape), cereal::make_nvp(
"slice_type", slice_type), cereal::make_nvp(
"num_outputs", num_outputs));
133 construct(
in_shape, slice_type, num_outputs);
136 template <
class Archive>
137 void serialize(Archive & ar) {
138 layer::serialize_prolog(ar);
139 ar(cereal::make_nvp(
"in_size", in_shape_), cereal::make_nvp(
"slice_type", slice_type_), cereal::make_nvp(
"num_outputs", num_outputs_));
142 void slice_data_forward(
const tensor_t& in_data,
143 std::vector<tensor_t*>& out_data) {
144 const vec_t* in = &in_data[0];
146 for (serial_size_t i = 0; i < num_outputs_; i++) {
147 tensor_t& out = *out_data[i];
149 std::copy(in, in + slice_size_[i], &out[0]);
151 in += slice_size_[i];
155 void slice_data_backward(std::vector<tensor_t*>& out_grad,
157 vec_t* in = &in_grad[0];
159 for (serial_size_t i = 0; i < num_outputs_; i++) {
160 tensor_t& out = *out_grad[i];
162 std::copy(&out[0], &out[0] + slice_size_[i], in);
164 in += slice_size_[i];
168 void slice_channels_forward(
const tensor_t& in_data,
169 std::vector<tensor_t*>& out_data) {
170 serial_size_t num_samples =
static_cast<serial_size_t
>(in_data.size());
171 serial_size_t channel_idx = 0;
172 serial_size_t spatial_dim = in_shape_.area();
174 for (serial_size_t i = 0; i < num_outputs_; i++) {
175 for (serial_size_t s = 0; s < num_samples; s++) {
176 float_t *out = &(*out_data[i])[s][0];
177 const float_t *in = &in_data[s][0] + channel_idx*spatial_dim;
179 std::copy(in, in + slice_size_[i] * spatial_dim, out);
181 channel_idx += slice_size_[i];
185 void slice_channels_backward(std::vector<tensor_t*>& out_grad,
187 serial_size_t num_samples =
static_cast<serial_size_t
>(in_grad.size());
188 serial_size_t channel_idx = 0;
189 serial_size_t spatial_dim = in_shape_.area();
191 for (serial_size_t i = 0; i < num_outputs_; i++) {
192 for (serial_size_t s = 0; s < num_samples; s++) {
193 const float_t *out = &(*out_grad[i])[s][0];
194 float_t *in = &in_grad[s][0] + channel_idx*spatial_dim;
196 std::copy(out, out + slice_size_[i] * spatial_dim, in);
198 channel_idx += slice_size_[i];
202 void set_sample_count(serial_size_t sample_count)
override {
203 if (slice_type_ == slice_type::slice_samples) {
204 if (num_outputs_ == 0)
205 throw nn_error(
"num_outputs must be positive integer");
207 serial_size_t sample_per_out = sample_count / num_outputs_;
209 slice_size_.resize(num_outputs_, sample_per_out);
210 slice_size_.back() = sample_count - (sample_per_out*(num_outputs_-1));
212 Base::set_sample_count(sample_count);
216 switch (slice_type_) {
217 case slice_type::slice_samples:
220 case slice_type::slice_channels:
221 set_shape_channels();
224 throw nn_not_implemented_error();
228 void set_shape_data() {
229 out_shapes_.resize(num_outputs_, in_shape_);
232 void set_shape_channels() {
233 serial_size_t channel_per_out = in_shape_.depth_ / num_outputs_;
236 for (serial_size_t i = 0; i < num_outputs_; i++) {
237 serial_size_t ch = channel_per_out;
239 if (i == num_outputs_ - 1) {
240 assert(in_shape_.depth_ >= i * channel_per_out);
241 ch = in_shape_.depth_ - i * channel_per_out;
244 slice_size_.push_back(ch);
245 out_shapes_.push_back(shape3d(in_shape_.width_, in_shape_.height_, ch));
250 slice_type slice_type_;
251 serial_size_t num_outputs_;
252 std::vector<shape3d> out_shapes_;
253 std::vector<serial_size_t> slice_size_;
base class of all kind of NN layers
Definition: layer.h:62
Definition: nn_error.h:83
slice an input data into multiple outputs along a given slice dimension.
Definition: slice_layer.h:42
void forward_propagation(const std::vector< tensor_t * > &in_data, std::vector< tensor_t * > &out_data) override
Definition: slice_layer.h:93
std::vector< shape3d > out_shape() const override
array of output shapes (width x height x depth)
Definition: slice_layer.h:89
slice_layer(const shape3d &in_shape, slice_type slice_type, serial_size_t num_outputs)
Definition: slice_layer.h:69
std::vector< shape3d > in_shape() const override
array of input shapes (width x height x depth)
Definition: slice_layer.h:85
std::string layer_type() const override
name of layer, should be unique for each concrete class
Definition: slice_layer.h:81
void back_propagation(const std::vector< tensor_t * > &in_data, const std::vector< tensor_t * > &out_data, std::vector< tensor_t * > &out_grad, std::vector< tensor_t * > &in_grad) override
return delta of previous layer (delta=\frac{dE}{da}, a=wx in fully-connected layer)
Definition: slice_layer.h:107