tiny_dnn  1.0.0
A header only, dependency-free deep learning framework in C++11
weight_init.h
1 /*
2  Copyright (c) 2015, Taiga Nomi
3  All rights reserved.
4 
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  * Redistributions of source code must retain the above copyright
8  notice, this list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright
10  notice, this list of conditions and the following disclaimer in the
11  documentation and/or other materials provided with the distribution.
12  * Neither the name of the <organization> nor the
13  names of its contributors may be used to endorse or promote products
14  derived from this software without specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 #pragma once
28 #include "tiny_dnn/util/util.h"
29 
30 namespace tiny_dnn {
31 namespace weight_init {
32 
33 class function {
34 public:
35  virtual void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) = 0;
36 };
37 
38 class scalable : public function {
39 public:
40  scalable(float_t value) : scale_(value) {}
41 
42  void scale(float_t value) {
43  scale_ = value;
44  }
45 protected:
46  float_t scale_;
47 };
48 
56 class xavier : public scalable {
57 public:
58  xavier() : scalable(float_t(6)) {}
59  explicit xavier(float_t value) : scalable(value) {}
60 
61  void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
62  const float_t weight_base = std::sqrt(scale_ / (fan_in + fan_out));
63 
64  uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
65  }
66 };
67 
75 class lecun : public scalable {
76 public:
77  lecun() : scalable(float_t(1)) {}
78  explicit lecun(float_t value) : scalable(value) {}
79 
80  void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
81  CNN_UNREFERENCED_PARAMETER(fan_out);
82 
83  const float_t weight_base = scale_ / std::sqrt(float_t(fan_in));
84 
85  uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
86  }
87 };
88 
89 class gaussian : public scalable {
90 public:
91  gaussian() : scalable(float_t(1)) {}
92  explicit gaussian(float_t sigma) : scalable(sigma) {}
93 
94  void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
95  CNN_UNREFERENCED_PARAMETER(fan_in);
96  CNN_UNREFERENCED_PARAMETER(fan_out);
97 
98  gaussian_rand(weight->begin(), weight->end(), float_t(0), scale_);
99  }
100 };
101 
102 class constant : public scalable {
103 public:
104  constant() : scalable(float_t(0)) {}
105  explicit constant(float_t value) : scalable(value) {}
106 
107  void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
108  CNN_UNREFERENCED_PARAMETER(fan_in);
109  CNN_UNREFERENCED_PARAMETER(fan_out);
110 
111  std::fill(weight->begin(), weight->end(), scale_);
112  }
113 };
114 
115 class he : public scalable {
116 public:
117  he() : scalable(float_t(2)) {}
118  explicit he(float_t value) : scalable(value) {}
119 
120  void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
121  CNN_UNREFERENCED_PARAMETER(fan_out);
122 
123  const float_t sigma = std::sqrt(scale_ /fan_in);
124 
125  gaussian_rand(weight->begin(), weight->end(), float_t(0), sigma);
126  }
127 };
128 
129 } // namespace weight_init
130 } // namespace tiny_dnn
Definition: weight_init.h:102
Definition: weight_init.h:33
Definition: weight_init.h:89
Definition: weight_init.h:115
Use fan-in(number of input weight for each neuron) for scaling.
Definition: weight_init.h:75
Definition: weight_init.h:38
Use fan-in and fan-out for scaling.
Definition: weight_init.h:56