33 #include <cereal/archives/json.hpp>
34 #include <cereal/types/memory.hpp>
35 #include "tiny_dnn/util/nn_error.h"
36 #include "tiny_dnn/util/macro.h"
37 #include "tiny_dnn/layers/layers.h"
41 template <
typename InputArchive>
44 void register_loader(
const std::string& name, std::function<std::shared_ptr<layer>(InputArchive&)> func) {
45 loaders_[name] = [=](
void* ar) {
46 return func(*
reinterpret_cast<InputArchive*
>(ar));
51 void register_type(
const std::string& name) {
52 type_names_[
typeid(T)] = name;
55 std::shared_ptr<layer> load(
const std::string& layer_name, InputArchive& ar) {
58 if (loaders_.find(layer_name) == loaders_.end()) {
59 throw nn_error(
"Failed to generate layer. Generator for " + layer_name +
" is not found.\n"
60 "Please use CNN_REGISTER_LAYER_DESERIALIZER macro to register appropriate generator");
63 return loaders_[layer_name](
reinterpret_cast<void*
>(&ar));
66 const std::string& type_name(std::type_index index)
const {
67 if (type_names_.find(index) == type_names_.end()) {
68 throw nn_error(
"Typename is not registered");
70 return type_names_.at(index);
79 void check_if_enabled()
const {
80 #ifdef CNN_NO_SERIALIZATION
81 static_assert(
sizeof(InputArchive)==0,
82 "You are using load functions, but deserialization function is disabled in current configuration.\n\n"
83 "You need to undef CNN_NO_SERIALIZATION to enable these functions.\n"
84 "If you are using cmake, you can use -DUSE_SERIALIZER=ON option.\n\n");
89 std::map<std::string, std::function<std::shared_ptr<layer>(
void*)>> loaders_;
91 std::map<std::type_index, std::string> type_names_;
94 static std::shared_ptr<layer> load_layer_impl(InputArchive& ia);
96 #define CNN_REGISTER_LAYER_BODY(layer_type, layer_name) \
97 register_loader(layer_name, load_layer_impl<layer_type>);\
98 register_type<layer_type>(layer_name);
100 #define CNN_REGISTER_LAYER(layer_type, layer_name) CNN_REGISTER_LAYER_BODY(layer_type, #layer_name)
102 #define CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, activation_type, layer_name) \
103 CNN_REGISTER_LAYER_BODY(layer_type<activation::activation_type>, #layer_name "<" #activation_type ">")
105 #define CNN_REGISTER_LAYER_WITH_ACTIVATIONS(layer_type, layer_name) \
106 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, tan_h, layer_name); \
107 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, softmax, layer_name); \
108 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, identity, layer_name); \
109 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, sigmoid, layer_name); \
110 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, relu, layer_name); \
111 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, leaky_relu, layer_name); \
112 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, elu, layer_name); \
113 CNN_REGISTER_LAYER_WITH_ACTIVATION(layer_type, tan_hp1m2, layer_name)
116 #include "serialization_layer_list.h"
119 #undef CNN_REGISTER_LAYER_BODY
120 #undef CNN_REGISTER_LAYER
121 #undef CNN_REGISTER_LAYER_WITH_ACTIVATION
122 #undef CNN_REGISTER_LAYER_WITH_ACTIVATIONS
126 template <
typename InputArchive>
127 template <
typename T>
130 using ST =
typename std::aligned_storage<
sizeof(T), CNN_ALIGNOF(T)>::type;
132 std::unique_ptr<ST> bn(
new ST());
134 cereal::memory_detail::LoadAndConstructLoadWrapper<InputArchive, T> wrapper(
reinterpret_cast<T*
>(bn.get()));
136 wrapper.CEREAL_SERIALIZE_FUNCTION_NAME(ia);
138 std::shared_ptr<layer> t;
139 t.reset(
reinterpret_cast<T*
>(bn.get()));
145 template <
typename T>
146 void start_loading_layer(T & ar) {}
148 template <
typename T>
149 void finish_loading_layer(T & ar) {}
151 inline void start_loading_layer(cereal::JSONInputArchive & ia) { ia.startNode(); }
153 inline void finish_loading_layer(cereal::JSONInputArchive & ia) { ia.finishNode(); }
158 template <
typename InputArchive>
160 start_loading_layer(ia);
163 ia(cereal::make_nvp(
"type", p));
166 finish_loading_layer(ia);
Definition: deserialization_helper.h:42
static std::shared_ptr< layer > load_layer(InputArchive &ia)
generate layer from cereal's Archive
Definition: deserialization_helper.h:159
error exception class for tiny-dnn
Definition: nn_error.h:37