tiny_dnn  1.0.0
A header only, dependency-free deep learning framework in C++11
target_cost.h
1 #pragma once
2 #include "tiny_dnn/util/util.h"
3 #include <numeric> // std::accumulate
4 
5 namespace tiny_dnn {
6 
7 // calculate the number of samples for each class label
8 // - for example, if there are 10 samples having label 0, and
9 // 20 samples having label 1, returns a vector [10, 20]
10 inline std::vector<serial_size_t> calculate_label_counts(const std::vector<label_t>& t) {
11  std::vector<serial_size_t> label_counts;
12  for (label_t label : t) {
13  if (label >= label_counts.size()) {
14  label_counts.resize(label + 1);
15  }
16  label_counts[label]++;
17  }
18  assert(std::accumulate(label_counts.begin(), label_counts.end(), static_cast<serial_size_t>(0)) == t.size());
19  return label_counts;
20 }
21 
22 // calculate the weight of a given sample needed for a balanced target cost
23 // NB: we call a target cost matrix "balanced", if the cost of each *class* is equal
24 // (this happens when the product weight * sample count is equal between the different
25 // classes, and the sum of these products equals the total number of samples)
26 inline float_t get_sample_weight_for_balanced_target_cost(serial_size_t classes, serial_size_t total_samples, serial_size_t this_class_samples)
27 {
28  assert(this_class_samples <= total_samples);
29  return total_samples / static_cast<float_t>(classes * this_class_samples);
30 }
31 
32 // create a target cost matrix implying equal cost for each *class* (distinct label)
33 // - by default, each *sample* has an equal cost, which means e.g. that a classifier
34 // may prefer to always guess the majority class (in case the degree of imbalance
35 // is relatively high, and the classification task is relatively difficult)
36 // - the parameter w can be used to fine-tune the balance:
37 // * use 0 to have an equal cost for each *sample* (equal to not supplying any target costs at all)
38 // * use 1 to have an equal cost for each *class* (default behaviour of this function)
39 // * use a value between 0 and 1 to have something between the two extremes
40 inline std::vector<vec_t> create_balanced_target_cost(const std::vector<label_t>& t, float_t w = 1.0)
41 {
42  const auto label_counts = calculate_label_counts(t);
43  const serial_size_t total_sample_count = static_cast<serial_size_t>(t.size());
44  const serial_size_t class_count = static_cast<serial_size_t>(label_counts.size());
45 
46  std::vector<vec_t> target_cost(t.size());
47 
48  for (serial_size_t i = 0; i < total_sample_count; ++i) {
49  vec_t& sample_cost = target_cost[i];
50  sample_cost.resize(class_count);
51  const float_t balanced_weight = get_sample_weight_for_balanced_target_cost(class_count, total_sample_count, label_counts[t[i]]);
52  const float_t unbalanced_weight = 1;
53  const float_t sample_weight = w * balanced_weight + (1 - w) * unbalanced_weight;
54  std::fill(sample_cost.begin(), sample_cost.end(), sample_weight);
55  }
56 
57  return target_cost;
58 }
59 
60 } // namespace tiny_dnn