libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_codec.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun, Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *
00005  * Redistribution and use in source and binary forms, with or without
00006  * modification, are permitted provided that the following conditions are met:
00007  *     * Redistributions of source code must retain the above copyright
00008  *       notice, this list of conditions and the following disclaimer.
00009  *     * Redistributions in binary form must reproduce the above copyright
00010  *       notice, this list of conditions and the following disclaimer in the
00011  *       documentation and/or other materials provided with the distribution.
00012  *     * Redistribution under a license not approved by the Open Source
00013  *       Initiative (http://www.opensource.org) must display the
00014  *       following acknowledgement in all advertising material:
00015  *        This product includes software developed at the Courant
00016  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00017  *     * The names of the authors may not be used to endorse or promote products
00018  *       derived from this software without specific prior written permission.
00019  *
00020  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00021  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00022  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00023  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00024  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00025  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00026  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00027  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00028  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00029  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00030  ***************************************************************************/
00031 
00032 namespace ebl {
00033 
00035   // codec
00036 
00037   template <class T>
00038   codec<T>::codec(module_1_1<T> &encoder_,
00039                   ebm_2<T>      &enc_cost_,
00040                   double        weight_energy_enc_,
00041                   ebm_1<T>      &z_cost_,
00042                   double        weight_energy_z_,
00043                   module_1_1<T> &decoder_,
00044                   ebm_2<T>      &dec_cost_,
00045                   double        weight_energy_dec_,
00046                   gd_param      &infp_)
00047     : encoder(encoder_), enc_out(1, 1, 1), enc_cost(enc_cost_),
00048       weight_energy_enc(weight_energy_enc_), enc_energy(),
00049       z(1, 1, 1), z_cost(z_cost_),
00050       weight_energy_z(weight_energy_z_), z_energy(),
00051       decoder(decoder_), dec_out(1, 1, 1), dec_cost(dec_cost_),
00052       weight_energy_dec(weight_energy_dec_), dec_energy(),
00053       infp(infp_) {
00054   }
00055 
00056   template <class T>
00057   codec<T>::~codec() {
00058   }
00059 
00060   template <class T>
00061   void codec<T>::fprop(fstate_idx<T> &in1, fstate_idx<T> &in2,
00062                        fstate_idx<T> &energy) {
00063     // initialize z with a simple one-pass fprop through whole machine
00064     fprop_one_pass(in1, in2, energy);
00065   }
00066   
00067   template <class T>
00068   void codec<T>::bprop(fstate_idx<T> &in1, fstate_idx<T> &in2,
00069                        fstate_idx<T> &energy) {
00070     // do gradient descent to find optimal code z
00071     bprop_optimal_code(in1, in2, energy, infp);
00072     // bprop through all modules
00073     bprop_one_pass(in1, in2, energy);
00074   }
00075 
00076   template <class T>
00077   void codec<T>::bbprop(fstate_idx<T> &in1, fstate_idx<T> &in2,
00078                         fstate_idx<T> &energy) {
00079     enc_out.clear_ddx();
00080     dec_out.clear_ddx();
00081     // initialize all energy 2nd derivatives with global energy derivative
00082     // so that we minimize the global cost function
00083     idx_dotc(energy.ddx, weight_energy_dec, dec_energy.ddx);
00084     idx_dotc(energy.ddx, weight_energy_enc, enc_energy.ddx);
00085     idx_dotc(energy.ddx, weight_energy_z, z_energy.ddx);
00086     // bprop through cost modules
00087     z_cost.bbprop(z, z_energy);
00088     enc_cost.bbprop(enc_out, z, enc_energy);
00089     dec_cost.bbprop(dec_out, in2, dec_energy);
00090     // bprop through encoder/decoder
00091     decoder.bbprop(z, dec_out);
00092     encoder.bbprop(in1, enc_out);
00093   }
00094 
00095   template <class T>
00096   void codec<T>::forget(forget_param_linear &fp) {
00097     encoder.forget(fp);
00098     decoder.forget(fp);
00099     enc_cost.forget(fp);
00100     dec_cost.forget(fp);
00101     z_cost.forget(fp);
00102     normalize();
00103   }
00104 
00105   template <class T>
00106   void codec<T>::normalize() {
00107     decoder.normalize();
00108   }
00109 
00110   // simple one-pass forward propagation
00111   template <class T>
00112   void codec<T>::fprop_one_pass(fstate_idx<T> &in1, fstate_idx<T> &in2, 
00113                                 fstate_idx<T> &energy) {
00114     encoder.fprop(in1, enc_out);
00115     // let the enc-cost produce its best guess
00116     // for what the code should be. If the cost
00117     // is an l2-distance, this will simply 
00118     // copy :enc-out:x into :z:x
00119     // There is no need to do an fprop of
00120     // the encoder afterward, because infer2 does it.
00121     enc_cost.infer2_copy(enc_out, z, enc_energy);
00122     // compute cost penalty
00123     z_cost.fprop(z, z_energy);
00124     // fprop through decoder
00125     decoder.fprop(z, dec_out);
00126     // fprop through decoder cost.
00127     dec_cost.fprop(dec_out, in2, dec_energy);
00128     // add up energy terms
00129     energy.clear();
00130     idx_dotcacc(enc_energy.x, weight_energy_enc, energy.x);
00131     idx_dotcacc(dec_energy.x, weight_energy_dec, energy.x);
00132     idx_dotcacc(z_energy.x, weight_energy_z, energy.x);
00133   }
00134   
00135   // simple one-pass backward propagation
00136   template <class T>
00137   void codec<T>::bprop_one_pass(fstate_idx<T> &in1, fstate_idx<T> &in2, 
00138                                 fstate_idx<T> &energy) {
00139     enc_out.clear_dx();
00140     dec_out.clear_dx();
00141     // initialize all energy derivatives with global energy derivative
00142     // so that we minimize the global cost function
00143     idx_dotc(energy.dx, weight_energy_dec, dec_energy.dx);
00144     idx_dotc(energy.dx, weight_energy_enc, enc_energy.dx);
00145     idx_dotc(energy.dx, weight_energy_z, z_energy.dx);
00146     // bprop through cost modules
00147     z_cost.bprop(z, z_energy);
00148     enc_cost.bprop(enc_out, z, enc_energy);
00149     dec_cost.bprop(dec_out, in2, dec_energy);
00150     // bprop through encoder/decoder
00151     decoder.bprop(z, dec_out);
00152     encoder.bprop(in1, enc_out);
00153   }
00154   
00155   // multiple-pass bprop on the decoder only to find the optimal code z
00156   template <class T>
00157   void codec<T>::bprop_optimal_code(fstate_idx<T> &in1, fstate_idx<T> &in2, 
00158                                     fstate_idx<T> &energy, gd_param &infp) {
00159     z.clear_dx();
00160     bprop(in1, in2, energy); // bprop once to initialize energy
00161     gd_param temp_ip(infp.eta, 0, 0, 0, 0, 0, 0, 0, 0); 
00162     double old_energy = energy.x.get() + 1;
00163     int cnt = 0;
00164     while ((cnt < infp.n)
00165            && check_code_threshold(z, infp)
00166            && (old_energy > energy.x.get())) {
00167       old_energy = energy.x.get();
00168       z.clear_dx();
00169       // bprop through decoder /////////////////////////////////////
00170       dec_out.clear_dx();
00171       enc_out.clear_dx();
00172       idx_dotc(energy.dx, weight_energy_dec, dec_energy.dx);
00173       idx_dotc(energy.dx, weight_energy_z, z_energy.dx);
00174       z_cost.bprop(z, z_energy);
00175       enc_cost.bprop(enc_out, z, enc_energy);
00176       dec_cost.bprop(dec_out, in2, dec_energy);
00177       decoder.bprop(z, dec_out);
00178       z.update_gd(temp_ip);
00179       // now fprop through decoder /////////////////////////////////
00180       decoder.fprop(z, dec_out);
00181       z_cost.fprop(z, z_energy);
00182       enc_cost.fprop(enc_out, z, enc_energy);
00183       dec_cost.fprop(dec_out, in2, dec_energy);
00184       // add up energy terms ///////////////////////////////////////
00185       idx_dotcacc(enc_energy.x, weight_energy_enc, energy.x);
00186       idx_dotcacc(dec_energy.x, weight_energy_dec, energy.x);
00187       idx_dotcacc(z_energy.x, weight_energy_z, energy.x);
00188       cnt++;
00189       if ((cnt % (int) infp.anneal_period) == 0)
00190         temp_ip.eta *= infp.anneal_value;
00191     }
00192     /* TODO: for logging
00193        (:nr-iter-infer:x cnt) 
00194        (:exit-condition:x
00195        (if (>= cnt :ip:n) 1 (if (< old-energy (:energy:x)) 3 2)))
00196        (==> logger log-optimal this z)
00197     */
00198   }
00199 
00200   template <class T>
00201   bool codec<T>::check_code_threshold(fstate_idx<T> &z, gd_param &infp) {
00202     return idx_l2norm(z.dx) > infp.gradient_threshold;
00203   }
00204 
00206   // codec_lone
00207 
00208   template <class T>
00209   codec_lone<T>::codec_lone(module_1_1<T> &encoder_,
00210                             module_1_1<T> &decoder_,
00211                             double weight_energy_enc_,
00212                             double weight_energy_z_,
00213                             double      weight_energy_dec_,
00214                             double thres,
00215                             gd_param &infp_)
00216     : codec<T>(encoder_, enc_cost_l2, weight_energy_enc_,
00217                z_cost_l1, weight_energy_z_,
00218                decoder_, dec_cost_l2, weight_energy_dec_, infp),
00219       z_cost_l1(thres) {
00220   }
00221 
00222   template <class T>
00223   codec_lone<T>::~codec_lone() {
00224   }
00225 
00226 } // end namespace ebl