// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/optimization.h>
#include <dlib/rand.h>
#include "tester.h"
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.find_max_factor_graph_viterbi");
// ----------------------------------------------------------------------------------------
dlib::rand rnd;
// ----------------------------------------------------------------------------------------
template <
unsigned long O,
unsigned long NS,
unsigned long num_nodes,
bool all_negative
>
class map_problem
{
public:
unsigned long order() const { return O; }
unsigned long num_states() const { return NS; }
map_problem()
{
data = randm(number_of_nodes(),(long)std::pow(num_states(),(double)order()+1), rnd);
if (all_negative)
data = -data;
}
unsigned long number_of_nodes (
) const
{
return num_nodes;
}
template <
typename EXP
>
double factor_value (
unsigned long node_id,
const matrix_exp<EXP>& node_states
) const
{
if (node_states.size() == 1)
return data(node_id, node_states(0));
else if (node_states.size() == 2)
return data(node_id, node_states(0) + node_states(1)*NS);
else if (node_states.size() == 3)
return data(node_id, (node_states(0) + node_states(1)*NS)*NS + node_states(2));
else
return data(node_id, ((node_states(0) + node_states(1)*NS)*NS + node_states(2))*NS + node_states(3));
}
matrix<double> data;
};
// ----------------------------------------------------------------------------------------
template <
typename map_problem
>
void brute_force_find_max_factor_graph_viterbi (
const map_problem& prob,
std::vector<unsigned long>& map_assignment
)
{
using namespace dlib::impl;
const int order = prob.order();
const int num_states = prob.num_states();
map_assignment.resize(prob.number_of_nodes());
double best_score = -std::numeric_limits<double>::infinity();
matrix<unsigned long,1,0> node_states;
node_states.set_size(prob.number_of_nodes());
node_states = 0;
do
{
double score = 0;
for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
{
score += prob.factor_value(i, (colm(node_states,range(i,i-std::min<int>(order,i)))));
}
if (score > best_score)
{
for (unsigned long i = 0; i < map_assignment.size(); ++i)
map_assignment[i] = node_states(i);
best_score = score;
}
} while(advance_state(node_states,num_states));
}
// ----------------------------------------------------------------------------------------
template <
unsigned long order,
unsigned long num_states,
unsigned long num_nodes,
bool all_negative
>
void do_test_()
{
dlog << LINFO << "order: "<< order
<< " num_states: " << num_states
<< " num_nodes: " << num_nodes
<< " all_negative: " << all_negative;
for (int i = 0; i < 25; ++i)
{
print_spinner();
map_problem<order,num_states,num_nodes,all_negative> prob;
std::vector<unsigned long> assign, assign2;
brute_force_find_max_factor_graph_viterbi(prob, assign);
find_max_factor_graph_viterbi(prob, assign2);
DLIB_TEST_MSG(mat(assign) == mat(assign2),
trans(mat(assign))
<< trans(mat(assign2))
);
}
}
template <
unsigned long order,
unsigned long num_states,
unsigned long num_nodes
>
void do_test()
{
do_test_<order,num_states,num_nodes,false>();
}
template <
unsigned long order,
unsigned long num_states,
unsigned long num_nodes
>
void do_test_negative()
{
do_test_<order,num_states,num_nodes,true>();
}
// ----------------------------------------------------------------------------------------
class test_find_max_factor_graph_viterbi : public tester
{
public:
test_find_max_factor_graph_viterbi (
) :
tester ("test_find_max_factor_graph_viterbi",
"Runs tests on the find_max_factor_graph_viterbi routine.")
{}
void perform_test (
)
{
do_test<1,3,0>();
do_test<1,3,1>();
do_test<1,3,2>();
do_test<0,3,2>();
do_test_negative<0,3,2>();
do_test<1,3,8>();
do_test<2,3,7>();
do_test_negative<2,3,7>();
do_test<3,3,8>();
do_test<4,3,8>();
do_test_negative<4,3,8>();
do_test<0,3,8>();
do_test<4,3,1>();
do_test<4,3,0>();
do_test<3,2,1>();
do_test<3,2,0>();
do_test<3,2,2>();
do_test<2,2,1>();
do_test_negative<3,2,1>();
do_test_negative<3,2,0>();
do_test_negative<3,2,2>();
do_test_negative<2,2,1>();
do_test<0,3,0>();
do_test<1,2,8>();
do_test<2,2,7>();
do_test<3,2,8>();
do_test<0,2,8>();
do_test<1,1,8>();
do_test<2,1,8>();
do_test<3,1,8>();
do_test<0,1,8>();
}
} a;
}