47 #include "tiny_dnn/core/framework/device.fwd.h"
48 #include "tiny_dnn/core/params/conv_params.h"
59 : device_(device), params_(params) {}
62 Device* device()
const {
return device_; }
65 Params* params()
const {
return params_; }
79 Device* device_ptr =
nullptr;
82 layer* layer_ptr_ =
nullptr;
85 Params* params_ptr_ =
nullptr;
88 bool parallelize =
false;
90 backend_t engine = default_engine();
94 std::vector<tensor_t*>& out_data)
95 : in_data_(in_data), out_data_(out_data) {
96 op_params_ = std::unique_ptr<OpParams>(
new OpParams());
100 const std::vector<tensor_t*>& out_data,
101 std::vector<tensor_t*>& out_grad,
102 std::vector<tensor_t*>& in_grad)
104 , out_data_(out_data)
105 , out_grad_(out_grad)
106 , in_grad_(in_grad) {
107 op_params_ = std::unique_ptr<OpParams>(
new OpParams());
110 tensor_t& input(
const int idx)
const {
111 return *in_data_[idx];
114 tensor_t& output(
const int idx)
const {
115 return *out_data_[idx];
118 tensor_t& input_grad(
const int idx)
const {
119 return *in_grad_[idx];
122 tensor_t& output_grad(
const int idx)
const {
123 return *out_grad_[idx];
126 void setParams(Params* params) {
127 op_params_->params_ptr_ = params;
130 Params* params()
const {
131 return op_params_->params_ptr_;
134 void setParallelize(
const bool parallelize) {
135 op_params_->parallelize = parallelize;
138 bool parallelize()
const {
139 return op_params_->parallelize;
142 void setDevice(Device* device) {
143 op_params_->device_ptr = device;
146 Device* device()
const {
147 return op_params_->device_ptr;
150 void setLayer(layer* layer) {
151 op_params_->layer_ptr_ = layer;
154 layer* Layer()
const {
155 return op_params_->layer_ptr_;
158 backend_t engine()
const {
159 return op_params_->engine;
162 void setEngine(
const backend_t engine) {
163 op_params_->engine = engine;
167 std::vector<tensor_t*> in_data_;
168 std::vector<tensor_t*> out_data_;
169 std::vector<tensor_t*> out_grad_;
170 std::vector<tensor_t*> in_grad_;
172 std::unique_ptr<OpParams> op_params_;
179 : device_(context.device())
180 , params_(context.params()) {}
187 Device* device_ =
nullptr;
188 Params* params_ =
nullptr;
Definition: device.fwd.h:73
Definition: op_kernel.h:55
Definition: op_kernel.h:72
Definition: op_kernel.h:175
base class of all kind of NN layers
Definition: layer.h:62
Definition: op_kernel.h:74