// Copyright (C) 2009 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_PEGASoS_
#define DLIB_PEGASoS_
#include "pegasos_abstract.h"
#include <cmath>
#include "../algs.h"
#include "function.h"
#include "kernel.h"
#include "kcentroid.h"
#include <iostream>
#include <memory>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename K
>
class svm_pegasos
{
typedef kcentroid<offset_kernel<K> > kc_type;
public:
typedef K kernel_type;
typedef typename kernel_type::scalar_type scalar_type;
typedef typename kernel_type::sample_type sample_type;
typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
template <typename K_>
struct rebind {
typedef svm_pegasos<K_> other;
};
svm_pegasos (
) :
max_sv(40),
lambda_c1(0.0001),
lambda_c2(0.0001),
max_wnorm(100),
tau(0.01),
tolerance(0.01),
train_count(0),
w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
{
}
svm_pegasos (
const kernel_type& kernel_,
const scalar_type& lambda_,
const scalar_type& tolerance_,
unsigned long max_num_sv
) :
max_sv(max_num_sv),
kernel(kernel_),
lambda_c1(lambda_),
lambda_c2(lambda_),
tau(0.01),
tolerance(tolerance_),
train_count(0),
w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
{
// make sure requires clause is not broken
DLIB_ASSERT(lambda_ > 0 && tolerance > 0 && max_num_sv > 0,
"\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
<< "\n\t max_num_sv: " << max_num_sv
);
}
void clear (
)
{
// reset the w vector back to its initial state
w = kc_type(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false);
train_count = 0;
}
void set_kernel (
kernel_type k
)
{
kernel = k;
clear();
}
void set_max_num_sv (
unsigned long max_num_sv
)
{
// make sure requires clause is not broken
DLIB_ASSERT(max_num_sv > 0,
"\tvoid svm_pegasos::set_max_num_sv(max_num_sv)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t max_num_sv: " << max_num_sv
);
max_sv = max_num_sv;
clear();
}
unsigned long get_max_num_sv (
) const
{
return max_sv;
}
void set_tolerance (
double tol
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < tol,
"\tvoid svm_pegasos::set_tolerance(tol)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t tol: " << tol
);
tolerance = tol;
clear();
}
void set_lambda (
scalar_type lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda_c1 = lambda_;
lambda_c2 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
void set_lambda_class1 (
scalar_type lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda_class1(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda_c1 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
void set_lambda_class2 (
scalar_type lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda_class2(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda_c2 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
const scalar_type get_lambda_class1 (
) const
{
return lambda_c1;
}
const scalar_type get_lambda_class2 (
) const
{
return lambda_c2;
}
const scalar_type get_tolerance (
) const
{
return tolerance;
}
const kernel_type get_kernel (
) const
{
return kernel;
}
unsigned long get_train_count (
) const
{
return static_cast<unsigned long>(train_count);
}
scalar_type train (
const sample_type& x,
const scalar_type& y
)
{
// make sure requires clause is not broken
DLIB_ASSERT(y == -1 || y == 1,
"\tscalar_type svm_pegasos::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t y: " << y
);
const double lambda = (y==+1)? lambda_c1 : lambda_c2;
++train_count;
const scalar_type learning_rate = 1/(lambda*train_count);
// if this sample point is within the margin of the current hyperplane
if (y*w.inner_product(x) < 1)
{
// compute: w = (1-learning_rate*lambda)*w + y*learning_rate*x
w.train(x, 1 - learning_rate*lambda, y*learning_rate);
scalar_type wnorm = std::sqrt(w.squared_norm());
scalar_type temp = max_wnorm/wnorm;
if (temp < 1)
w.scale_by(temp);
}
else
{
w.scale_by(1 - learning_rate*lambda);
}
// return the current learning rate
return 1/(std::min(lambda_c1,lambda_c2)*train_count);
}
scalar_type operator() (
const sample_type& x
) const
{
return w.inner_product(x);
}
const decision_function<kernel_type> get_decision_function (
) const
{
distance_function<offset_kernel<kernel_type> > df = w.get_distance_function();
return decision_function<kernel_type>(df.get_alpha(), -tau*sum(df.get_alpha()), kernel, df.get_basis_vectors());
}
void swap (
svm_pegasos& item
)
{
exchange(max_sv, item.max_sv);
exchange(kernel, item.kernel);
exchange(lambda_c1, item.lambda_c1);
exchange(lambda_c2, item.lambda_c2);
exchange(max_wnorm, item.max_wnorm);
exchange(tau, item.tau);
exchange(tolerance, item.tolerance);
exchange(train_count, item.train_count);
exchange(w, item.w);
}
friend void serialize(const svm_pegasos& item, std::ostream& out)
{
serialize(item.max_sv, out);
serialize(item.kernel, out);
serialize(item.lambda_c1, out);
serialize(item.lambda_c2, out);
serialize(item.max_wnorm, out);
serialize(item.tau, out);
serialize(item.tolerance, out);
serialize(item.train_count, out);
serialize(item.w, out);
}
friend void deserialize(svm_pegasos& item, std::istream& in)
{
deserialize(item.max_sv, in);
deserialize(item.kernel, in);
deserialize(item.lambda_c1, in);
deserialize(item.lambda_c2, in);
deserialize(item.max_wnorm, in);
deserialize(item.tau, in);
deserialize(item.tolerance, in);
deserialize(item.train_count, in);
deserialize(item.w, in);
}
private:
unsigned long max_sv;
kernel_type kernel;
scalar_type lambda_c1;
scalar_type lambda_c2;
scalar_type max_wnorm;
scalar_type tau;
scalar_type tolerance;
scalar_type train_count;
kc_type w;
}; // end of class svm_pegasos
template <
typename K
>
void swap (
svm_pegasos<K>& a,
svm_pegasos<K>& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
template <
typename T,
typename U
>
void replicate_settings (
const svm_pegasos<T>& source,
svm_pegasos<U>& dest
)
{
dest.set_tolerance(source.get_tolerance());
dest.set_lambda_class1(source.get_lambda_class1());
dest.set_lambda_class2(source.get_lambda_class2());
dest.set_max_num_sv(source.get_max_num_sv());
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
class batch_trainer
{
// ------------------------------------------------------------------------------------
template <
typename K,
typename sample_vector_type
>
class caching_kernel
{
public:
typedef typename K::scalar_type scalar_type;
typedef long sample_type;
//typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
caching_kernel () {}
caching_kernel (
const K& kern,
const sample_vector_type& samps,
long cache_size_
) : real_kernel(kern), samples(&samps), counter(0)
{
cache_size = std::min<long>(cache_size_, samps.size());
cache.reset(new cache_type);
cache->frequency_of_use.resize(samps.size());
for (long i = 0; i < samps.size(); ++i)
cache->frequency_of_use[i] = std::make_pair(0, i);
// Set the cache build/rebuild threshold so that we have to have
// as many cache misses as there are entries in the cache before
// we build/rebuild.
counter_threshold = samps.size()*cache_size;
cache->sample_location.assign(samples->size(), -1);
}
scalar_type operator() (
const sample_type& a,
const sample_type& b
) const
{
// rebuild the cache every so often
if (counter > counter_threshold )
{
build_cache();
}
const long a_loc = cache->sample_location[a];
const long b_loc = cache->sample_location[b];
cache->frequency_of_use[a].first += 1;
cache->frequency_of_use[b].first += 1;
if (a_loc != -1)
{
return cache->kernel(a_loc, b);
}
else if (b_loc != -1)
{
return cache->kernel(b_loc, a);
}
else
{
++counter;
return real_kernel((*samples)(a), (*samples)(b));
}
}
bool operator== (
const caching_kernel& item
) const
{
return item.real_kernel == real_kernel &&
item.samples == samples;
}
private:
K real_kernel;
void build_cache (
) const
{
std::sort(cache->frequency_of_use.rbegin(), cache->frequency_of_use.rend());
counter = 0;
cache->kernel.set_size(cache_size, samples->size());
cache->sample_location.assign(samples->size(), -1);
// loop over all the samples in the cache
for (long i = 0; i < cache_size; ++i)
{
const long cur = cache->frequency_of_use[i].second;
cache->sample_location[cur] = i;
// now populate all possible kernel products with the current sample
for (long j = 0; j < samples->size(); ++j)
{
cache->kernel(i, j) = real_kernel((*samples)(cur), (*samples)(j));
}
}
// reset the frequency of use metrics
for (long i = 0; i < samples->size(); ++i)
cache->frequency_of_use[i] = std::make_pair(0, i);
}
struct cache_type
{
matrix<scalar_type> kernel;
std::vector<long> sample_location; // where in the cache a sample is. -1 means not in cache
std::vector<std::pair<long,long> > frequency_of_use;
};
const sample_vector_type* samples = 0;
std::shared_ptr<cache_type> cache;
mutable unsigned long counter = 0;
unsigned long counter_threshold = 0;
long cache_size = 0;
};
// ------------------------------------------------------------------------------------
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef typename trainer_type::trained_function_type trained_function_type;
batch_trainer (
) :
min_learning_rate(0.1),
use_cache(false),
cache_size(100)
{
}
batch_trainer (
const trainer_type& trainer_,
const scalar_type min_learning_rate_,
bool verbose_,
bool use_cache_,
long cache_size_ = 100
) :
trainer(trainer_),
min_learning_rate(min_learning_rate_),
verbose(verbose_),
use_cache(use_cache_),
cache_size(cache_size_)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < min_learning_rate_ &&
cache_size_ > 0,
"\tbatch_trainer::batch_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t min_learning_rate_: " << min_learning_rate_
<< "\n\t cache_size_: " << cache_size_
);
trainer.clear();
}
const scalar_type get_min_learning_rate (
) const
{
return min_learning_rate;
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const
{
if (use_cache)
return do_train_cached(mat(x), mat(y));
else
return do_train(mat(x), mat(y));
}
private:
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> do_train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const
{
dlib::rand rnd;
trainer_type my_trainer(trainer);
scalar_type cur_learning_rate = min_learning_rate + 10;
unsigned long count = 0;
while (cur_learning_rate > min_learning_rate)
{
const long i = rnd.get_random_32bit_number()%x.size();
// keep feeding the trainer data until its learning rate goes below our threshold
cur_learning_rate = my_trainer.train(x(i), y(i));
if (verbose)
{
if ( (count&0x7FF) == 0)
{
std::cout << "\rbatch_trainer(): Percent complete: "
<< 100*min_learning_rate/cur_learning_rate << " " << std::flush;
}
++count;
}
}
if (verbose)
{
decision_function<kernel_type> df = my_trainer.get_decision_function();
std::cout << "\rbatch_trainer(): Percent complete: 100 " << std::endl;
std::cout << " Num sv: " << df.basis_vectors.size() << std::endl;
std::cout << " bias: " << df.b << std::endl;
return df;
}
else
{
return my_trainer.get_decision_function();
}
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> do_train_cached (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const
{
dlib::rand rnd;
// make a caching kernel
typedef caching_kernel<kernel_type, in_sample_vector_type> ckernel_type;
ckernel_type ck(trainer.get_kernel(), x, cache_size);
// now rebind the trainer to use the caching kernel
typedef typename trainer_type::template rebind<ckernel_type>::other rebound_trainer_type;
rebound_trainer_type my_trainer;
my_trainer.set_kernel(ck);
replicate_settings(trainer, my_trainer);
scalar_type cur_learning_rate = min_learning_rate + 10;
unsigned long count = 0;
while (cur_learning_rate > min_learning_rate)
{
const long i = rnd.get_random_32bit_number()%x.size();
// keep feeding the trainer data until its learning rate goes below our threshold
cur_learning_rate = my_trainer.train(i, y(i));
if (verbose)
{
if ( (count&0x7FF) == 0)
{
std::cout << "\rbatch_trainer(): Percent complete: "
<< 100*min_learning_rate/cur_learning_rate << " " << std::flush;
}
++count;
}
}
if (verbose)
{
decision_function<ckernel_type> cached_df;
cached_df = my_trainer.get_decision_function();
std::cout << "\rbatch_trainer(): Percent complete: 100 " << std::endl;
std::cout << " Num sv: " << cached_df.basis_vectors.size() << std::endl;
std::cout << " bias: " << cached_df.b << std::endl;
return decision_function<kernel_type> (
cached_df.alpha,
cached_df.b,
trainer.get_kernel(),
rowm(x, cached_df.basis_vectors)
);
}
else
{
decision_function<ckernel_type> cached_df;
cached_df = my_trainer.get_decision_function();
return decision_function<kernel_type> (
cached_df.alpha,
cached_df.b,
trainer.get_kernel(),
rowm(x, cached_df.basis_vectors)
);
}
}
trainer_type trainer;
scalar_type min_learning_rate;
bool verbose;
bool use_cache;
long cache_size;
}; // end of class batch_trainer
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> batch (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, false); }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> verbose_batch (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, false); }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> batch_cached (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1,
long cache_size = 100
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, true, cache_size); }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> verbose_batch_cached (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1,
long cache_size = 100
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, true, cache_size); }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_PEGASoS_