47 #include "tiny_dnn/core/framework/op_kernel.h"
57 #if defined(USE_OPENCL) || defined(USE_CUDA)
58 auto params = OpKernel::params_->conv();
61 const tensor_t& in_data = context.input(0);
62 const tensor_t& W = context.input(1);
63 const tensor_t& bias = context.input(2);
64 tensor_t& out_data = context.output(1);
67 fill_tensor(out_data, float_t(0));
70 CLCudaAPI::Program program = ProgramManager::getInstance()
71 .program(
Program(context.device(), context.Layer()));
77 auto kernel = CLCudaAPI::Kernel(program,
"CFMulti");
81 CLCudaAPI::Context ctx = context.device()->context();
82 CLCudaAPI::Queue queue = context.device()->queue();
85 for (serial_size_t i = 0; i < in_data.size(); ++i) {
90 auto dev_in = CLCudaAPI::Buffer<float_t>(ctx, queue,
91 in_data[i].begin(), in_data[i].end());
93 auto dev_W = CLCudaAPI::Buffer<float_t>(ctx, queue,
94 W[0].begin(), W[0].end());
96 auto dev_bias = CLCudaAPI::Buffer<float_t>(ctx, queue,
97 bias[0].begin(), bias[0].end());
99 auto dev_out = CLCudaAPI::Buffer<float_t>(ctx, queue,
100 out_data[i].begin(), out_data[i].end());
102 kernel.SetArgument(0, dev_in);
103 kernel.SetArgument(1, 0);
104 kernel.SetArgument(2, dev_W);
105 kernel.SetArgument(3, 0);
106 kernel.SetArgument(4, dev_bias);
107 kernel.SetArgument(5, 0);
108 kernel.SetArgument(6, dev_out);
109 kernel.SetArgument(7, 0);
111 kernel.SetArgument(8,
static_cast<cl_ushort
>(params.in.width_));
112 kernel.SetArgument(9,
static_cast<cl_ushort
>(params.in.height_));
113 kernel.SetArgument(10,
static_cast<cl_ushort
>(params.out.width_));
114 kernel.SetArgument(11,
static_cast<cl_ushort
>(params.out.height_));
117 serial_size_t res = device->device().MaxWorkGroupSize() % 16;
118 serial_size_t size = device->device().MaxWorkGroupSize() - res;
120 auto global = std::vector<size_t>{size};
121 auto local = std::vector<size_t>{16};
124 auto event = CLCudaAPI::Event();
129 nn_info(
"## Running the kernel ...");
131 kernel.Launch(queue, global, local, event.pointer());
134 nn_info(
" > Took " + to_string(event.GetElapsedTime()) +
" ms");
137 std::vector<float_t> out(out_data[i].size(), 0);
138 dev_out.Read(queue, out_data[i].size(), out);
142 for (serial_size_t j = 0; j < out.size(); ++j) {
143 std::cout << out[j] <<
" ";
145 std::cout << std::endl;
148 std::copy(std::begin(out), std::end(out), std::back_inserter(out_data[i]));
151 throw nn_error(
"Not compiled with OpenCL");
Definition: conv2d_op_opencl.h:156
Definition: conv2d_op_opencl.h:51
Definition: device.fwd.h:73
Definition: op_kernel.h:55
Definition: op_kernel.h:72
Definition: op_kernel.h:175
error exception class for tiny-dnn
Definition: nn_error.h:37
info class for tiny-dnn (for debug)
Definition: nn_error.h:69
warning class for tiny-dnn (for debug)
Definition: nn_error.h:52