// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
#define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
#include <vector>
#include "../matrix.h"
#include "cross_validate_multiclass_trainer_abstract.h"
#include <sstream>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename dec_funct_type,
typename sample_type,
typename label_type
>
const matrix<double> test_multiclass_decision_function (
const dec_funct_type& dec_funct,
const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test
)
{
// make sure requires clause is not broken
DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
"\tmatrix test_multiclass_decision_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(x_test,y_test): "
<< is_learning_problem(x_test,y_test));
const std::vector<label_type> all_labels = dec_funct.get_labels();
// make a lookup table that maps from labels to their index in all_labels
std::map<label_type,unsigned long> label_to_int;
for (unsigned long i = 0; i < all_labels.size(); ++i)
label_to_int[all_labels[i]] = i;
matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res;
res.set_size(all_labels.size(), all_labels.size());
res = 0;
typename std::map<label_type,unsigned long>::const_iterator iter;
// now test this trained object
for (unsigned long i = 0; i < x_test.size(); ++i)
{
iter = label_to_int.find(y_test[i]);
// ignore samples with labels that the decision function doesn't know about.
if (iter == label_to_int.end())
continue;
const unsigned long truth = iter->second;
const unsigned long pred = label_to_int[dec_funct(x_test[i])];
res(truth,pred) += 1;
}
return res;
}
// ----------------------------------------------------------------------------------------
class cross_validation_error : public dlib::error
{
public:
cross_validation_error(const std::string& msg) : dlib::error(msg){};
};
template <
typename trainer_type,
typename sample_type,
typename label_type
>
const matrix<double> cross_validate_multiclass_trainer (
const trainer_type& trainer,
const std::vector<sample_type>& x,
const std::vector<label_type>& y,
const long folds
)
{
typedef typename trainer_type::mem_manager_type mem_manager_type;
// make sure requires clause is not broken
DLIB_ASSERT(is_learning_problem(x,y) == true &&
1 < folds && folds <= static_cast<long>(x.size()),
"\tmatrix cross_validate_multiclass_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.size(): " << x.size()
<< "\n\t folds: " << folds
<< "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
);
const std::vector<label_type> all_labels = select_all_distinct_labels(y);
// count the number of times each label shows up
std::map<label_type,long> label_counts;
for (unsigned long i = 0; i < y.size(); ++i)
label_counts[y[i]] += 1;
// figure out how many samples from each class will be in the test and train splits
std::map<label_type,long> num_in_test, num_in_train;
for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
{
const long in_test = i->second/folds;
if (in_test == 0)
{
std::ostringstream sout;
sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
sout << "than the number of elements of one of the training classes." << std::endl;
sout << " folds: "<< folds << std::endl;
sout << " size of class " << i->first << ": "<< i->second << std::endl;
throw cross_validation_error(sout.str());
}
num_in_test[i->first] = in_test;
num_in_train[i->first] = i->second - in_test;
}
std::vector<sample_type> x_test, x_train;
std::vector<label_type> y_test, y_train;
matrix<double, 0, 0, mem_manager_type> res;
std::map<label_type,long> next_test_idx;
for (unsigned long i = 0; i < all_labels.size(); ++i)
next_test_idx[all_labels[i]] = 0;
label_type label;
for (long i = 0; i < folds; ++i)
{
x_test.clear();
y_test.clear();
x_train.clear();
y_train.clear();
// load up the test samples
for (unsigned long j = 0; j < all_labels.size(); ++j)
{
label = all_labels[j];
long next = next_test_idx[label];
long cur = 0;
const long num_needed = num_in_test[label];
while (cur < num_needed)
{
if (y[next] == label)
{
x_test.push_back(x[next]);
y_test.push_back(label);
++cur;
}
next = (next + 1)%x.size();
}
next_test_idx[label] = next;
}
// load up the training samples
for (unsigned long j = 0; j < all_labels.size(); ++j)
{
label = all_labels[j];
long next = next_test_idx[label];
long cur = 0;
const long num_needed = num_in_train[label];
while (cur < num_needed)
{
if (y[next] == label)
{
x_train.push_back(x[next]);
y_train.push_back(label);
++cur;
}
next = (next + 1)%x.size();
}
}
try
{
// do the training and testing
res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test);
}
catch (invalid_nu_error&)
{
// just ignore cases which result in an invalid nu
}
} // for (long i = 0; i < folds; ++i)
return res;
}
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_