2 #include "tiny_dnn/util/util.h"
10 inline std::vector<serial_size_t> calculate_label_counts(
const std::vector<label_t>& t) {
11 std::vector<serial_size_t> label_counts;
12 for (label_t label : t) {
13 if (label >= label_counts.size()) {
14 label_counts.resize(label + 1);
16 label_counts[label]++;
18 assert(std::accumulate(label_counts.begin(), label_counts.end(),
static_cast<serial_size_t
>(0)) == t.size());
26 inline float_t get_sample_weight_for_balanced_target_cost(serial_size_t classes, serial_size_t total_samples, serial_size_t this_class_samples)
28 assert(this_class_samples <= total_samples);
29 return total_samples /
static_cast<float_t
>(classes * this_class_samples);
40 inline std::vector<vec_t> create_balanced_target_cost(
const std::vector<label_t>& t, float_t w = 1.0)
42 const auto label_counts = calculate_label_counts(t);
43 const serial_size_t total_sample_count =
static_cast<serial_size_t
>(t.size());
44 const serial_size_t class_count =
static_cast<serial_size_t
>(label_counts.size());
46 std::vector<vec_t> target_cost(t.size());
48 for (serial_size_t i = 0; i < total_sample_count; ++i) {
49 vec_t& sample_cost = target_cost[i];
50 sample_cost.resize(class_count);
51 const float_t balanced_weight = get_sample_weight_for_balanced_target_cost(class_count, total_sample_count, label_counts[t[i]]);
52 const float_t unbalanced_weight = 1;
53 const float_t sample_weight = w * balanced_weight + (1 - w) * unbalanced_weight;
54 std::fill(sample_cost.begin(), sample_cost.end(), sample_weight);