tiny_dnn  1.0.0
A header only, dependency-free deep learning framework in C++11
mnist_parser.h
1 /*
2  Copyright (c) 2013, Taiga Nomi
3  All rights reserved.
4 
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  * Redistributions of source code must retain the above copyright
8  notice, this list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright
10  notice, this list of conditions and the following disclaimer in the
11  documentation and/or other materials provided with the distribution.
12  * Neither the name of the <organization> nor the
13  names of its contributors may be used to endorse or promote products
14  derived from this software without specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 #pragma once
28 #include "tiny_dnn/util/util.h"
29 #include <fstream>
30 #include <cstdint>
31 
32 namespace tiny_dnn {
33 namespace detail {
34 
35 struct mnist_header {
36  uint32_t magic_number;
37  uint32_t num_items;
38  uint32_t num_rows;
39  uint32_t num_cols;
40 };
41 
42 inline void parse_mnist_header(std::ifstream& ifs, mnist_header& header) {
43  ifs.read((char*) &header.magic_number, 4);
44  ifs.read((char*) &header.num_items, 4);
45  ifs.read((char*) &header.num_rows, 4);
46  ifs.read((char*) &header.num_cols, 4);
47 
48  if (is_little_endian()) {
49  reverse_endian(&header.magic_number);
50  reverse_endian(&header.num_items);
51  reverse_endian(&header.num_rows);
52  reverse_endian(&header.num_cols);
53  }
54 
55  if (header.magic_number != 0x00000803 || header.num_items <= 0)
56  throw nn_error("MNIST label-file format error");
57  if (ifs.fail() || ifs.bad())
58  throw nn_error("file error");
59 }
60 
61 inline void parse_mnist_image(std::ifstream& ifs,
62  const mnist_header& header,
63  float_t scale_min,
64  float_t scale_max,
65  int x_padding,
66  int y_padding,
67  vec_t& dst) {
68  const int width = header.num_cols + 2 * x_padding;
69  const int height = header.num_rows + 2 * y_padding;
70 
71  std::vector<uint8_t> image_vec(header.num_rows * header.num_cols);
72 
73  ifs.read((char*) &image_vec[0], header.num_rows * header.num_cols);
74 
75  dst.resize(width * height, scale_min);
76 
77  for (uint32_t y = 0; y < header.num_rows; y++)
78  for (uint32_t x = 0; x < header.num_cols; x++)
79  dst[width * (y + y_padding) + x + x_padding]
80  = (image_vec[y * header.num_cols + x] / float_t(255)) * (scale_max - scale_min) + scale_min;
81 }
82 
83 } // namespace detail
84 
92 inline void parse_mnist_labels(const std::string& label_file, std::vector<label_t> *labels) {
93  std::ifstream ifs(label_file.c_str(), std::ios::in | std::ios::binary);
94 
95  if (ifs.bad() || ifs.fail())
96  throw nn_error("failed to open file:" + label_file);
97 
98  uint32_t magic_number, num_items;
99 
100  ifs.read((char*) &magic_number, 4);
101  ifs.read((char*) &num_items, 4);
102 
103  if (is_little_endian()) { // MNIST data is big-endian format
104  reverse_endian(&magic_number);
105  reverse_endian(&num_items);
106  }
107 
108  if (magic_number != 0x00000801 || num_items <= 0)
109  throw nn_error("MNIST label-file format error");
110 
111  for (uint32_t i = 0; i < num_items; i++) {
112  uint8_t label;
113  ifs.read((char*) &label, 1);
114  labels->push_back((label_t) label);
115  }
116 }
117 
140 inline void parse_mnist_images(const std::string& image_file,
141  std::vector<vec_t> *images,
142  float_t scale_min,
143  float_t scale_max,
144  int x_padding,
145  int y_padding) {
146 
147  if (x_padding < 0 || y_padding < 0)
148  throw nn_error("padding size must not be negative");
149  if (scale_min >= scale_max)
150  throw nn_error("scale_max must be greater than scale_min");
151 
152  std::ifstream ifs(image_file.c_str(), std::ios::in | std::ios::binary);
153 
154  if (ifs.bad() || ifs.fail())
155  throw nn_error("failed to open file:" + image_file);
156 
157  detail::mnist_header header;
158 
159  detail::parse_mnist_header(ifs, header);
160 
161  for (uint32_t i = 0; i < header.num_items; i++) {
162  vec_t image;
163  detail::parse_mnist_image(ifs, header, scale_min, scale_max, x_padding, y_padding, image);
164  images->push_back(image);
165  }
166 }
167 
168 } // namespace tiny_dnn
error exception class for tiny-dnn
Definition: nn_error.h:37
Definition: mnist_parser.h:35