libeblearn
ebl::supervised_trainer< Tnet, Tdata, Tlabel > Class Template Reference

#include <ebl_trainer.h>

List of all members.

Public Member Functions

 supervised_trainer (trainable_module< Tnet, Tdata, Tlabel > &m, parameter< Tnet, bbstate_idx< Tnet > > &p)
 constructor.
virtual ~supervised_trainer ()
 destructor.
bool test_sample (labeled_datasource< Tnet, Tdata, Tlabel > &ds, bbstate_idx< Tnet > &label, bbstate_idx< Tnet > &answers, infer_param &infp)
Tnet train_sample (labeled_datasource< Tnet, Tdata, Tlabel > &ds, gd_param &arg)
void test (labeled_datasource< Tnet, Tdata, Tlabel > &ds, classifier_meter &log, infer_param &infp, uint max_test=0)
void train (labeled_datasource< Tnet, Tdata, Tlabel > &ds, classifier_meter &log, gd_param &args, int niter, infer_param &infp, intg hessian_period=0, intg nhessian=0, double mu=.02)
void compute_diaghessian (labeled_datasource< Tnet, Tdata, Tlabel > &ds, intg niter, double mu)
 compute hessian
void set_iteration (int i)
void pretty (labeled_datasource< Tnet, Tdata, Tlabel > &ds)
 pretty some information about training, e.g. input and network sizes.
void set_progress_file (const std::string &s)
void update_progress ()

Protected Member Functions

void init (labeled_datasource< Tnet, Tdata, Tlabel > &ds, classifier_meter *log=NULL, bool new_iteration=false)

Protected Attributes

trainable_module< Tnet, Tdata,
Tlabel > & 
machine
parameter< Tnet, bbstate_idx
< Tnet > > & 
param
 the learned params
bbstate_idx< Tnet > energy
 Tmp energy buffer.
bbstate_idx< Tnet > * answers
 Tmp answer buffer.
bbstate_idx< Tnet > * label
 Tmp label buffer.
intg age
int iteration
void * iteration_ptr
bool prettied
 Flag used to pretty info just once.
std::string progress_file
 Name of progress file.
intg progress_cnt
 A count for updating progress.
bool test_running
 Show test on trained.

Friends

class supervised_trainer_gui

Detailed Description

template<typename Tnet, typename Tdata, typename Tlabel>
class ebl::supervised_trainer< Tnet, Tdata, Tlabel >

Supervised Trainer. A specialisation of the generic trainer, taking samples (of type Tnet) and labels (of type Tlabel) as training input. Template Tnet is the network's type and also the input data's type. However datasources with different data type may be provided in which case a conversion will occur after each sample extraction from the datasource (via a deep idx_copy).


Member Function Documentation

template<typename Tnet , typename Tdata , typename Tlabel >
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::init ( labeled_datasource< Tnet, Tdata, Tlabel > &  ds,
classifier_meter log = NULL,
bool  new_iteration = false 
) [protected]

init datasource to begining and assign indata to a buffer corresponding to ds's sample size. also increment iteration counter, unless new_iteration is false.

template<typename Tnet , typename Tdata , typename Tlabel >
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::set_iteration ( int  i)

Set iteration id to i. This can be useful when resuming a training to a certain iteration.

template<typename Tnet , typename Tdata , typename Tlabel >
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::set_progress_file ( const std::string &  s)

Sets the name of the file indicating progress of training. If set, this file will be 'touched' after each sample is trained or tested to indicate that training is still going on.

template<typename Tnet , typename Tdata , typename Tlabel >
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::test ( labeled_datasource< Tnet, Tdata, Tlabel > &  ds,
classifier_meter log,
infer_param infp,
uint  max_test = 0 
)

Measure the average energy and classification error rate on a dataset.

Parameters:
max_testIf > 0, limit the number of tests to this number.
template<typename Tnet , typename Tdata , typename Tlabel >
bool ebl::supervised_trainer< Tnet, Tdata, Tlabel >::test_sample ( labeled_datasource< Tnet, Tdata, Tlabel > &  ds,
bbstate_idx< Tnet > &  label,
bbstate_idx< Tnet > &  answers,
infer_param infp 
)

Test the current sample of 'ds', put the answers in 'answers' and return true if the infered label equals the groundtruth 'label'.

template<typename Tnet , typename Tdata , typename Tlabel >
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::train ( labeled_datasource< Tnet, Tdata, Tlabel > &  ds,
classifier_meter log,
gd_param args,
int  niter,
infer_param infp,
intg  hessian_period = 0,
intg  nhessian = 0,
double  mu = .02 
)

train for <niter> sweeps over the training set. <samples> contains the inputs samples, and <labels> the corresponding desired categories <labels>. return the average energy computed on-the-fly. <update-args> is a list of arguments for the parameter update method (e.g. learning rate and weight decay).

Parameters:
hessian_periodRecompute 2nd order derivatives at every 'hessian_period' samples if > 0.
nhessianEstimate 2nd order derivatives on 'nhessian' samples.
template<typename Tnet , typename Tdata , typename Tlabel >
Tnet ebl::supervised_trainer< Tnet, Tdata, Tlabel >::train_sample ( labeled_datasource< Tnet, Tdata, Tlabel > &  ds,
gd_param arg 
)

Perform a learning update on the current sample of 'ds', using 'arguments arg' for the parameter update method (e.g. learning rate and weight decay).

template<typename Tnet , typename Tdata , typename Tlabel >
void ebl::supervised_trainer< Tnet, Tdata, Tlabel >::update_progress ( )

If progress file is defined, touch the file to let outside world know that training is still alive.


The documentation for this class was generated from the following files: