28 #include "tiny_dnn/util/util.h"
31 namespace weight_init {
35 virtual void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) = 0;
40 scalable(float_t value) : scale_(value) {}
42 void scale(float_t value) {
61 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out)
override {
62 const float_t weight_base = std::sqrt(scale_ / (fan_in + fan_out));
64 uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
80 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out)
override {
81 CNN_UNREFERENCED_PARAMETER(fan_out);
83 const float_t weight_base = scale_ / std::sqrt(float_t(fan_in));
85 uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
94 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out)
override {
95 CNN_UNREFERENCED_PARAMETER(fan_in);
96 CNN_UNREFERENCED_PARAMETER(fan_out);
98 gaussian_rand(weight->begin(), weight->end(), float_t(0), scale_);
107 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out)
override {
108 CNN_UNREFERENCED_PARAMETER(fan_in);
109 CNN_UNREFERENCED_PARAMETER(fan_out);
111 std::fill(weight->begin(), weight->end(), scale_);
118 explicit he(float_t value) :
scalable(value) {}
120 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out)
override {
121 CNN_UNREFERENCED_PARAMETER(fan_out);
123 const float_t sigma = std::sqrt(scale_ /fan_in);
125 gaussian_rand(weight->begin(), weight->end(), float_t(0), sigma);
Definition: weight_init.h:102
Definition: weight_init.h:33
Definition: weight_init.h:89
Definition: weight_init.h:115
Use fan-in(number of input weight for each neuron) for scaling.
Definition: weight_init.h:75
Definition: weight_init.h:38
Use fan-in and fan-out for scaling.
Definition: weight_init.h:56