libeblearn
|
#include <ebl_trainer.h>
Public Member Functions | |
supervised_trainer (trainable_module< Tnet, Tdata, Tlabel > &m, parameter< Tnet, bbstate_idx< Tnet > > &p) | |
constructor. | |
virtual | ~supervised_trainer () |
destructor. | |
bool | test_sample (labeled_datasource< Tnet, Tdata, Tlabel > &ds, bbstate_idx< Tnet > &label, bbstate_idx< Tnet > &answers, infer_param &infp) |
Tnet | train_sample (labeled_datasource< Tnet, Tdata, Tlabel > &ds, gd_param &arg) |
void | test (labeled_datasource< Tnet, Tdata, Tlabel > &ds, classifier_meter &log, infer_param &infp, uint max_test=0) |
void | train (labeled_datasource< Tnet, Tdata, Tlabel > &ds, classifier_meter &log, gd_param &args, int niter, infer_param &infp, intg hessian_period=0, intg nhessian=0, double mu=.02) |
void | compute_diaghessian (labeled_datasource< Tnet, Tdata, Tlabel > &ds, intg niter, double mu) |
compute hessian | |
void | set_iteration (int i) |
void | pretty (labeled_datasource< Tnet, Tdata, Tlabel > &ds) |
pretty some information about training, e.g. input and network sizes. | |
void | set_progress_file (const std::string &s) |
void | update_progress () |
Protected Member Functions | |
void | init (labeled_datasource< Tnet, Tdata, Tlabel > &ds, classifier_meter *log=NULL, bool new_iteration=false) |
Protected Attributes | |
trainable_module< Tnet, Tdata, Tlabel > & | machine |
parameter< Tnet, bbstate_idx < Tnet > > & | param |
the learned params | |
bbstate_idx< Tnet > | energy |
Tmp energy buffer. | |
bbstate_idx< Tnet > * | answers |
Tmp answer buffer. | |
bbstate_idx< Tnet > * | label |
Tmp label buffer. | |
intg | age |
int | iteration |
void * | iteration_ptr |
bool | prettied |
Flag used to pretty info just once. | |
std::string | progress_file |
Name of progress file. | |
intg | progress_cnt |
A count for updating progress. | |
bool | test_running |
Show test on trained. | |
Friends | |
class | supervised_trainer_gui |
Supervised Trainer. A specialisation of the generic trainer, taking samples (of type Tnet) and labels (of type Tlabel) as training input. Template Tnet is the network's type and also the input data's type. However datasources with different data type may be provided in which case a conversion will occur after each sample extraction from the datasource (via a deep idx_copy).
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::init | ( | labeled_datasource< Tnet, Tdata, Tlabel > & | ds, |
classifier_meter * | log = NULL , |
||
bool | new_iteration = false |
||
) | [protected] |
init datasource to begining and assign indata to a buffer corresponding to ds's sample size. also increment iteration counter, unless new_iteration is false.
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::set_iteration | ( | int | i | ) |
Set iteration id to i. This can be useful when resuming a training to a certain iteration.
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::set_progress_file | ( | const std::string & | s | ) |
Sets the name of the file indicating progress of training. If set, this file will be 'touched' after each sample is trained or tested to indicate that training is still going on.
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::test | ( | labeled_datasource< Tnet, Tdata, Tlabel > & | ds, |
classifier_meter & | log, | ||
infer_param & | infp, | ||
uint | max_test = 0 |
||
) |
Measure the average energy and classification error rate on a dataset.
max_test | If > 0, limit the number of tests to this number. |
bool ebl::supervised_trainer< Tnet, Tdata, Tlabel >::test_sample | ( | labeled_datasource< Tnet, Tdata, Tlabel > & | ds, |
bbstate_idx< Tnet > & | label, | ||
bbstate_idx< Tnet > & | answers, | ||
infer_param & | infp | ||
) |
Test the current sample of 'ds', put the answers in 'answers' and return true if the infered label equals the groundtruth 'label'.
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::train | ( | labeled_datasource< Tnet, Tdata, Tlabel > & | ds, |
classifier_meter & | log, | ||
gd_param & | args, | ||
int | niter, | ||
infer_param & | infp, | ||
intg | hessian_period = 0 , |
||
intg | nhessian = 0 , |
||
double | mu = .02 |
||
) |
train for <niter> sweeps over the training set. <samples> contains the inputs samples, and <labels> the corresponding desired categories <labels>. return the average energy computed on-the-fly. <update-args> is a list of arguments for the parameter update method (e.g. learning rate and weight decay).
hessian_period | Recompute 2nd order derivatives at every 'hessian_period' samples if > 0. |
nhessian | Estimate 2nd order derivatives on 'nhessian' samples. |
Tnet ebl::supervised_trainer< Tnet, Tdata, Tlabel >::train_sample | ( | labeled_datasource< Tnet, Tdata, Tlabel > & | ds, |
gd_param & | arg | ||
) |
Perform a learning update on the current sample of 'ds', using 'arguments arg' for the parameter update method (e.g. learning rate and weight decay).
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::update_progress | ( | ) |
If progress file is defined, touch the file to let outside world know that training is still alive.