29 #include "tiny_dnn/node.h"
30 #include "tiny_dnn/layers/layer.h"
31 #include "tiny_dnn/network.h"
41 : root_(root_node), name_(graph_name) {}
45 : root_(
network[0]), name_(graph_name) {}
51 generate_header(stream);
52 generate_nodes(stream);
53 generate_footer(stream);
57 typedef std::unordered_map<const node*, std::string> node2name_t;
59 void generate_header(std::ostream& stream) {
60 stream <<
"digraph \"" << name_ <<
"\" {" << std::endl;
61 stream <<
" node [ shape=record ];" << std::endl;
64 void generate_nodes(std::ostream& stream) {
65 node2name_t node2name;
66 get_layer_names(node2name);
69 [&](
const layer& l) { generate_layer(stream, l, node2name); },
70 [&](
const edge& e) { generate_edge(stream, e, node2name); });
73 void get_layer_names(node2name_t& node2name) {
74 std::unordered_map<std::string, int> layer_counts;
76 auto namer = [&](
const layer& l) {
77 std::string ltype = l.layer_type();
80 node2name[&l] =
"\"" + ltype + to_string(layer_counts[l.layer_type()]++) +
"\"";
83 graph_traverse(root_, namer, [&](
const edge&){});
86 void generate_edge(std::ostream& stream,
const edge& e, node2name_t& node2name) {
91 serial_size_t dst_port = n->prev_port(e);
92 serial_size_t src_port = prev->next_port(e);
93 stream <<
" " << node2name[prev] <<
":out" << src_port <<
94 " -> " << node2name[n] <<
":in" << dst_port <<
";" << std::endl;
98 void generate_layer(std::ostream& stream,
const layer& layer, node2name_t& node2name) {
99 stream <<
" " << node2name[&layer] <<
" [" << std::endl;
100 stream <<
" label= \"";
101 stream << layer.layer_type() <<
"|{{in";
102 generate_layer_channels(stream, layer.in_shape(), layer.in_types(),
"in");
104 generate_layer_channels(stream, layer.out_shape(), layer.out_types(),
"out");
105 stream <<
"}}\""<< std::endl;
106 stream <<
" ];" << std::endl;
109 void generate_layer_channels(std::ostream& stream,
110 const std::vector<shape3d>& shapes,
111 const std::vector<vector_type>& vtypes,
112 const std::string& port_prefix) {
113 CNN_UNREFERENCED_PARAMETER(vtypes);
114 for (
size_t i = 0; i < shapes.size(); i++) {
115 stream <<
"|<" << port_prefix << i <<
">" << shapes[i] <<
"(" << vtypes[i] <<
")";
119 void generate_footer(std::ostream& stream) {
120 stream <<
"}" << std::endl;
utility for graph visualization
Definition: graph_visualizer.h:38
void generate(std::ostream &stream)
generate graph structure in dot language format
Definition: graph_visualizer.h:50
base class of all kind of NN layers
Definition: layer.h:62
A model of neural networks in tiny-dnn.
Definition: network.h:167