33 #include <cereal/archives/json.hpp>
34 #include <cereal/types/memory.hpp>
35 #include "tiny_dnn/util/nn_error.h"
36 #include "tiny_dnn/util/macro.h"
37 #include "tiny_dnn/layers/layers.h"
41 template <
typename OutputArchive>
44 void register_saver(
const std::string& name, std::function<
void(OutputArchive&,
const layer*)> func) {
45 savers_[name] = [=](
void* ar,
const layer* l) {
46 return func(*
reinterpret_cast<OutputArchive*
>(ar), l);
51 void register_type(
const std::string& name) {
52 type_names_[
typeid(T)] = name;
55 void save(
const std::string& layer_name, OutputArchive & ar,
const layer *l) {
58 if (savers_.find(layer_name) == savers_.end()) {
59 throw nn_error(
"Failed to generate layer. Generator for " + layer_name +
" is not found.\n"
60 "Please use CNN_REGISTER_LAYER_DESERIALIZER macro to register appropriate generator");
63 savers_[layer_name](
reinterpret_cast<void*
>(&ar), l);
66 const std::string& type_name(std::type_index index)
const {
67 if (type_names_.find(index) == type_names_.end()) {
68 throw nn_error(
"Typename is not registered");
70 return type_names_.at(index);
79 void check_if_enabled()
const {
80 #ifdef CNN_NO_SERIALIZATION
81 static_assert(
sizeof(OutputArchive)==0,
82 "You are using save functions, but serialization function is disabled in current configuration.\n\n"
83 "You need to undef CNN_NO_SERIALIZATION to enable these functions.\n"
84 "If you are using cmake, you can use -DUSE_SERIALIZER=ON option.\n\n");
89 std::map<std::string, std::function<void(
void*,
const layer*)>> savers_;
91 std::map<std::type_index, std::string> type_names_;
94 static void save_layer_impl(OutputArchive& oa,
const layer*
layer);
96 #define CNN_REGISTER_LAYER_BODY(layer_type, layer_name) \
97 register_type<layer_type>(layer_name);\
98 register_saver(layer_name, save_layer_impl<layer_type>)
100 #define CNN_REGISTER_LAYER(layer_type, layer_name) CNN_REGISTER_LAYER_BODY(layer_type, #layer_name)
102 #define CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, activation_type, layer_name) \
103 CNN_REGISTER_LAYER_BODY(layer_type<activation::activation_type>, #layer_name "<" #activation_type ">")
105 #define CNN_REGISTER_LAYER_WITH_ACTIVATIONS(layer_type, layer_name) \
106 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, tan_h, layer_name); \
107 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, softmax, layer_name); \
108 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, identity, layer_name); \
109 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, sigmoid, layer_name); \
110 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, relu, layer_name); \
111 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, leaky_relu, layer_name); \
112 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, elu, layer_name); \
113 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, tan_hp1m2, layer_name)
116 #include "serialization_layer_list.h"
119 #undef CNN_REGISTER_LAYER_BODY
120 #undef CNN_REGISTER_LAYER
121 #undef CNN_REGISTER_LAYER_WITH_ACTIVATION
122 #undef CNN_REGISTER_LAYER_WITH_ACTIVATIONS
126 template <
typename OutputArchive>
127 template <
typename T>
130 *
dynamic_cast<const T*
>(
layer)));
133 template <
typename OutputArchive>
134 void layer::save_layer(OutputArchive & oa,
const layer& l) {
135 const std::string& name = serialization_helper<OutputArchive>::get_instance().type_name(
typeid(l));
136 serialization_helper<OutputArchive>::get_instance().save(name, oa, &l);
139 template <
class Archive>
140 void layer::serialize_prolog(Archive & ar) {
141 ar(cereal::make_nvp(
"type",
142 serialization_helper<Archive>::get_instance().type_name(
typeid(*
this))));
base class of all kind of NN layers
Definition: layer.h:62
error exception class for tiny-dnn
Definition: nn_error.h:37
Definition: serialization_helper.h:42