35 #include <unordered_set>
37 #include "tiny_dnn/util/util.h"
38 #include "tiny_dnn/util/product.h"
39 #include "tiny_dnn/util/image.h"
40 #include "tiny_dnn/util/weight_init.h"
41 #include "tiny_dnn/optimizers/optimizer.h"
43 #include "tiny_dnn/activations/activation_function.h"
51 typedef node* nodeptr_t;
52 typedef std::shared_ptr<edge> edgeptr_t;
54 typedef layer* layerptr_t;
59 class node :
public std::enable_shared_from_this<node> {
61 node(serial_size_t in_size, serial_size_t out_size)
62 : prev_(in_size), next_(out_size) {}
65 const std::vector<edgeptr_t>& prev()
const {
return prev_; }
66 const std::vector<edgeptr_t>& next()
const {
return next_; }
68 serial_size_t prev_port(
const edge& e)
const {
69 auto it = std::find_if(prev_.begin(), prev_.end(),
70 [&](edgeptr_t ep) { return ep.get() == &e; });
71 return (serial_size_t)std::distance(prev_.begin(), it);
74 serial_size_t next_port(
const edge& e)
const {
75 auto it = std::find_if(next_.begin(), next_.end(),
76 [&](edgeptr_t ep) { return ep.get() == &e; });
77 return (serial_size_t)std::distance(next_.begin(), it);
80 std::vector<node*> prev_nodes()
const;
81 std::vector<node*> next_nodes()
const;
86 serial_size_t head_index, serial_size_t tail_index);
88 mutable std::vector<edgeptr_t> prev_;
89 mutable std::vector<edgeptr_t> next_;
100 data_({vec_t(shape.size())}),
101 grad_({vec_t(shape.size())}),
104 void merge_grads(vec_t *dst) {
105 dst->resize(grad_[0].size());
106 std::fill(dst->begin(), dst->end(),
static_cast<float_t
>(0));
109 for (
size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {
110 vectorize::reduce<float_t>(&grad_[sample][0], dst->size(), &(*dst)[0]);
115 for (
size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {
116 std::fill(grad_[sample].begin(), grad_[sample].end(), (float_t)0);
120 tensor_t* get_data() {
124 const tensor_t* get_data()
const {
128 tensor_t* get_gradient() {
132 const tensor_t* get_gradient()
const {
136 const std::vector<node*>& next()
const {
return next_; }
137 node* prev() {
return prev_; }
138 const node* prev()
const {
return prev_; }
140 const shape3d& shape()
const {
return shape_; }
141 vector_type vtype()
const {
return vtype_; }
142 void add_next_node(
node* next) { next_.push_back(next); }
150 std::vector<node*> next_;
153 inline std::vector<node*> node::prev_nodes()
const {
154 std::set<node*> sets;
155 for (
auto& e : prev_) {
156 if (e && e->prev()) sets.insert(e->prev());
158 return std::vector<node*>(sets.begin(), sets.end());
161 inline std::vector<node*> node::next_nodes()
const {
162 std::set<node*> sets;
163 for (
auto& e : next_) {
166 sets.insert(n.begin(), n.end());
169 return std::vector<node*>(sets.begin(), sets.end());
172 template <
typename T>
175 nodes_.push_back(l1); nodes_.push_back(l2);
177 std::vector<T> nodes_;
180 template <
typename T>
185 template <
typename T>
186 node_tuple<std::shared_ptr<T>> operator , (std::shared_ptr<T> l1, std::shared_ptr<T> l2) {
187 return node_tuple<std::shared_ptr<T>>(l1, l2);
190 template <
typename T>
191 node_tuple<std::shared_ptr<T>> operator , (node_tuple<std::shared_ptr<T>> lhs, std::shared_ptr<T>& rhs) {
192 lhs.nodes_.push_back(rhs);
196 template <
typename T>
197 node_tuple<T*> operator , (node_tuple<T*> lhs, T& rhs) {
198 lhs.nodes_.push_back(&rhs);
202 template <
typename T,
typename U>
203 inline std::shared_ptr<U>& operator << (std::shared_ptr<T>& lhs,
204 std::shared_ptr<U>& rhs) {
205 connect(lhs.get(), rhs.get());
209 template <
typename T,
typename U>
210 inline U& operator << (
const node_tuple<T>& lhs, U& rhs) {
211 for (serial_size_t i = 0; i < static_cast<serial_size_t>(lhs.nodes_.size()); i++) {
212 connect(&*lhs.nodes_[i], &*rhs, 0, i);
217 template <
typename T,
typename U>
218 inline node_tuple<T>& operator << (U& lhs,
const node_tuple<T>& rhs) {
219 for (serial_size_t i = 0; i < static_cast<serial_size_t>(rhs.nodes_.size()); i++) {
220 connect(&*lhs, &*rhs.nodes_[i], i, 0);
225 template <
typename T,
typename U>
226 inline U& operator << (
const node_tuple<T*>& lhs, U& rhs) {
227 for (serial_size_t i = 0; i < static_cast<serial_size_t>(lhs.nodes_.size()); i++) {
228 connect(lhs.nodes_[i], &rhs, 0, i);
233 template <
typename T,
typename U>
234 inline node_tuple<T*>& operator << (U& lhs,
const node_tuple<T*>& rhs) {
235 for (serial_size_t i = 0; i < static_cast<serial_size_t>(rhs.nodes_.size()); i++) {
236 connect(&lhs, rhs.nodes_[i], i, 0);
class containing input/output data
Definition: node.h:95
base class of all kind of NN layers
Definition: layer.h:62
base class of all kind of tinny-cnn data
Definition: node.h:59