// Copyright (C) 2007 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SVm_FUNCTION
#define DLIB_SVm_FUNCTION
#include "function_abstract.h"
#include <cmath>
#include <limits>
#include <sstream>
#include "../matrix.h"
#include "../algs.h"
#include "../serialize.h"
#include "../rand.h"
#include "../statistics.h"
#include "kernel_matrix.h"
#include "kernel.h"
#include "sparse_kernel.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct decision_function
{
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
decision_function (
) : b(0), kernel_function(K()) {}
decision_function (
const decision_function& d
) :
alpha(d.alpha),
b(d.b),
kernel_function(d.kernel_function),
basis_vectors(d.basis_vectors)
{}
decision_function (
const scalar_vector_type& alpha_,
const scalar_type& b_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) :
alpha(alpha_),
b(b_),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{}
result_type operator() (
const sample_type& x
) const
{
result_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,basis_vectors(i));
return temp - b;
}
};
template <
typename K
>
void serialize (
const decision_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.b, out);
serialize(item.kernel_function, out);
serialize(item.basis_vectors, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type decision_function");
}
}
template <
typename K
>
void deserialize (
decision_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.b, in);
deserialize(item.kernel_function, in);
deserialize(item.basis_vectors, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type decision_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename function_type
>
struct probabilistic_function
{
typedef typename function_type::scalar_type scalar_type;
typedef typename function_type::result_type result_type;
typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type;
scalar_type alpha;
scalar_type beta;
function_type decision_funct;
probabilistic_function (
) : alpha(0), beta(0), decision_funct(function_type()) {}
probabilistic_function (
const probabilistic_function& d
) :
alpha(d.alpha),
beta(d.beta),
decision_funct(d.decision_funct)
{}
probabilistic_function (
const scalar_type a_,
const scalar_type b_,
const function_type& decision_funct_
) :
alpha(a_),
beta(b_),
decision_funct(decision_funct_)
{}
result_type operator() (
const sample_type& x
) const
{
result_type f = decision_funct(x);
return 1/(1 + std::exp(alpha*f + beta));
}
};
template <
typename function_type
>
void serialize (
const probabilistic_function<function_type>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.beta, out);
serialize(item.decision_funct, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type probabilistic_function");
}
}
template <
typename function_type
>
void deserialize (
probabilistic_function<function_type>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.beta, in);
deserialize(item.decision_funct, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type probabilistic_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct probabilistic_decision_function
{
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
scalar_type alpha;
scalar_type beta;
decision_function<K> decision_funct;
probabilistic_decision_function (
) : alpha(0), beta(0), decision_funct(decision_function<K>()) {}
probabilistic_decision_function (
const probabilistic_function<decision_function<K> >& d
) :
alpha(d.alpha),
beta(d.beta),
decision_funct(d.decision_funct)
{}
probabilistic_decision_function (
const probabilistic_decision_function& d
) :
alpha(d.alpha),
beta(d.beta),
decision_funct(d.decision_funct)
{}
probabilistic_decision_function (
const scalar_type a_,
const scalar_type b_,
const decision_function<K>& decision_funct_
) :
alpha(a_),
beta(b_),
decision_funct(decision_funct_)
{}
result_type operator() (
const sample_type& x
) const
{
result_type f = decision_funct(x);
return 1/(1 + std::exp(alpha*f + beta));
}
};
template <
typename K
>
void serialize (
const probabilistic_decision_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.beta, out);
serialize(item.decision_funct, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type probabilistic_decision_function");
}
}
template <
typename K
>
void deserialize (
probabilistic_decision_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.beta, in);
deserialize(item.decision_funct, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type probabilistic_decision_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename K
>
class distance_function
{
public:
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
distance_function (
) : b(0), kernel_function(K()) {}
explicit distance_function (
const kernel_type& kern
) : b(0), kernel_function(kern) {}
distance_function (
const kernel_type& kern,
const sample_type& samp
) :
alpha(ones_matrix<scalar_type>(1,1)),
b(kern(samp,samp)),
kernel_function(kern)
{
basis_vectors.set_size(1,1);
basis_vectors(0) = samp;
}
distance_function (
const decision_function<K>& f
) :
alpha(f.alpha),
b(trans(f.alpha)*kernel_matrix(f.kernel_function,f.basis_vectors)*f.alpha),
kernel_function(f.kernel_function),
basis_vectors(f.basis_vectors)
{
// make sure requires clause is not broken
DLIB_ASSERT(f.alpha.size() == f.basis_vectors.size(),
"\t distance_function(f)"
<< "\n\t The supplied decision_function is invalid."
<< "\n\t f.alpha.size(): " << f.alpha.size()
<< "\n\t f.basis_vectors.size(): " << f.basis_vectors.size()
);
}
distance_function (
const distance_function& d
) :
alpha(d.alpha),
b(d.b),
kernel_function(d.kernel_function),
basis_vectors(d.basis_vectors)
{
}
distance_function (
const scalar_vector_type& alpha_,
const scalar_type& b_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) :
alpha(alpha_),
b(b_),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{
// make sure requires clause is not broken
DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
"\t distance_function()"
<< "\n\t The supplied arguments are invalid."
<< "\n\t alpha_.size(): " << alpha_.size()
<< "\n\t basis_vectors_.size(): " << basis_vectors_.size()
);
}
distance_function (
const scalar_vector_type& alpha_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) :
alpha(alpha_),
b(trans(alpha)*kernel_matrix(kernel_function_,basis_vectors_)*alpha),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{
// make sure requires clause is not broken
DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
"\t distance_function()"
<< "\n\t The supplied arguments are invalid."
<< "\n\t alpha_.size(): " << alpha_.size()
<< "\n\t basis_vectors_.size(): " << basis_vectors_.size()
);
}
const scalar_vector_type& get_alpha (
) const { return alpha; }
const scalar_type& get_squared_norm (
) const { return b; }
const K& get_kernel(
) const { return kernel_function; }
const sample_vector_type& get_basis_vectors (
) const { return basis_vectors; }
result_type operator() (
const sample_type& x
) const
{
result_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,basis_vectors(i));
temp = b + kernel_function(x,x) - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
result_type operator() (
const distance_function& x
) const
{
result_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
for (long j = 0; j < x.alpha.nr(); ++j)
temp += alpha(i)*x.alpha(j) * kernel_function(basis_vectors(i), x.basis_vectors(j));
temp = b + x.b - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
distance_function operator* (
const scalar_type& val
) const
{
return distance_function(val*alpha,
val*val*b,
kernel_function,
basis_vectors);
}
distance_function operator/ (
const scalar_type& val
) const
{
return distance_function(alpha/val,
b/val/val,
kernel_function,
basis_vectors);
}
distance_function operator+ (
const distance_function& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
"\t distance_function distance_function::operator+()"
<< "\n\t You can only add two distance_functions together if they use the same kernel."
);
if (alpha.size() == 0)
return rhs;
else if (rhs.alpha.size() == 0)
return *this;
else
return distance_function(join_cols(alpha, rhs.alpha),
b + rhs.b + 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
kernel_function,
join_cols(basis_vectors, rhs.basis_vectors));
}
distance_function operator- (
const distance_function& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
"\t distance_function distance_function::operator-()"
<< "\n\t You can only subtract two distance_functions if they use the same kernel."
);
if (alpha.size() == 0 && rhs.alpha.size() == 0)
return distance_function(kernel_function);
else if (alpha.size() != 0 && rhs.alpha.size() == 0)
return *this;
else if (alpha.size() == 0 && rhs.alpha.size() != 0)
return -1*rhs;
else
return distance_function(join_cols(alpha, -rhs.alpha),
b + rhs.b - 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
kernel_function,
join_cols(basis_vectors, rhs.basis_vectors));
}
private:
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
};
template <
typename K
>
distance_function<K> operator* (
const typename K::scalar_type& val,
const distance_function<K>& df
) { return df*val; }
template <
typename K
>
void serialize (
const distance_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.b, out);
serialize(item.kernel_function, out);
serialize(item.basis_vectors, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type distance_function");
}
}
template <
typename K
>
void deserialize (
distance_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.alpha, in);
deserialize(item.b, in);
deserialize(item.kernel_function, in);
deserialize(item.basis_vectors, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type distance_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename function_type,
typename normalizer_type = vector_normalizer<typename function_type::sample_type>
>
struct normalized_function
{
typedef typename function_type::result_type result_type;
typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type;
normalizer_type normalizer;
function_type function;
normalized_function (
){}
normalized_function (
const normalized_function& f
) :
normalizer(f.normalizer),
function(f.function)
{}
const std::vector<result_type> get_labels(
) const { return function.get_labels(); }
unsigned long number_of_classes (
) const { return function.number_of_classes(); }
normalized_function (
const vector_normalizer<sample_type>& normalizer_,
const function_type& funct
) : normalizer(normalizer_), function(funct) {}
result_type operator() (
const sample_type& x
) const { return function(normalizer(x)); }
};
template <
typename function_type,
typename normalizer_type
>
void serialize (
const normalized_function<function_type,normalizer_type>& item,
std::ostream& out
)
{
try
{
serialize(item.normalizer, out);
serialize(item.function, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type normalized_function");
}
}
template <
typename function_type,
typename normalizer_type
>
void deserialize (
normalized_function<function_type,normalizer_type>& item,
std::istream& in
)
{
try
{
deserialize(item.normalizer, in);
deserialize(item.function, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type normalized_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct projection_function
{
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef scalar_vector_type result_type;
scalar_matrix_type weights;
K kernel_function;
sample_vector_type basis_vectors;
projection_function (
) {}
projection_function (
const projection_function& f
) : weights(f.weights), kernel_function(f.kernel_function), basis_vectors(f.basis_vectors) {}
projection_function (
const scalar_matrix_type& weights_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) : weights(weights_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {}
long out_vector_size (
) const { return weights.nr(); }
const result_type& operator() (
const sample_type& x
) const
{
// Run the x sample through all the basis functions we have and then
// multiply it by the weights matrix and return the result. Note that
// the temp vectors are here to avoid reallocating their memory every
// time this function is called.
temp1 = kernel_matrix(kernel_function, basis_vectors, x);
temp2 = weights*temp1;
return temp2;
}
private:
mutable result_type temp1, temp2;
};
template <
typename K
>
void serialize (
const projection_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.weights, out);
serialize(item.kernel_function, out);
serialize(item.basis_vectors, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type projection_function");
}
}
template <
typename K
>
void deserialize (
projection_function<K>& item,
std::istream& in
)
{
try
{
deserialize(item.weights, in);
deserialize(item.kernel_function, in);
deserialize(item.basis_vectors, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type projection_function");
}
}
// ----------------------------------------------------------------------------------------
template <
typename K,
typename result_type_ = typename K::scalar_type
>
struct multiclass_linear_decision_function
{
typedef result_type_ result_type;
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
// You are getting a compiler error on this line because you supplied a non-linear kernel
// to the multiclass_linear_decision_function object. You have to use one of the linear
// kernels with this object.
COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
scalar_matrix_type weights;
scalar_vector_type b;
std::vector<result_type> labels;
const std::vector<result_type>& get_labels(
) const { return labels; }
unsigned long number_of_classes (
) const { return labels.size(); }
std::pair<result_type, scalar_type> predict (
const sample_type& x
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(weights.size() > 0 &&
weights.nr() == (long)number_of_classes() &&
weights.nr() == b.size(),
"\t pair<result_type,scalar_type> multiclass_linear_decision_function::predict(x)"
<< "\n\t This object must be properly initialized before you can use it."
<< "\n\t weights.size(): " << weights.size()
<< "\n\t weights.nr(): " << weights.nr()
<< "\n\t number_of_classes(): " << number_of_classes()
);
// Rather than doing something like, best_idx = index_of_max(weights*x-b)
// we do the following somewhat more complex thing because this supports
// both sparse and dense samples.
scalar_type best_val = dot(rowm(weights,0),x) - b(0);
unsigned long best_idx = 0;
for (unsigned long i = 1; i < labels.size(); ++i)
{
scalar_type temp = dot(rowm(weights,i),x) - b(i);
if (temp > best_val)
{
best_val = temp;
best_idx = i;
}
}
return std::make_pair(labels[best_idx], best_val);
}
result_type operator() (
const sample_type& x
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(weights.size() > 0 &&
weights.nr() == (long)number_of_classes() &&
weights.nr() == b.size(),
"\t result_type multiclass_linear_decision_function::operator()(x)"
<< "\n\t This object must be properly initialized before you can use it."
<< "\n\t weights.size(): " << weights.size()
<< "\n\t weights.nr(): " << weights.nr()
<< "\n\t number_of_classes(): " << number_of_classes()
);
return predict(x).first;
}
};
template <
typename K,
typename result_type_
>
void serialize (
const multiclass_linear_decision_function<K,result_type_>& item,
std::ostream& out
)
{
try
{
serialize(item.weights, out);
serialize(item.b, out);
serialize(item.labels, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type multiclass_linear_decision_function");
}
}
template <
typename K,
typename result_type_
>
void deserialize (
multiclass_linear_decision_function<K,result_type_>& item,
std::istream& in
)
{
try
{
deserialize(item.weights, in);
deserialize(item.b, in);
deserialize(item.labels, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type multiclass_linear_decision_function");
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SVm_FUNCTION