libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * 00005 * Redistribution and use in source and binary forms, with or without 00006 * modification, are permitted provided that the following conditions are met: 00007 * * Redistributions of source code must retain the above copyright 00008 * notice, this list of conditions and the following disclaimer. 00009 * * Redistributions in binary form must reproduce the above copyright 00010 * notice, this list of conditions and the following disclaimer in the 00011 * documentation and/or other materials provided with the distribution. 00012 * * Redistribution under a license not approved by the Open Source 00013 * Initiative (http://www.opensource.org) must display the 00014 * following acknowledgement in all advertising material: 00015 * This product includes software developed at the Courant 00016 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00017 * * The names of the authors may not be used to endorse or promote products 00018 * derived from this software without specific prior written permission. 00019 * 00020 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00021 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00022 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00023 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00024 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00025 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00026 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00027 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00028 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00029 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00030 ***************************************************************************/ 00031 00032 #ifndef EBL_CODEC_H_ 00033 #define EBL_CODEC_H_ 00034 00035 #include "ebl_defines.h" 00036 #include "libidx.h" 00037 #include "ebl_arch.h" 00038 #include "ebl_states.h" 00039 #include "ebl_cost.h" 00040 00041 namespace ebl { 00042 00045 template <class T> class codec: public ebm_2<T> { 00046 public: 00047 // encoder 00048 module_1_1<T> &encoder; 00049 fstate_idx<T> enc_out; 00050 ebm_2<T> &enc_cost; 00051 double weight_energy_enc; 00052 fstate_idx<T> enc_energy; 00053 // z 00054 fstate_idx<T> z; 00055 ebm_1<T> &z_cost; 00056 double weight_energy_z; 00057 fstate_idx<T> z_energy; 00058 // decoder 00059 module_1_1<T> &decoder; 00060 fstate_idx<T> dec_out; 00061 ebm_2<T> &dec_cost; 00062 double weight_energy_dec; 00063 fstate_idx<T> dec_energy; 00064 gd_param &infp; 00065 00067 codec(module_1_1<T> &encoder_, 00068 ebm_2<T> &enc_cost_, 00069 double weight_energy_enc_, 00070 ebm_1<T> &z_cost_, 00071 double weight_energy_z_, 00072 module_1_1<T> &decoder_, 00073 ebm_2<T> &dec_cost_, 00074 double weight_energy_dec_, 00075 gd_param &infp_); 00077 virtual ~codec(); 00079 virtual void fprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00080 fstate_idx<T> &energy); 00082 virtual void bprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00083 fstate_idx<T> &energy); 00085 virtual void bbprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00086 fstate_idx<T> &energy); 00088 virtual void forget(forget_param_linear &fp); 00090 virtual void normalize(); 00091 00092 protected: 00094 virtual void fprop_one_pass(fstate_idx<T> &in1, fstate_idx<T> &in2, 00095 fstate_idx<T> &energy); 00097 virtual void bprop_one_pass(fstate_idx<T> &in1, fstate_idx<T> &in2, 00098 fstate_idx<T> &energy); 00100 virtual void bprop_optimal_code(fstate_idx<T> &in1, fstate_idx<T> &in2, 00101 fstate_idx<T> &energy, gd_param &infp); 00102 00105 virtual bool check_code_threshold(fstate_idx<T> &z, gd_param &infp); 00106 }; 00107 00110 template <class T> class codec_lone: codec<T> { 00111 public: 00112 distance_l2<T> enc_cost_l2; 00113 penalty_l1<T> z_cost_l1; 00114 distance_l2<T> dec_cost_l2; 00115 00117 codec_lone(module_1_1<T> &encoder_, 00118 module_1_1<T> &decoder_, 00119 double weight_energy_enc_, 00120 double weight_energy_z_, 00121 double weight_energy_dec_, 00122 double thres, 00123 gd_param &infp_); 00125 virtual ~codec_lone(); 00126 protected: 00127 using codec<T>::infp; 00128 }; 00129 00130 } // namespace ebl { 00131 00132 #include "ebl_codec.hpp" 00133 00134 #endif /* EBL_CODEC_H_ */