libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun, Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * 00005 * Redistribution and use in source and binary forms, with or without 00006 * modification, are permitted provided that the following conditions are met: 00007 * * Redistributions of source code must retain the above copyright 00008 * notice, this list of conditions and the following disclaimer. 00009 * * Redistributions in binary form must reproduce the above copyright 00010 * notice, this list of conditions and the following disclaimer in the 00011 * documentation and/or other materials provided with the distribution. 00012 * * Redistribution under a license not approved by the Open Source 00013 * Initiative (http://www.opensource.org) must display the 00014 * following acknowledgement in all advertising material: 00015 * This product includes software developed at the Courant 00016 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00017 * * The names of the authors may not be used to endorse or promote products 00018 * derived from this software without specific prior written permission. 00019 * 00020 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00021 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00022 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00023 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00024 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00025 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00026 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00027 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00028 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00029 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00030 ***************************************************************************/ 00031 00032 namespace ebl { 00033 00035 // codec 00036 00037 template <class T> 00038 codec<T>::codec(module_1_1<T> &encoder_, 00039 ebm_2<T> &enc_cost_, 00040 double weight_energy_enc_, 00041 ebm_1<T> &z_cost_, 00042 double weight_energy_z_, 00043 module_1_1<T> &decoder_, 00044 ebm_2<T> &dec_cost_, 00045 double weight_energy_dec_, 00046 gd_param &infp_) 00047 : encoder(encoder_), enc_out(1, 1, 1), enc_cost(enc_cost_), 00048 weight_energy_enc(weight_energy_enc_), enc_energy(), 00049 z(1, 1, 1), z_cost(z_cost_), 00050 weight_energy_z(weight_energy_z_), z_energy(), 00051 decoder(decoder_), dec_out(1, 1, 1), dec_cost(dec_cost_), 00052 weight_energy_dec(weight_energy_dec_), dec_energy(), 00053 infp(infp_) { 00054 } 00055 00056 template <class T> 00057 codec<T>::~codec() { 00058 } 00059 00060 template <class T> 00061 void codec<T>::fprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00062 fstate_idx<T> &energy) { 00063 // initialize z with a simple one-pass fprop through whole machine 00064 fprop_one_pass(in1, in2, energy); 00065 } 00066 00067 template <class T> 00068 void codec<T>::bprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00069 fstate_idx<T> &energy) { 00070 // do gradient descent to find optimal code z 00071 bprop_optimal_code(in1, in2, energy, infp); 00072 // bprop through all modules 00073 bprop_one_pass(in1, in2, energy); 00074 } 00075 00076 template <class T> 00077 void codec<T>::bbprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00078 fstate_idx<T> &energy) { 00079 enc_out.clear_ddx(); 00080 dec_out.clear_ddx(); 00081 // initialize all energy 2nd derivatives with global energy derivative 00082 // so that we minimize the global cost function 00083 idx_dotc(energy.ddx, weight_energy_dec, dec_energy.ddx); 00084 idx_dotc(energy.ddx, weight_energy_enc, enc_energy.ddx); 00085 idx_dotc(energy.ddx, weight_energy_z, z_energy.ddx); 00086 // bprop through cost modules 00087 z_cost.bbprop(z, z_energy); 00088 enc_cost.bbprop(enc_out, z, enc_energy); 00089 dec_cost.bbprop(dec_out, in2, dec_energy); 00090 // bprop through encoder/decoder 00091 decoder.bbprop(z, dec_out); 00092 encoder.bbprop(in1, enc_out); 00093 } 00094 00095 template <class T> 00096 void codec<T>::forget(forget_param_linear &fp) { 00097 encoder.forget(fp); 00098 decoder.forget(fp); 00099 enc_cost.forget(fp); 00100 dec_cost.forget(fp); 00101 z_cost.forget(fp); 00102 normalize(); 00103 } 00104 00105 template <class T> 00106 void codec<T>::normalize() { 00107 decoder.normalize(); 00108 } 00109 00110 // simple one-pass forward propagation 00111 template <class T> 00112 void codec<T>::fprop_one_pass(fstate_idx<T> &in1, fstate_idx<T> &in2, 00113 fstate_idx<T> &energy) { 00114 encoder.fprop(in1, enc_out); 00115 // let the enc-cost produce its best guess 00116 // for what the code should be. If the cost 00117 // is an l2-distance, this will simply 00118 // copy :enc-out:x into :z:x 00119 // There is no need to do an fprop of 00120 // the encoder afterward, because infer2 does it. 00121 enc_cost.infer2_copy(enc_out, z, enc_energy); 00122 // compute cost penalty 00123 z_cost.fprop(z, z_energy); 00124 // fprop through decoder 00125 decoder.fprop(z, dec_out); 00126 // fprop through decoder cost. 00127 dec_cost.fprop(dec_out, in2, dec_energy); 00128 // add up energy terms 00129 energy.clear(); 00130 idx_dotcacc(enc_energy.x, weight_energy_enc, energy.x); 00131 idx_dotcacc(dec_energy.x, weight_energy_dec, energy.x); 00132 idx_dotcacc(z_energy.x, weight_energy_z, energy.x); 00133 } 00134 00135 // simple one-pass backward propagation 00136 template <class T> 00137 void codec<T>::bprop_one_pass(fstate_idx<T> &in1, fstate_idx<T> &in2, 00138 fstate_idx<T> &energy) { 00139 enc_out.clear_dx(); 00140 dec_out.clear_dx(); 00141 // initialize all energy derivatives with global energy derivative 00142 // so that we minimize the global cost function 00143 idx_dotc(energy.dx, weight_energy_dec, dec_energy.dx); 00144 idx_dotc(energy.dx, weight_energy_enc, enc_energy.dx); 00145 idx_dotc(energy.dx, weight_energy_z, z_energy.dx); 00146 // bprop through cost modules 00147 z_cost.bprop(z, z_energy); 00148 enc_cost.bprop(enc_out, z, enc_energy); 00149 dec_cost.bprop(dec_out, in2, dec_energy); 00150 // bprop through encoder/decoder 00151 decoder.bprop(z, dec_out); 00152 encoder.bprop(in1, enc_out); 00153 } 00154 00155 // multiple-pass bprop on the decoder only to find the optimal code z 00156 template <class T> 00157 void codec<T>::bprop_optimal_code(fstate_idx<T> &in1, fstate_idx<T> &in2, 00158 fstate_idx<T> &energy, gd_param &infp) { 00159 z.clear_dx(); 00160 bprop(in1, in2, energy); // bprop once to initialize energy 00161 gd_param temp_ip(infp.eta, 0, 0, 0, 0, 0, 0, 0, 0); 00162 double old_energy = energy.x.get() + 1; 00163 int cnt = 0; 00164 while ((cnt < infp.n) 00165 && check_code_threshold(z, infp) 00166 && (old_energy > energy.x.get())) { 00167 old_energy = energy.x.get(); 00168 z.clear_dx(); 00169 // bprop through decoder ///////////////////////////////////// 00170 dec_out.clear_dx(); 00171 enc_out.clear_dx(); 00172 idx_dotc(energy.dx, weight_energy_dec, dec_energy.dx); 00173 idx_dotc(energy.dx, weight_energy_z, z_energy.dx); 00174 z_cost.bprop(z, z_energy); 00175 enc_cost.bprop(enc_out, z, enc_energy); 00176 dec_cost.bprop(dec_out, in2, dec_energy); 00177 decoder.bprop(z, dec_out); 00178 z.update_gd(temp_ip); 00179 // now fprop through decoder ///////////////////////////////// 00180 decoder.fprop(z, dec_out); 00181 z_cost.fprop(z, z_energy); 00182 enc_cost.fprop(enc_out, z, enc_energy); 00183 dec_cost.fprop(dec_out, in2, dec_energy); 00184 // add up energy terms /////////////////////////////////////// 00185 idx_dotcacc(enc_energy.x, weight_energy_enc, energy.x); 00186 idx_dotcacc(dec_energy.x, weight_energy_dec, energy.x); 00187 idx_dotcacc(z_energy.x, weight_energy_z, energy.x); 00188 cnt++; 00189 if ((cnt % (int) infp.anneal_period) == 0) 00190 temp_ip.eta *= infp.anneal_value; 00191 } 00192 /* TODO: for logging 00193 (:nr-iter-infer:x cnt) 00194 (:exit-condition:x 00195 (if (>= cnt :ip:n) 1 (if (< old-energy (:energy:x)) 3 2))) 00196 (==> logger log-optimal this z) 00197 */ 00198 } 00199 00200 template <class T> 00201 bool codec<T>::check_code_threshold(fstate_idx<T> &z, gd_param &infp) { 00202 return idx_l2norm(z.dx) > infp.gradient_threshold; 00203 } 00204 00206 // codec_lone 00207 00208 template <class T> 00209 codec_lone<T>::codec_lone(module_1_1<T> &encoder_, 00210 module_1_1<T> &decoder_, 00211 double weight_energy_enc_, 00212 double weight_energy_z_, 00213 double weight_energy_dec_, 00214 double thres, 00215 gd_param &infp_) 00216 : codec<T>(encoder_, enc_cost_l2, weight_energy_enc_, 00217 z_cost_l1, weight_energy_z_, 00218 decoder_, dec_cost_l2, weight_energy_dec_, infp), 00219 z_cost_l1(thres) { 00220 } 00221 00222 template <class T> 00223 codec_lone<T>::~codec_lone() { 00224 } 00225 00226 } // end namespace ebl