33 #include "tiny_dnn/util/util.h"
37 #pragma warning(disable:4996)
40 #define STB_IMAGE_IMPLEMENTATION
41 #define STB_IMAGE_INLINE
42 #include "third_party/stb/stb_image.h"
44 #define STB_IMAGE_RESIZE_IMPLEMENTATION
45 #define STB_IMAGE_RESIZE_INLINE
46 #include "third_party/stb/stb_image_resize.h"
48 #define STB_IMAGE_WRITE_IMPLEMENTATION
49 #define STB_IMAGE_WRITE_INLINE
50 #include "third_party/stb/stb_image_write.h"
58 typename std::enable_if<std::is_unsigned<T>::value, T>::type saturated_sub(T s1, T s2) {
59 return s1 > s2 ?
static_cast<T
>(s1 - s2) : 0;
63 typename std::enable_if<!std::is_unsigned<T>::value, T>::type saturated_sub(T s1, T s2) {
64 return static_cast<T
>(s1 - s2);
67 inline bool ends_with(std::string
const & value, std::string
const & ending) {
68 if (ending.size() > value.size())
return false;
69 return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
72 inline void resize_image_core(
const uint8_t* src,
int srcw,
int srch, uint8_t* dst,
int dstw,
int dsth,
int channels)
74 stbir_resize_uint8(src, srcw, srch, 0, dst, dstw, dsth, 0, channels);
77 inline void resize_image_core(
const float* src,
int srcw,
int srch,
float* dst,
int dstw,
int dsth,
int channels)
79 stbir_resize_float(src, srcw, srch, 0, dst, dstw, dsth, 0, channels);
84 enum class image_type {
93 template<
typename T =
unsigned char>
96 typedef T intensity_t;
97 typedef typename std::vector<intensity_t>::iterator iterator;
98 typedef typename std::vector<intensity_t>::const_iterator const_iterator;
100 image() : width_(0), height_(0), depth_(1) {}
105 image(
const T* data,
size_t width,
size_t height, image_type type)
106 : width_(width), height_(height), depth_(type == image_type::grayscale ? 1: 3), type_(type), data_(depth_ * width_ * height_, 0)
108 std::copy(data, data + width * height * depth_, &data_[0]);
115 : width_(size.width_), height_(size.height_), depth_(size.depth_),
117 data_(depth_ * width_ * height_, 0){
118 if (type == image_type::grayscale && size.depth_ != 1) {
119 throw nn_error(
"depth must be 1 in grayscale");
121 else if (type != image_type::grayscale && size.depth_ != 3) {
122 throw nn_error(
"depth must be 3 in rgb/bgr");
126 template <
typename U>
127 image(
const image<U>& rhs) : width_(rhs.width()), height_(rhs.height()), depth_(rhs.depth()), type_(rhs.type()), data_(rhs.shape().size()) {
128 std::transform(rhs.begin(), rhs.end(), data_.begin(), [](T src) { return static_cast<intensity_t>(src); });
136 image(
const std::string& filename, image_type type)
139 stbi_uc* input_pixels = stbi_load(filename.c_str(), &w, &h, &d, type == image_type::grayscale ? 1 : 3);
140 if (input_pixels ==
nullptr) {
141 throw nn_error(
"failed to open image:" + std::string(stbi_failure_reason()));
144 width_ =
static_cast<size_t>(w);
145 height_ =
static_cast<size_t>(h);
146 depth_ = type == image_type::grayscale ? 1 : 3;
149 data_.resize(width_*height_*depth_);
152 from_rgb(input_pixels, input_pixels + data_.size());
154 stbi_image_free(input_pixels);
157 void save(
const std::string& path)
const {
159 std::vector<uint8_t> buf = to_rgb<uint8_t>();
161 if (detail::ends_with(path,
"png")) {
162 ret = stbi_write_png(path.c_str(),
163 static_cast<int>(width_),
164 static_cast<int>(height_),
165 static_cast<int>(depth_),
166 (
const void*)&buf[0], 0);
169 ret = stbi_write_bmp(path.c_str(),
170 static_cast<int>(width_),
171 static_cast<int>(height_),
172 static_cast<int>(depth_),
173 (
const void*)&buf[0]);
176 throw nn_error(
"failed to save image:" + path);
180 void write(
const std::string& path)
const {
184 void resize(
size_t width,
size_t height)
186 data_.resize(width * height * depth_);
192 void fill(intensity_t value) {
193 std::fill(data_.begin(), data_.end(), value);
196 intensity_t& at(
size_t x,
size_t y,
size_t z = 0) {
200 return data_[z * width_ * height_ + y * width_ + x];
203 const intensity_t& at(
size_t x,
size_t y,
size_t z = 0)
const {
207 return data_[z * width_ * height_ + y * width_ + x];
210 bool empty()
const {
return data_.empty(); }
211 iterator begin() {
return data_.begin(); }
212 iterator end() {
return data_.end(); }
213 const_iterator begin()
const {
return data_.begin(); }
214 const_iterator end()
const {
return data_.end(); }
216 intensity_t& operator[](std::size_t idx) {
return data_[idx]; };
217 const intensity_t& operator[](std::size_t idx)
const {
return data_[idx]; };
219 size_t width()
const {
return width_; }
220 size_t height()
const {
return height_; }
221 size_t depth()
const {
return depth_;}
222 image_type type()
const {
return type_; }
223 shape3d shape()
const {
224 return shape3d(
static_cast<serial_size_t
>(width_),
225 static_cast<serial_size_t
>(height_),
226 static_cast<serial_size_t
>(depth_));
228 const std::vector<intensity_t>& data()
const {
return data_; }
229 vec_t to_vec()
const {
return vec_t(begin(), end()); }
231 template <
typename U>
232 std::vector<U> to_rgb()
const {
234 return std::vector<U>(data_.begin(), data_.end());
237 std::vector<U> buf(shape().size());
238 auto order = depth_order(type_);
239 auto dst = buf.begin();
241 for (
size_t y = 0; y < height_; y++)
242 for (
size_t x = 0; x < width_; x++)
243 for (
size_t i = 0; i < depth_; i++)
244 *dst++ =
static_cast<U
>(at(x, y, order[i]));
249 template <
typename Iter>
250 void from_rgb(Iter begin, Iter end) {
252 std::copy(begin, end, data_.begin());
255 auto order = depth_order(type_);
256 assert(
static_cast<serial_size_t
>(
257 std::distance(begin, end)) == data_.size());
259 for (
size_t y = 0; y < height_; y++)
260 for (
size_t x = 0; x < width_; x++)
261 for (
size_t i = 0; i < depth_; i++)
262 at(x, y, order[i]) =
static_cast<intensity_t
>(*begin++);
267 std::array<size_t, 3> depth_order(image_type img)
const {
268 if (img == image_type::rgb) {
272 assert(img == image_type::bgr);
280 std::vector<intensity_t> data_;
283 template <
typename T>
284 image<float_t> mean_image(
const image<T>& src)
286 image<float_t> mean(shape3d(1, 1, (serial_size_t)src.depth()), src.type());
288 for (
size_t i = 0; i < src.depth(); i++) {
290 for (
size_t y = 0; y < src.height(); y++) {
291 for (
size_t x = 0; x < src.width(); x++) {
292 sum += src.at(x, y, i);
295 mean.at(0, 0, i) = sum / (src.width() * src.height());
306 template <
typename T>
307 inline image<T> resize_image(
const image<T>& src,
int width,
int height)
309 image<T> resized(shape3d(
static_cast<serial_size_t
>(width),
310 static_cast<serial_size_t
>(height),
311 static_cast<serial_size_t
>(src.depth())),
313 std::vector<T> src_rgb = src.template to_rgb<T>();
314 std::vector<T> dst_rgb(resized.shape().size());
316 detail::resize_image_core(&src_rgb[0],
317 static_cast<int>(src.width()),
318 static_cast<int>(src.height()),
322 static_cast<int>(src.depth()));
324 resized.from_rgb(dst_rgb.begin(), dst_rgb.end());
331 template <
typename T>
332 image<T> subtract_image(
const image<T>& lhs,
const image<T>& rhs)
334 if (lhs.shape() != rhs.shape()) {
335 throw nn_error(
"Shapes of lhs/rhs must be same. lhs:" + to_string(lhs.shape()) +
",rhs:" + to_string(rhs.shape()));
338 image<T> dst(lhs.shape(), lhs.type());
340 auto dstit = dst.begin();
341 auto lhsit = lhs.begin();
342 auto rhsit = rhs.begin();
344 for (; dstit != dst.end(); ++dstit, ++lhsit, ++rhsit) {
345 *dstit = detail::saturated_sub(*lhsit, *rhsit);
350 template <
typename T>
351 image<T> subtract_scalar(
const image<T>& lhs,
const image<T>& rhs)
353 if (lhs.depth() != rhs.depth()) {
354 throw nn_error(
"Depth of lhs/rhs must be same. lhs:" + to_string(lhs.depth()) +
",rhs:" + to_string(rhs.depth()));
356 if (rhs.width() != 1 || rhs.height() != 1) {
357 throw nn_error(
"rhs must be 1x1xN");
360 image<T> dst(lhs.shape(), lhs.type());
362 auto dstit = dst.begin();
363 auto lhsit = lhs.begin();
364 auto rhsit = rhs.begin();
366 for (
size_t i = 0; i < lhs.depth(); i++, ++rhsit) {
367 for (
size_t j = 0; j < lhs.width() * lhs.height(); j++, ++dstit, ++lhsit) {
368 *dstit = detail::saturated_sub(*lhsit, *rhsit);
389 inline image<T> vec2image(
const vec_t& vec, serial_size_t block_size = 2, serial_size_t max_cols = 20)
392 throw nn_error(
"failed to visialize image: vector is empty");
395 const serial_size_t border_width = 1;
396 const auto cols = vec.size() >= (serial_size_t)max_cols ? (serial_size_t)max_cols : vec.size();
397 const auto rows = (vec.size() - 1) / cols + 1;
398 const auto pitch = block_size + border_width;
399 const auto width = pitch * cols + border_width;
400 const auto height = pitch * rows + border_width;
401 const typename image<T>::intensity_t bg_color = 255;
402 serial_size_t current_idx = 0;
404 img.resize(width, height);
407 auto minmax = std::minmax_element(vec.begin(), vec.end());
409 for (
unsigned int r = 0; r < rows; r++) {
410 serial_size_t topy = pitch * r + border_width;
412 for (
unsigned int c = 0; c < cols; c++, current_idx++) {
413 serial_size_t leftx = pitch * c + border_width;
414 const float_t src = vec[current_idx];
415 image<>::intensity_t dst
416 =
static_cast<typename image<T>::intensity_t
>(rescale(src, *minmax.first, *minmax.second, 0, 255));
418 for (serial_size_t y = 0; y < block_size; y++)
419 for (serial_size_t x = 0; x < block_size; x++)
420 img.at(x + leftx, y + topy) = dst;
422 if (current_idx == vec.size())
return img;
443 inline image<T> vec2image(
const vec_t& vec,
const index3d<serial_size_t>& maps) {
445 throw nn_error(
"failed to visualize image: vector is empty");
446 if (vec.size() != maps.size())
447 throw nn_error(
"failed to visualize image: vector size invalid");
449 const serial_size_t border_width = 1;
450 const auto pitch = maps.width_ + border_width;
451 const auto width = maps.depth_ * pitch + border_width;
452 const auto height = maps.height_ + 2 * border_width;
453 const typename image<T>::intensity_t bg_color = 255;
456 img.resize(width, height);
459 auto minmax = std::minmax_element(vec.begin(), vec.end());
461 for (serial_size_t c = 0; c < maps.depth_; ++c) {
462 const auto top = border_width;
463 const auto left = c * pitch + border_width;
465 for (serial_size_t y = 0; y < maps.height_; ++y) {
466 for (serial_size_t x = 0; x < maps.width_; ++x) {
467 const float_t val = vec[maps.get_index(x, y, c)];
469 img.at(left + x, top + y)
470 =
static_cast<typename image<T>::intensity_t
>(rescale(val, *minmax.first, *minmax.second, 0, 255));
Simple image utility class.
Definition: image.h:94
image(const T *data, size_t width, size_t height, image_type type)
create image from raw pointer
Definition: image.h:105
image(const std::string &filename, image_type type)
create image from file supported file format: JPEG/PNG/TGA/BMP/PSD/GIF/HDR/PIC/PNM (see detail at the...
Definition: image.h:136
image(const shape3d &size, image_type type)
create WxHxD image filled with 0
Definition: image.h:114
error exception class for tiny-dnn
Definition: nn_error.h:37