libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_STATES_H_ 00034 #define EBL_STATES_H_ 00035 00036 #include "libidx.h" 00037 00038 namespace ebl { 00039 00040 class infer_param { 00041 }; 00042 00047 class EXPORT gd_param: public infer_param { 00048 public: 00050 double eta; 00052 double n; 00054 double decay_l1; 00056 double decay_l2; 00058 intg decay_time; 00060 double inertia; 00062 double anneal_value; 00064 intg anneal_period; 00066 double gradient_threshold; 00068 int niter_done; 00069 00071 gd_param(); 00072 00074 gd_param(double leta, double ln, double l1, double l2, intg dtime, 00075 double iner, double a_v, intg a_p, double g_t); 00076 }; 00077 00078 EXPORT std::ostream& operator<<(std::ostream &out, const gd_param &p); 00079 00082 class forget_param { 00083 }; 00084 00085 class EXPORT forget_param_linear: public forget_param { 00086 public: 00090 forget_param_linear(double v, double e); 00091 00092 // public member variables ///////////////////////////////////////////////// 00093 public: 00096 double value; 00097 double exponent; 00098 random generator; 00099 }; 00100 00107 class EXPORT state : public smart_pointer { 00108 public: 00110 state(); 00112 virtual ~state(); 00114 virtual void clear(); 00116 virtual void clear_x(); 00118 virtual void clear_dx(); 00120 virtual void clear_ddx(); 00121 virtual void update_gd(gd_param &arg); 00124 virtual state& operator=(const state& other); 00125 }; 00126 00127 template <typename T> class bbstate_idx; 00128 template <typename T, class Tstate = bbstate_idx<T> > class parameter; 00129 00132 template <typename T> class fstate_idx: public state { 00133 public: 00134 virtual ~fstate_idx(); 00135 00138 00140 fstate_idx(); 00142 fstate_idx(intg s0); 00144 fstate_idx(intg s0, intg s1); 00146 fstate_idx(intg s0, intg s1, intg s2); 00148 fstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1, 00149 intg s6 = -1, intg s7 = -1); 00151 fstate_idx(const idxdim &d); 00154 fstate_idx(intg n, fstate_idx<T> &fs); 00155 00158 00163 fstate_idx(parameter<T,fstate_idx<T> > *st); 00168 fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0); 00173 fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1); 00178 fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1, intg s2); 00184 fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1, intg s2, 00185 intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); 00191 fstate_idx(parameter<T,fstate_idx<T> > *st, const idxdim &d); 00192 00195 00199 fstate_idx(const idx<T> &x); 00200 00203 00205 virtual void clear(); 00207 virtual void clear_x(); 00208 00211 00213 virtual intg nelements(); 00215 virtual intg footprint(); 00217 virtual intg size(); 00218 00221 00223 virtual void resize(intg s0 = -1, intg s1 = -1, intg s2 = -1, intg s3 = -1, 00224 intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); 00226 virtual void resize(const idxdim &d); 00229 virtual void resize1(intg dimn, intg size); 00232 virtual void resize_as(fstate_idx& s); 00235 virtual void resize_as_but1(fstate_idx<T>& s, intg fixed_dim); 00236 00237 // virtual void resize(const intg* dimsBegin, const intg* dimsEnd); 00238 00241 00247 fstate_idx<T> select(int dimension, intg slice_index); 00254 fstate_idx<T> narrow(int dimension, intg size, intg offset); 00255 00258 00260 fstate_idx<T> make_copy(); 00263 virtual fstate_idx<T>& operator=(const fstate_idx<T>& other); 00265 virtual void copy(fstate_idx<T> &cpy); 00266 00267 // info printing /////////////////////////////////////////////////////////// 00268 00270 virtual void pretty(); 00272 virtual void print(); 00273 00276 public: 00278 idx<T> x; 00280 idx<T> dx; 00282 idx<T> ddx; 00283 }; 00284 00285 00288 template <typename T> class bstate_idx: public fstate_idx<T> { 00289 public: 00290 virtual ~bstate_idx(); 00291 00294 00296 bstate_idx(); 00298 bstate_idx(intg s0); 00300 bstate_idx(intg s0, intg s1); 00302 bstate_idx(intg s0, intg s1, intg s2); 00304 bstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1, 00305 intg s6 = -1, intg s7 = -1); 00307 bstate_idx(const idxdim &d); 00310 bstate_idx(intg n, bstate_idx<T> &fs); 00311 00314 00319 bstate_idx(parameter<T,bstate_idx<T> > *st); 00324 bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0); 00329 bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1); 00334 bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1, intg s2); 00340 bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1, intg s2, 00341 intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); 00347 bstate_idx(parameter<T,bstate_idx<T> > *st, const idxdim &d); 00348 00351 00355 bstate_idx(const idx<T> &x, const idx<T> &dx); 00356 00359 00361 virtual void clear(); 00363 virtual void clear_dx(); 00365 virtual void update_gd(gd_param &arg); 00366 00369 00371 using fstate_idx<T>::nelements; 00373 using fstate_idx<T>::footprint; 00375 using fstate_idx<T>::size; 00376 00379 00381 virtual void resize(intg s0 = -1, intg s1 = -1, intg s2 = -1, intg s3 = -1, 00382 intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); 00384 virtual void resize(const idxdim &d); 00387 virtual void resize1(intg dimn, intg size); 00390 virtual void resize_as(bstate_idx& s); 00393 virtual void resize_as_but1(bstate_idx<T>& s, intg fixed_dim); 00394 00395 // virtual void resize(const intg* dimsBegin, const intg* dimsEnd); 00396 00399 00405 bstate_idx<T> select(int dimension, intg slice_index); 00412 bstate_idx<T> narrow(int dimension, intg size, intg offset); 00413 00416 00418 bstate_idx<T> make_copy(); 00421 virtual bstate_idx<T>& operator=(const bstate_idx<T>& other); 00423 virtual void copy(bstate_idx<T> &cpy); 00424 00425 // info printing /////////////////////////////////////////////////////////// 00426 00428 virtual void pretty(); 00430 virtual void print(); 00431 00434 public: 00436 using fstate_idx<T>::x; 00438 using fstate_idx<T>::dx; 00440 using fstate_idx<T>::ddx; 00441 }; 00442 00445 template <typename T> class bbstate_idx: public bstate_idx<T> { 00446 public: 00447 virtual ~bbstate_idx(); 00448 00451 00453 bbstate_idx(); 00455 bbstate_idx(intg s0); 00457 bbstate_idx(intg s0, intg s1); 00459 bbstate_idx(intg s0, intg s1, intg s2); 00461 bbstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1, 00462 intg s6 = -1, intg s7 = -1); 00464 bbstate_idx(const idxdim &d); 00467 bbstate_idx(intg n, bbstate_idx<T> &fs); 00468 00471 00476 bbstate_idx(parameter<T,bbstate_idx<T> > *st); 00481 bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0); 00486 bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, intg s1); 00491 bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, intg s1, intg s2); 00497 bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, intg s1, intg s2, 00498 intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 =-1); 00504 bbstate_idx(parameter<T,bbstate_idx<T> > *st, const idxdim &d); 00505 00508 00509 // TODO: this causes bug in mnist part of tester 00510 /* //! Constructs a bbstate_idx from an idx to be used as x. dx and ddx */ 00511 /* //! will be allocated with the same size as x. */ 00512 /*//! Note: the data pointed to by x is not copied, we only create new idx*/ 00513 /* //! pointing to the same data. */ 00514 /* bbstate_idx(const idx<T> &x); */ 00515 00519 bbstate_idx(const idx<T> &x, const idx<T> &dx, const idx<T> &ddx); 00520 00523 00525 virtual void clear(); 00527 virtual void clear_ddx(); 00529 using bstate_idx<T>::update_gd; 00530 00533 00535 using fstate_idx<T>::nelements; 00537 using fstate_idx<T>::footprint; 00539 using fstate_idx<T>::size; 00540 00543 00545 virtual void resize(intg s0 = -1, intg s1 = -1, intg s2 = -1, intg s3 = -1, 00546 intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); 00548 virtual void resize(const idxdim &d); 00551 virtual void resize1(intg dimn, intg size); 00554 virtual void resize_as(bbstate_idx& s); 00557 virtual void resize_as_but1(bbstate_idx<T>& s, intg fixed_dim); 00558 00559 // virtual void resize(const intg* dimsBegin, const intg* dimsEnd); 00560 00563 00569 bbstate_idx<T> select(int dimension, intg slice_index); 00576 bbstate_idx<T> narrow(int dimension, intg size, intg offset); 00577 00580 00582 bbstate_idx<T> make_copy(); 00585 virtual bbstate_idx<T>& operator=(const bbstate_idx<T>& other); 00587 virtual void copy(bbstate_idx<T> &cpy); 00588 00589 // info printing /////////////////////////////////////////////////////////// 00590 00592 virtual void pretty(); 00594 virtual void print(); 00595 00598 public: 00600 using fstate_idx<T>::x; 00602 using fstate_idx<T>::dx; 00604 using fstate_idx<T>::ddx; 00605 }; 00606 00609 template <typename T> 00610 class parameter<T, fstate_idx<T> > : public fstate_idx<T> { 00611 public: 00613 parameter(intg initial_size = 100); 00615 parameter(const char *param_filename); 00617 virtual ~parameter(); 00618 virtual void resize(intg s0); 00619 00621 // I/O methods 00622 00626 bool load_x(std::vector<string> &files); 00628 bool load_x(const char *param_filename); 00630 bool load_x(idx<T> &m); 00632 bool save_x(const char *param_filename); 00633 00636 00638 using fstate_idx<T>::nelements; 00640 using fstate_idx<T>::footprint; 00642 using fstate_idx<T>::size; 00643 00646 public: 00647 using fstate_idx<T>::x; 00648 }; 00649 00652 00654 template <typename T> 00655 class parameter<T, bstate_idx<T> > 00656 : public bstate_idx<T>, public parameter<T,fstate_idx<T> > { 00657 public: 00659 parameter(intg initial_size = 100); 00661 parameter(const char *param_filename); 00663 virtual ~parameter(); 00665 virtual void resize(intg s0); 00667 virtual void update(gd_param &arg); 00669 void clear_deltax(); 00671 void set_epsilon(T m); 00672 00674 // I/O methods 00675 00677 bool load_x(const char *param_filename); 00679 bool save_x(const char *param_filename); 00680 00683 00685 using fstate_idx<T>::nelements; 00687 using fstate_idx<T>::footprint; 00689 using fstate_idx<T>::size; 00690 00692 // protected methods 00693 00694 protected: 00695 00698 void update_gd(gd_param &arg); 00702 void update_deltax(T knew, T kold); 00703 00706 public: 00707 using fstate_idx<T>::x; 00708 using bstate_idx<T>::dx; 00709 00710 // idx<T> gradient; 00711 idx<T> deltax; 00712 idx<T> epsilons; 00713 }; 00714 00718 00719 template <typename T> 00720 class parameter<T, bbstate_idx<T> > : public bbstate_idx<T> { 00721 public: 00723 parameter(intg initial_size = 100); 00725 parameter(const char *param_filename); 00727 virtual ~parameter(); 00729 virtual void resize(intg s0); 00731 virtual void update(gd_param &arg); 00733 void clear_deltax(); 00735 void clear_ddeltax(); 00737 void set_epsilon(T m); 00740 void compute_epsilons(T mu); 00744 void update_ddeltax(T knew, T kold); 00745 00747 // I/O methods 00748 00752 bool load_x(std::vector<string> &files); 00754 bool load_x(const char *param_filename); 00756 bool load_x(idx<T> &m); 00758 bool save_x(const char *param_filename); 00759 00762 00764 using fstate_idx<T>::nelements; 00766 using fstate_idx<T>::footprint; 00768 using fstate_idx<T>::size; 00769 00771 // protected methods 00772 protected: 00773 00776 void update_gd(gd_param &arg); 00780 void update_deltax(T knew, T kold); 00781 00784 public: 00785 using fstate_idx<T>::x; 00786 using fstate_idx<T>::dx; 00787 using fstate_idx<T>::ddx; 00788 00789 //idx<T> gradient; 00790 idx<T> deltax; 00791 idx<T> epsilons; 00792 idx<T> ddeltax; 00793 }; 00794 00796 // state_idxloopers 00797 00798 template <typename Tstate> class state_idxlooper; 00799 00801 template <typename T> 00802 class state_idxlooper<fstate_idx<T> > : public fstate_idx<T> { 00803 public: 00804 using fstate_idx<T>::x; 00805 00806 idxlooper<T> lx; 00807 00809 state_idxlooper(fstate_idx<T> &s, int ld); 00810 virtual ~state_idxlooper(); 00812 bool notdone(); 00814 void next(); 00815 }; 00816 00819 template <typename T> 00820 class state_idxlooper<bstate_idx<T> > : public bstate_idx<T> { 00821 public: 00822 using bstate_idx<T>::x; 00823 using bstate_idx<T>::dx; 00824 00825 idxlooper<T> lx; 00826 idxlooper<T> ldx; 00827 00829 state_idxlooper(bstate_idx<T> &s, int ld); 00830 virtual ~state_idxlooper(); 00832 bool notdone(); 00834 void next(); 00835 }; 00836 00839 template <typename T> 00840 class state_idxlooper<bbstate_idx<T> > : public bbstate_idx<T> { 00841 public: 00842 using bbstate_idx<T>::x; 00843 using bbstate_idx<T>::dx; 00844 using bbstate_idx<T>::ddx; 00845 00846 idxlooper<T> lx; 00847 idxlooper<T> ldx; 00848 idxlooper<T> lddx; 00849 00851 state_idxlooper(bbstate_idx<T> &s, int ld); 00852 virtual ~state_idxlooper(); 00854 bool notdone(); 00856 void next(); 00857 }; 00858 00861 #define state_idx_eloop1(dst0,src0,type0) \ 00862 state_idxlooper<type0> dst0(src0, (src0).x.order() - 1); \ 00863 for ( ; dst0.notdone(); dst0.next()) 00864 00867 #define state_idx_eloop2(dst0,src0,type0,dst1,src1,type1) \ 00868 if ((src0).x.dim((src0).x.order() - 1) \ 00869 != (src1).x.dim((src1).x.order() - 1)) \ 00870 eblerror("incompatible state_idx for eloop\n"); \ 00871 state_idxlooper<type0> dst0(src0,(src0).x.order()-1); \ 00872 state_idxlooper<type1> dst1(src1,(src1).x.order()-1); \ 00873 for ( ; dst0.notdone(); dst0.next(), dst1.next()) 00874 00877 #define state_idx_eloop3(src0,type0,dst1,src1,type1,dst2,src2,type2) \ 00878 if (((src0).x.dim((src0).x.order() - 1) \ 00879 != (src1).x.dim((src1).x.order() - 1)) \ 00880 || ((src0).x.dim((src0).x.order() - 1) \ 00881 != (src2).x.dim((src2).x.order() - 1))) \ 00882 eblerror("incompatible idxs for eloop\n"); \ 00883 state_idxlooper<type0> dst0(src0,(src0).x.order()-1); \ 00884 state_idxlooper<type1> dst1(src1,(src1).x.order()-1); \ 00885 state_idxlooper<type2> dst2(src2,(src2).x.order()-1); \ 00886 for ( ; dst0.notdone(); dst0.next(), dst1.next(), dst2.next()) 00887 00891 template <class Tstate> class mstate : public svector<Tstate> { 00892 public: 00894 mstate(); 00900 mstate(const mstate<Tstate> &other, intg dims, intg nstates = -1); 00902 mstate(const mstate<Tstate> &other); 00904 virtual ~mstate(); 00905 00907 00909 virtual void clear_x(); 00911 virtual void clear_dx(); 00913 virtual void clear_ddx(); 00914 00916 virtual void copy(mstate<Tstate> &cpy); 00918 template <typename T> void copy(midx<T> &cpy); 00920 template <typename T> midx<T> copy(); 00921 00923 virtual mstate<Tstate> narrow(intg size, intg offset); 00925 virtual mstate<Tstate> narrow(int dimension, intg size, intg offset); 00927 virtual mstate<Tstate> narrow(midxdim ®ions); 00930 virtual mstate<Tstate> narrow_max(mfidxdim ®ions); 00931 00934 template <class T> void get_midx(mfidxdim ®ions, midx<T> &all); 00938 template <class T> void get_max_midx(mfidxdim ®ions, midx<T> &all); 00941 template <class T> void get_padded_midx(mfidxdim ®ions, midx<T> &all); 00944 00945 virtual void resize(mstate<Tstate> &s2, uint nmax = 0); 00951 template <class Tstate2> void resize(mstate<Tstate2> &other); 00953 virtual idxdim& get_idxdim0(); 00954 00955 // member variables //////////////////////////////////////////////////////// 00956 protected: 00957 typename svector<Tstate>::iterator it; 00958 }; 00959 00960 00961 00963 // stream operators 00964 00965 template <typename T> 00966 EXPORT std::ostream& operator<<(std::ostream &out, const fstate_idx<T> &p); 00967 template <typename T> 00968 EXPORT std::ostream& operator<<(std::ostream &out, const bstate_idx<T> &p); 00969 template <typename T> 00970 EXPORT std::ostream& operator<<(std::ostream &out, const bbstate_idx<T> &p); 00971 template <class Tstate> 00972 EXPORT std::ostream& operator<<(std::ostream &out, const mstate<Tstate> &p); 00973 00974 } // namespace ebl { 00975 00976 #include "ebl_states.hpp" 00977 00978 #endif /* EBL_STATES_H_ */