libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_states.h
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun and Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_STATES_H_
00034 #define EBL_STATES_H_
00035 
00036 #include "libidx.h"
00037 
00038 namespace ebl {
00039 
00040   class infer_param {
00041   };
00042 
00047   class EXPORT gd_param: public infer_param {
00048   public:
00050     double eta;
00052     double n;
00054     double decay_l1;
00056     double decay_l2;
00058     intg decay_time;
00060     double inertia;
00062     double anneal_value;
00064     intg anneal_period;
00066     double gradient_threshold;
00068     int niter_done;
00069 
00071     gd_param();
00072     
00074     gd_param(double leta, double ln, double l1, double l2, intg dtime,
00075              double iner, double a_v, intg a_p, double g_t);
00076   };
00077 
00078   EXPORT std::ostream& operator<<(std::ostream &out, const gd_param &p);
00079 
00082   class forget_param {
00083   };
00084 
00085   class EXPORT forget_param_linear: public forget_param {
00086   public:
00090     forget_param_linear(double v, double e);
00091 
00092     // public member variables /////////////////////////////////////////////////
00093   public:
00096     double value;
00097     double exponent;
00098     random generator; 
00099   };
00100 
00107   class EXPORT state : public smart_pointer {
00108   public:
00110     state();
00112     virtual ~state();
00114     virtual void clear();
00116     virtual void clear_x();
00118     virtual void clear_dx();
00120     virtual void clear_ddx();
00121     virtual void update_gd(gd_param &arg);
00124     virtual state& operator=(const state& other);
00125   };
00126 
00127   template <typename T> class bbstate_idx;
00128   template <typename T, class Tstate = bbstate_idx<T> > class parameter;
00129 
00132   template <typename T> class fstate_idx: public state {
00133   public:
00134     virtual ~fstate_idx();
00135 
00138 
00140     fstate_idx();
00142     fstate_idx(intg s0);
00144     fstate_idx(intg s0, intg s1);
00146     fstate_idx(intg s0, intg s1, intg s2);
00148     fstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1,
00149               intg s6 = -1, intg s7 = -1);
00151     fstate_idx(const idxdim &d);
00154     fstate_idx(intg n, fstate_idx<T> &fs);
00155 
00158 
00163     fstate_idx(parameter<T,fstate_idx<T> > *st);
00168     fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0);
00173     fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1);
00178     fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1, intg s2);
00184     fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1, intg s2,
00185                intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1);
00191     fstate_idx(parameter<T,fstate_idx<T> > *st, const idxdim &d);
00192 
00195 
00199     fstate_idx(const idx<T> &x);
00200 
00203 
00205     virtual void clear();
00207     virtual void clear_x();
00208 
00211 
00213     virtual intg nelements();
00215     virtual intg footprint();
00217     virtual intg size();
00218 
00221 
00223     virtual void resize(intg s0 = -1, intg s1 = -1, intg s2 = -1, intg s3 = -1,
00224                         intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1);
00226     virtual void resize(const idxdim &d);
00229     virtual void resize1(intg dimn, intg size);   
00232     virtual void resize_as(fstate_idx& s);
00235     virtual void resize_as_but1(fstate_idx<T>& s, intg fixed_dim);
00236 
00237     //    virtual void resize(const intg* dimsBegin, const intg* dimsEnd);
00238 
00241 
00247     fstate_idx<T> select(int dimension, intg slice_index);    
00254     fstate_idx<T> narrow(int dimension, intg size, intg offset);
00255     
00258     
00260     fstate_idx<T> make_copy();
00263     virtual fstate_idx<T>& operator=(const fstate_idx<T>& other);
00265     virtual void copy(fstate_idx<T> &cpy);
00266 
00267     // info printing ///////////////////////////////////////////////////////////
00268     
00270     virtual void pretty();    
00272     virtual void print();
00273     
00276   public:
00278     idx<T> x;
00280     idx<T> dx;
00282     idx<T> ddx;
00283   };
00284 
00285   
00288   template <typename T> class bstate_idx: public fstate_idx<T> {
00289   public:
00290     virtual ~bstate_idx();
00291 
00294 
00296     bstate_idx();
00298     bstate_idx(intg s0);
00300     bstate_idx(intg s0, intg s1);
00302     bstate_idx(intg s0, intg s1, intg s2);
00304     bstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1,
00305               intg s6 = -1, intg s7 = -1);
00307     bstate_idx(const idxdim &d);
00310     bstate_idx(intg n, bstate_idx<T> &fs);
00311 
00314 
00319     bstate_idx(parameter<T,bstate_idx<T> > *st);
00324     bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0);
00329     bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1);
00334     bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1, intg s2);
00340     bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1, intg s2,
00341                intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1);
00347     bstate_idx(parameter<T,bstate_idx<T> > *st, const idxdim &d);
00348 
00351 
00355     bstate_idx(const idx<T> &x, const idx<T> &dx);
00356 
00359 
00361     virtual void clear();
00363     virtual void clear_dx();
00365     virtual void update_gd(gd_param &arg);
00366       
00369 
00371     using fstate_idx<T>::nelements;
00373     using fstate_idx<T>::footprint;
00375     using fstate_idx<T>::size;
00376     
00379 
00381     virtual void resize(intg s0 = -1, intg s1 = -1, intg s2 = -1, intg s3 = -1,
00382                         intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1);
00384     virtual void resize(const idxdim &d);
00387     virtual void resize1(intg dimn, intg size);    
00390     virtual void resize_as(bstate_idx& s);
00393     virtual void resize_as_but1(bstate_idx<T>& s, intg fixed_dim);
00394 
00395     //    virtual void resize(const intg* dimsBegin, const intg* dimsEnd);
00396 
00399 
00405     bstate_idx<T> select(int dimension, intg slice_index);     
00412     bstate_idx<T> narrow(int dimension, intg size, intg offset);
00413     
00416     
00418     bstate_idx<T> make_copy();
00421     virtual bstate_idx<T>& operator=(const bstate_idx<T>& other);
00423     virtual void copy(bstate_idx<T> &cpy);
00424     
00425     // info printing ///////////////////////////////////////////////////////////
00426     
00428     virtual void pretty();    
00430     virtual void print();
00431     
00434   public:
00436     using fstate_idx<T>::x;
00438     using fstate_idx<T>::dx;
00440     using fstate_idx<T>::ddx;
00441   };
00442   
00445   template <typename T> class bbstate_idx: public bstate_idx<T> {
00446   public:
00447     virtual ~bbstate_idx();
00448 
00451 
00453     bbstate_idx();
00455     bbstate_idx(intg s0);
00457     bbstate_idx(intg s0, intg s1);
00459     bbstate_idx(intg s0, intg s1, intg s2);
00461     bbstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1,
00462               intg s6 = -1, intg s7 = -1);
00464     bbstate_idx(const idxdim &d);
00467     bbstate_idx(intg n, bbstate_idx<T> &fs);
00468 
00471 
00476     bbstate_idx(parameter<T,bbstate_idx<T> > *st);
00481     bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0);
00486     bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, intg s1);
00491     bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, intg s1, intg s2);
00497     bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, intg s1, intg s2,
00498                 intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 =-1);
00504     bbstate_idx(parameter<T,bbstate_idx<T> > *st, const idxdim &d);
00505 
00508 
00509     // TODO: this causes bug in mnist part of tester
00510     /* //! Constructs a bbstate_idx from an idx to be used as x. dx and ddx */
00511     /* //! will be allocated with the same size as x. */
00512     /*//! Note: the data pointed to by x is not copied, we only create new idx*/
00513     /* //!   pointing to the same data. */
00514     /* bbstate_idx(const idx<T> &x); */
00515 
00519     bbstate_idx(const idx<T> &x, const idx<T> &dx, const idx<T> &ddx);
00520 
00523 
00525     virtual void clear();
00527     virtual void clear_ddx();
00529     using bstate_idx<T>::update_gd;
00530       
00533 
00535     using fstate_idx<T>::nelements;
00537     using fstate_idx<T>::footprint;
00539     using fstate_idx<T>::size;
00540     
00543 
00545     virtual void resize(intg s0 = -1, intg s1 = -1, intg s2 = -1, intg s3 = -1,
00546                         intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1);
00548     virtual void resize(const idxdim &d);
00551     virtual void resize1(intg dimn, intg size);    
00554     virtual void resize_as(bbstate_idx& s);
00557     virtual void resize_as_but1(bbstate_idx<T>& s, intg fixed_dim);
00558 
00559     //    virtual void resize(const intg* dimsBegin, const intg* dimsEnd);
00560 
00563 
00569     bbstate_idx<T> select(int dimension, intg slice_index);     
00576     bbstate_idx<T> narrow(int dimension, intg size, intg offset);
00577     
00580     
00582     bbstate_idx<T> make_copy();
00585     virtual bbstate_idx<T>& operator=(const bbstate_idx<T>& other);
00587     virtual void copy(bbstate_idx<T> &cpy);
00588     
00589     // info printing ///////////////////////////////////////////////////////////
00590     
00592     virtual void pretty();    
00594     virtual void print();
00595     
00598   public:
00600     using fstate_idx<T>::x;
00602     using fstate_idx<T>::dx;
00604     using fstate_idx<T>::ddx;
00605   };
00606 
00609   template <typename T>
00610     class parameter<T, fstate_idx<T> > : public fstate_idx<T> {
00611   public:
00613     parameter(intg initial_size = 100);
00615     parameter(const char *param_filename);
00617     virtual ~parameter();
00618     virtual void resize(intg s0);
00619 
00621     // I/O methods
00622     
00626     bool load_x(std::vector<string> &files);
00628     bool load_x(const char *param_filename);
00630     bool load_x(idx<T> &m);
00632     bool save_x(const char *param_filename);
00633 
00636 
00638     using fstate_idx<T>::nelements;
00640     using fstate_idx<T>::footprint;
00642     using fstate_idx<T>::size;
00643     
00646   public:
00647     using fstate_idx<T>::x;
00648   };
00649 
00652 
00654   template <typename T>
00655     class parameter<T, bstate_idx<T> >
00656     : public bstate_idx<T>, public parameter<T,fstate_idx<T> > {
00657   public:
00659     parameter(intg initial_size = 100);
00661     parameter(const char *param_filename);
00663     virtual ~parameter();
00665     virtual void resize(intg s0);
00667     virtual void update(gd_param &arg);
00669     void clear_deltax();    
00671     void set_epsilon(T m);
00672 
00674     // I/O methods
00675     
00677     bool load_x(const char *param_filename);
00679     bool save_x(const char *param_filename);
00680 
00683 
00685     using fstate_idx<T>::nelements;
00687     using fstate_idx<T>::footprint;
00689     using fstate_idx<T>::size;
00690     
00692     // protected methods
00693 
00694   protected:
00695 
00698     void update_gd(gd_param &arg);
00702     void update_deltax(T knew, T kold);
00703     
00706   public:
00707     using fstate_idx<T>::x;
00708     using bstate_idx<T>::dx;
00709 
00710     //    idx<T> gradient;
00711     idx<T> deltax;   
00712     idx<T> epsilons; 
00713   };
00714 
00718 
00719   template <typename T>
00720     class parameter<T, bbstate_idx<T> > : public bbstate_idx<T> {
00721   public:
00723     parameter(intg initial_size = 100);
00725     parameter(const char *param_filename);
00727     virtual ~parameter();
00729     virtual void resize(intg s0);
00731     virtual void update(gd_param &arg);
00733     void clear_deltax();
00735     void clear_ddeltax();
00737     void set_epsilon(T m);
00740     void compute_epsilons(T mu);
00744     void update_ddeltax(T knew, T kold);
00745     
00747     // I/O methods
00748     
00752     bool load_x(std::vector<string> &files);
00754     bool load_x(const char *param_filename);
00756     bool load_x(idx<T> &m);
00758     bool save_x(const char *param_filename);
00759 
00762 
00764     using fstate_idx<T>::nelements;
00766     using fstate_idx<T>::footprint;
00768     using fstate_idx<T>::size;
00769     
00771     // protected methods
00772   protected:
00773     
00776     void update_gd(gd_param &arg);    
00780     void update_deltax(T knew, T kold);
00781     
00784   public:
00785     using fstate_idx<T>::x;
00786     using fstate_idx<T>::dx;
00787     using fstate_idx<T>::ddx;
00788 
00789     //idx<T> gradient;    
00790     idx<T> deltax;   
00791     idx<T> epsilons; 
00792     idx<T> ddeltax;  
00793   };
00794 
00796   // state_idxloopers
00797 
00798   template <typename Tstate> class state_idxlooper;
00799   
00801   template <typename T>
00802     class state_idxlooper<fstate_idx<T> > : public fstate_idx<T> {
00803   public:
00804     using fstate_idx<T>::x;
00805 
00806     idxlooper<T> lx;
00807 
00809     state_idxlooper(fstate_idx<T> &s, int ld);
00810     virtual ~state_idxlooper();
00812     bool notdone();
00814     void next();
00815   };
00816 
00819   template <typename T>
00820     class state_idxlooper<bstate_idx<T> > : public bstate_idx<T> {
00821   public:
00822     using bstate_idx<T>::x;
00823     using bstate_idx<T>::dx;
00824 
00825     idxlooper<T> lx;
00826     idxlooper<T> ldx;
00827 
00829     state_idxlooper(bstate_idx<T> &s, int ld);
00830     virtual ~state_idxlooper();
00832     bool notdone();
00834     void next();
00835   };
00836 
00839   template <typename T>
00840     class state_idxlooper<bbstate_idx<T> > : public bbstate_idx<T> {
00841   public:
00842     using bbstate_idx<T>::x;
00843     using bbstate_idx<T>::dx;
00844     using bbstate_idx<T>::ddx;
00845 
00846     idxlooper<T> lx;
00847     idxlooper<T> ldx;
00848     idxlooper<T> lddx;
00849 
00851     state_idxlooper(bbstate_idx<T> &s, int ld);
00852     virtual ~state_idxlooper();
00854     bool notdone();
00856     void next();
00857   };
00858 
00861 #define state_idx_eloop1(dst0,src0,type0)                               \
00862   state_idxlooper<type0> dst0(src0, (src0).x.order() - 1);              \
00863   for ( ; dst0.notdone(); dst0.next())
00864 
00867 #define state_idx_eloop2(dst0,src0,type0,dst1,src1,type1)               \
00868   if ((src0).x.dim((src0).x.order() - 1)                                \
00869   != (src1).x.dim((src1).x.order() - 1))                                \
00870   eblerror("incompatible state_idx for eloop\n");                       \
00871   state_idxlooper<type0> dst0(src0,(src0).x.order()-1);                 \
00872   state_idxlooper<type1> dst1(src1,(src1).x.order()-1);                 \
00873   for ( ; dst0.notdone(); dst0.next(), dst1.next())
00874 
00877 #define state_idx_eloop3(src0,type0,dst1,src1,type1,dst2,src2,type2)    \
00878   if (((src0).x.dim((src0).x.order() - 1)                               \
00879        != (src1).x.dim((src1).x.order() - 1))                           \
00880       || ((src0).x.dim((src0).x.order() - 1)                            \
00881           != (src2).x.dim((src2).x.order() - 1)))                       \
00882     eblerror("incompatible idxs for eloop\n");                          \
00883   state_idxlooper<type0> dst0(src0,(src0).x.order()-1);                 \
00884   state_idxlooper<type1> dst1(src1,(src1).x.order()-1);                 \
00885   state_idxlooper<type2> dst2(src2,(src2).x.order()-1);                 \
00886   for ( ; dst0.notdone(); dst0.next(), dst1.next(), dst2.next())
00887 
00891   template <class Tstate> class mstate : public svector<Tstate> {
00892   public:
00894     mstate();
00900     mstate(const mstate<Tstate> &other, intg dims, intg nstates = -1);
00902     mstate(const mstate<Tstate> &other);
00904     virtual ~mstate();
00905 
00907 
00909     virtual void clear_x();
00911     virtual void clear_dx();
00913     virtual void clear_ddx();
00914     
00916     virtual void copy(mstate<Tstate> &cpy);
00918     template <typename T> void copy(midx<T> &cpy);
00920     template <typename T> midx<T> copy();
00921     
00923     virtual mstate<Tstate> narrow(intg size, intg offset);
00925     virtual mstate<Tstate> narrow(int dimension, intg size, intg offset);
00927     virtual mstate<Tstate> narrow(midxdim &regions);
00930     virtual mstate<Tstate> narrow_max(mfidxdim &regions);
00931     
00934     template <class T> void get_midx(mfidxdim &regions, midx<T> &all);
00938     template <class T> void get_max_midx(mfidxdim &regions, midx<T> &all);
00941     template <class T> void get_padded_midx(mfidxdim &regions, midx<T> &all);
00944     
00945     virtual void resize(mstate<Tstate> &s2, uint nmax = 0);
00951     template <class Tstate2> void resize(mstate<Tstate2> &other);
00953     virtual idxdim& get_idxdim0();
00954     
00955     // member variables ////////////////////////////////////////////////////////
00956   protected:
00957     typename svector<Tstate>::iterator it;
00958   };
00959 
00960   
00961 
00963   // stream operators
00964 
00965   template <typename T>
00966   EXPORT std::ostream& operator<<(std::ostream &out, const fstate_idx<T> &p);
00967   template <typename T>
00968   EXPORT std::ostream& operator<<(std::ostream &out, const bstate_idx<T> &p);
00969   template <typename T>
00970   EXPORT std::ostream& operator<<(std::ostream &out, const bbstate_idx<T> &p);  
00971   template <class Tstate>
00972   EXPORT std::ostream& operator<<(std::ostream &out, const mstate<Tstate> &p);
00973 
00974 } // namespace ebl {
00975 
00976 #include "ebl_states.hpp"
00977 
00978 #endif /* EBL_STATES_H_ */