31 #include <unordered_map>
32 #include <cereal/types/utility.hpp>
33 #include <cereal/types/tuple.hpp>
35 #include "tiny_dnn/util/util.h"
36 #include "tiny_dnn/layers/layer.h"
37 #include "tiny_dnn/optimizers/optimizer.h"
41 template <
typename Archive>
42 void save(Archive & ar,
const std::vector<tiny_dnn::layerptr_t>& v) {
43 ar(cereal::make_size_tag((cereal::size_type)v.size()));
45 tiny_dnn::layer::save_layer(ar, *n);
50 template <
typename Archive>
51 void load(Archive & ar, std::vector<std::shared_ptr<tiny_dnn::layer>>& v) {
52 cereal::size_type size;
53 ar(cereal::make_size_tag(size));
55 for (
size_t i = 0; i < size; i++) {
87 typedef std::vector<layerptr_t>::iterator iterator;
88 typedef std::vector<layerptr_t>::const_iterator const_iterator;
96 void backward(
const std::vector<tensor_t>& first) = 0;
103 std::vector<tensor_t>
forward(
const std::vector<tensor_t>& first) = 0;
110 for (
auto l : nodes_) {
111 l->update_weight(opt, batch_size);
118 virtual void setup(
bool reset_weight) {
119 for (
auto l : nodes_) {
120 l->setup(reset_weight);
125 for (
auto l : nodes_) {
130 size_t size()
const {
return nodes_.size(); }
131 iterator begin() {
return nodes_.begin(); }
132 iterator end() {
return nodes_.end(); }
133 const_iterator begin()
const {
return nodes_.begin(); }
134 const_iterator end()
const {
return nodes_.end(); }
135 layer* operator[] (
size_t index) {
return nodes_[index]; }
136 const layer* operator[] (
size_t index)
const {
return nodes_[index]; }
137 serial_size_t in_data_size()
const {
return nodes_.front()->in_data_size(); }
138 serial_size_t out_data_size()
const {
return nodes_.back()->out_data_size(); }
140 template <
typename T>
141 const T& at(
size_t index)
const {
142 const T* v =
dynamic_cast<const T*
>(nodes_[index]);
144 throw nn_error(
"failed to cast");
147 template <
typename T>
148 T& at(
size_t index) {
149 T* v =
dynamic_cast<T*
>(nodes_[index]);
151 throw nn_error(
"failed to cast");
155 virtual float_t target_value_min(
int out_channel = 0)
const {
156 CNN_UNREFERENCED_PARAMETER(out_channel);
157 return nodes_.back()->out_value_range().first;
160 virtual float_t target_value_max(
int out_channel = 0)
const {
161 CNN_UNREFERENCED_PARAMETER(out_channel);
162 return nodes_.back()->out_value_range().second;
165 void save(std::ostream& os)
const {
166 for (
auto& l : nodes_) {
171 void load(std::istream& is) {
173 for (
auto& l : nodes_) {
178 virtual void load(
const std::vector<float_t>& vec) {
181 for (
auto& l : nodes_) {
186 void label2vec(
const label_t* t, serial_size_t num, std::vector<vec_t> *vec)
const {
187 serial_size_t outdim = out_data_size();
190 for (serial_size_t i = 0; i < num; i++) {
191 assert(t[i] < outdim);
192 vec->emplace_back(outdim, target_value_min());
193 vec->back()[t[i]] = target_value_max();
197 template <
typename OutputArchive>
198 void save_model(OutputArchive & oa)
const;
200 template <
typename InputArchive>
201 void load_model(InputArchive & ia);
204 template <
typename OutputArchive>
205 void save_weights(OutputArchive & oa)
const {
206 for (
auto n : nodes_) {
211 template <
typename InputArchive>
212 void load_weights(InputArchive & ia) {
213 for (
auto n : nodes_) {
219 template <
typename T>
220 void push_back(T&& node) {
221 push_back_impl(std::forward<T>(node),
222 typename std::is_rvalue_reference<decltype(node)>::type());
225 template <
typename T>
226 void push_back(std::shared_ptr<T> node) {
227 own_nodes_.push_back(node);
228 nodes_.push_back(own_nodes_.back().get());
234 std::vector<tensor_t> reorder_for_layerwise_processing(
const std::vector<tensor_t>& input) {
235 const serial_size_t sample_count =
static_cast<serial_size_t
>(input.size());
236 const serial_size_t channel_count =
static_cast<serial_size_t
>(input[0].size());
239 std::vector<tensor_t> output(channel_count, tensor_t(sample_count));
241 for (serial_size_t sample = 0; sample < sample_count; ++sample) {
242 assert(input[sample].size() == channel_count);
243 for (serial_size_t channel = 0; channel < channel_count; ++channel) {
244 output[channel][sample] = input[sample][channel];
251 template <
typename T>
252 void push_back_impl(T&& node, std::true_type) {
253 own_nodes_.push_back(std::make_shared<
254 typename std::remove_reference<T>::type>(std::forward<T>(node)));
255 nodes_.push_back(own_nodes_.back().get());
258 template <
typename T>
259 void push_back_impl(T&& node, std::false_type) {
260 nodes_.push_back(&node);
264 std::vector<std::shared_ptr<layer>> own_nodes_;
266 std::vector<layerptr_t> nodes_;
274 void backward(
const std::vector<tensor_t>& first)
override {
276 const std::vector<tensor_t> reordered_grad = reorder_for_layerwise_processing(first);
277 assert(reordered_grad.size() == 1);
279 nodes_.back()->set_out_grads({ reordered_grad[0] });
281 for (
auto l = nodes_.rbegin(); l != nodes_.rend(); l++) {
286 std::vector<tensor_t>
forward(
const std::vector<tensor_t>& first)
override {
288 const std::vector<tensor_t> reordered_data = reorder_for_layerwise_processing(first);
289 assert(reordered_data.size() == 1);
291 nodes_.front()->set_in_data({ reordered_data[0] });
293 for (
auto l : nodes_) {
297 const std::vector<tensor_t> out = nodes_.back()->output();
299 return normalize_out(out);
302 template <
typename T>
303 void add(T&&
layer) {
304 push_back(std::forward<T>(
layer));
306 if (nodes_.size() != 1) {
307 auto head = nodes_[nodes_.size()-2];
308 auto tail = nodes_[nodes_.size()-1];
309 connect(head, tail, 0, 0);
310 auto out = head->outputs();
311 auto in = tail->inputs();
313 check_connectivity();
316 void check_connectivity() {
317 for (serial_size_t i = 0; i < nodes_.size() - 1; i++) {
318 auto out = nodes_[i]->outputs();
319 auto in = nodes_[i+1]->inputs();
321 if (out[0] != in[0]) {
327 template <
typename InputArchive>
328 void load_connections(InputArchive& ia) {
329 for (serial_size_t i = 0; i < nodes_.size() - 1; i++) {
330 auto head = nodes_[i];
331 auto tail = nodes_[i + 1];
332 connect(head, tail, 0, 0);
336 template <
typename OutputArchive>
337 void save_connections(OutputArchive& )
const { }
342 std::vector<tensor_t> normalize_out(
const std::vector<tensor_t>& out)
345 std::vector<tensor_t> normalized_output;
347 const size_t sample_count = out[0].size();
348 normalized_output.resize(sample_count, tensor_t(1));
350 for (
size_t sample = 0; sample < sample_count; ++sample) {
351 normalized_output[sample][0] = out[0][sample];
354 return normalized_output;
364 void backward(
const std::vector<tensor_t>& out_grad)
override {
366 serial_size_t output_channel_count =
static_cast<serial_size_t
>(out_grad[0].size());
368 if (output_channel_count != output_layers_.size()) {
369 throw nn_error(
"input size mismatch");
372 const std::vector<tensor_t> reordered_grad = reorder_for_layerwise_processing(out_grad);
373 assert(reordered_grad.size() == output_channel_count);
375 for (serial_size_t i = 0; i < output_channel_count; i++) {
376 output_layers_[i]->set_out_grads({ reordered_grad[i] });
379 for (
auto l = nodes_.rbegin(); l != nodes_.rend(); l++) {
384 std::vector<tensor_t>
forward(
const std::vector<tensor_t>& in_data)
override {
386 serial_size_t input_data_channel_count =
static_cast<serial_size_t
>(in_data[0].size());
388 if (input_data_channel_count != input_layers_.size()) {
389 throw nn_error(
"input size mismatch");
392 const std::vector<tensor_t> reordered_data = reorder_for_layerwise_processing(in_data);
393 assert(reordered_data.size() == input_data_channel_count);
395 for (serial_size_t channel_index = 0; channel_index < input_data_channel_count; channel_index++) {
396 input_layers_[channel_index]->set_in_data({ reordered_data[channel_index] });
399 for (
auto l : nodes_) {
405 void construct(
const std::vector<layerptr_t>& input,
406 const std::vector<layerptr_t>& output) {
407 std::vector<layerptr_t> sorted;
408 std::vector<nodeptr_t> input_nodes(input.begin(), input.end());
409 std::unordered_map<node*, std::vector<uint8_t>> removed_edge;
412 while (!input_nodes.empty()) {
413 sorted.push_back(
dynamic_cast<layerptr_t>(input_nodes.back()));
414 input_nodes.pop_back();
417 std::vector<node*> next = curr->next_nodes();
419 for (
size_t i = 0; i < next.size(); i++) {
420 if (!next[i])
continue;
422 if (removed_edge.find(next[i]) == removed_edge.end()) {
423 removed_edge[next[i]] =
424 std::vector<uint8_t>(next[i]->prev_nodes().size(), 0);
427 std::vector<uint8_t>& removed = removed_edge[next[i]];
428 removed[find_index(next[i]->prev_nodes(), curr)] = 1;
430 if (std::all_of(removed.begin(), removed.end(), [](uint8_t x) {
432 input_nodes.push_back(next[i]);
437 for (
auto& n : sorted) {
441 input_layers_ = input;
442 output_layers_ = output;
450 struct _graph_connection {
451 void add_connection(serial_size_t head, serial_size_t tail, serial_size_t head_index, serial_size_t tail_index) {
452 if (!is_connected(head, tail, head_index, tail_index)) {
453 connections.emplace_back(head, tail, head_index, tail_index);
457 bool is_connected(serial_size_t head, serial_size_t tail, serial_size_t head_index, serial_size_t tail_index)
const {
458 return std::find(connections.begin(),
460 std::make_tuple(head, tail, head_index, tail_index)) != connections.end();
463 template <
typename Archive>
464 void serialize(Archive & ar) {
465 ar(CEREAL_NVP(connections), CEREAL_NVP(in_nodes), CEREAL_NVP(out_nodes));
468 std::vector<std::tuple<serial_size_t, serial_size_t, serial_size_t, serial_size_t>> connections;
469 std::vector<serial_size_t> in_nodes, out_nodes;
472 template <
typename OutputArchive>
473 void save_connections(OutputArchive& oa)
const {
474 _graph_connection gc;
475 std::unordered_map<node*, serial_size_t> node2id;
476 serial_size_t idx = 0;
478 for (
auto n : nodes_) {
481 for (
auto l : input_layers_) {
482 gc.in_nodes.push_back(node2id[l]);
484 for (
auto l : output_layers_) {
485 gc.out_nodes.push_back(node2id[l]);
488 for (
auto l : input_layers_) {
489 graph_traverse(l, [=](layer& l) {}, [&](edge& e) {
490 auto next = e.next();
491 serial_size_t head_index = e.prev()->next_port(e);
493 for (
auto n : next) {
494 serial_size_t tail_index = n->prev_port(e);
495 gc.add_connection(node2id[e.prev()], node2id[n], head_index, tail_index);
500 oa(cereal::make_nvp(
"graph", gc));
503 template <
typename InputArchive>
504 void load_connections(InputArchive& ia) {
505 _graph_connection gc;
506 ia(cereal::make_nvp(
"graph", gc));
508 for (
auto c : gc.connections) {
509 serial_size_t head, tail, head_index, tail_index;
510 std::tie(head, tail, head_index, tail_index) = c;
511 connect(nodes_[head], nodes_[tail], head_index, tail_index);
513 for (
auto in : gc.in_nodes) {
514 input_layers_.push_back(nodes_[in]);
516 for (
auto out : gc.out_nodes) {
517 output_layers_.push_back(nodes_[out]);
522 std::vector<tensor_t> merge_outs() {
523 std::vector<tensor_t> merged;
524 serial_size_t output_channel_count =
static_cast<serial_size_t
>(output_layers_.size());
525 for (serial_size_t output_channel = 0; output_channel < output_channel_count; ++output_channel) {
526 std::vector<tensor_t> out = output_layers_[output_channel]->output();
528 serial_size_t sample_count =
static_cast<serial_size_t
>(out[0].size());
529 if (output_channel == 0) {
530 assert(merged.empty());
531 merged.resize(sample_count, tensor_t(output_channel_count));
534 assert(merged.size() == sample_count);
536 for (serial_size_t sample = 0; sample < sample_count; ++sample) {
537 merged[sample][output_channel] = out[0][sample];
543 serial_size_t find_index(
const std::vector<node*>& nodes,
545 for (serial_size_t i = 0; i < nodes.size(); i++) {
546 if (nodes[i] ==
static_cast<node*
>(&*target))
return i;
548 throw nn_error(
"invalid connection");
550 std::vector<layerptr_t> input_layers_;
551 std::vector<layerptr_t> output_layers_;
556 template <
typename OutputArchive>
557 void nodes::save_model(OutputArchive & oa)
const {
558 oa(cereal::make_nvp(
"nodes", nodes_));
560 if (
typeid(*
this) ==
typeid(sequential)) {
561 dynamic_cast<const sequential*
>(
this)->save_connections(oa);
564 dynamic_cast<const graph*
>(
this)->save_connections(oa);
568 template <
typename InputArchive>
569 void nodes::load_model(InputArchive & ia) {
573 ia(cereal::make_nvp(
"nodes", own_nodes_));
575 for (
auto& n : own_nodes_) {
576 nodes_.push_back(&*n);
579 if (
typeid(*
this) ==
typeid(sequential)) {
580 dynamic_cast<sequential*
>(
this)->load_connections(ia);
583 dynamic_cast<graph*
>(
this)->load_connections(ia);
generic graph network
Definition: nodes.h:362
void backward(const std::vector< tensor_t > &out_grad) override
propagate gradient
Definition: nodes.h:364
std::vector< tensor_t > forward(const std::vector< tensor_t > &in_data) override
Definition: nodes.h:384
base class of all kind of NN layers
Definition: layer.h:62
static std::shared_ptr< layer > load_layer(InputArchive &ia)
generate layer from cereal's Archive
Definition: deserialization_helper.h:159
error exception class for tiny-dnn
Definition: nn_error.h:37
basic class of various network types (sequential, multi-in/multi-out).
Definition: nodes.h:85
virtual void backward(const std::vector< tensor_t > &first)=0
propagate gradient
virtual void update_weights(optimizer *opt, int batch_size)
update weights and clear all gradients
Definition: nodes.h:109
virtual std::vector< tensor_t > forward(const std::vector< tensor_t > &first)=0
virtual void setup(bool reset_weight)
setup all weights, must be called before forward/backward
Definition: nodes.h:118
single-input, single-output feedforward network
Definition: nodes.h:272
std::vector< tensor_t > forward(const std::vector< tensor_t > &first) override
Definition: nodes.h:286
void backward(const std::vector< tensor_t > &first) override
propagate gradient
Definition: nodes.h:274
base class of optimizer usesHessian : true if an optimizer uses hessian (2nd order derivative of loss...
Definition: optimizer.h:37