libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_energy.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet *
00003  *   pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_ENERGY_HPP_
00034 #define EBL_ENERGY_HPP_
00035 
00036 namespace ebl {
00037 
00039   // l2_energy
00040   
00041   template <typename T, class Tstate>
00042   l2_energy<T,Tstate>::l2_energy(const char *name_) : ebm_2<Tstate>(name_) { 
00043   }
00044   
00045   template <typename T, class Tstate>
00046   l2_energy<T,Tstate>::~l2_energy() { 
00047   }
00048   
00049   template <typename T, class Tstate>
00050   void l2_energy<T,Tstate>::fprop(Tstate &in1, Tstate &in2, Tstate &energy) {
00051     idx_sqrdist(in1.x, in2.x, energy.x); // squared distance between in1 and in2
00052     idx_dotc(energy.x, 0.5, energy.x); // multiply by .5
00053   }
00054   
00055   template <typename T, class Tstate>
00056   void l2_energy<T,Tstate>::bprop(Tstate &in1, Tstate &in2, Tstate &energy) {
00057     idx_checkorder1(energy.x, 0); // energy.x must have an order of 0
00058     idx_sub(in1.x, in2.x, in1.dx); // derivative with respect to in1
00059     idx_dotc(in1.dx, energy.dx.get(), in1.dx); // multiply by energy derivative
00060     idx_minus(in1.dx, in2.dx); // derivative with respect to in2
00061   }
00062   
00063   template <typename T, class Tstate>
00064   void l2_energy<T,Tstate>::bbprop(Tstate &in1, Tstate &in2, Tstate &energy) {
00065     idx_addc(in1.ddx, energy.dx.get(), in1.ddx);
00066     idx_addc(in2.ddx, energy.dx.get(), in2.ddx);
00067   }
00068   
00069   template <typename T, class Tstate>
00070   void l2_energy<T,Tstate>::infer2_copy(Tstate &in1, Tstate &in2,
00071                                         Tstate &energy) {
00072     idx_copy(in1.x, in2.x);
00073     idx_clear(energy.x);
00074   }
00075     
00076   template <typename T, class Tstate>
00077   std::string l2_energy<T,Tstate>::describe() {
00078     std::string s;
00079     s << "energy " << this->name()
00080       << " is the euclidean distance between inputs";
00081     return s;
00082   }
00083     
00085   // l1_penalty
00086   
00087   template <typename T, class Tstate>
00088   l1_penalty<T,Tstate>::l1_penalty(T t, T c) : threshold(t), coeff(c) { 
00089   }
00090   
00091   template <typename T, class Tstate>
00092   l1_penalty<T,Tstate>::~l1_penalty() { 
00093   }
00094   
00095   template <typename T, class Tstate>
00096   void l1_penalty<T,Tstate>::fprop(Tstate &in, Tstate &energy) {
00097     idx_sumabs(in.x, energy.x.idx_ptr());
00098     energy.x.set(energy.x.get() * coeff);
00099   }
00100   
00101   template <typename T, class Tstate>
00102   void l1_penalty<T,Tstate>::bprop(Tstate &in, Tstate &energy) {
00103     idx_thresdotc_acc(in.x, energy.dx.get() * coeff, threshold, in.dx);
00104   }
00105   
00106   template <typename T, class Tstate>
00107   void l1_penalty<T,Tstate>::bbprop(Tstate &in, Tstate &energy) {
00108     idx_addc(in.ddx, energy.ddx.get() * coeff * coeff, in.ddx);
00109   }
00110 
00111   template <typename T, class Tstate>
00112   std::string l1_penalty<T,Tstate>::describe() {
00113     std::string s;
00114     s << "l1 penalty " << this->name()
00115       << " with threshold " << threshold << " and coefficient " << coeff;
00116     return s;
00117   }
00118   
00120   // cross_entropy_energy
00121   
00122   template <typename T, class Tstate>
00123   cross_entropy_energy<T,Tstate>::cross_entropy_energy(const char *name_)
00124   : ebm_2<Tstate>(name_) { 
00125   }
00126   
00127   template <typename T, class Tstate>
00128   cross_entropy_energy<T,Tstate>::~cross_entropy_energy() { 
00129   }
00130   
00131   template <typename T, class Tstate>
00132   void cross_entropy_energy<T,Tstate>::fprop(Tstate &in1, Tstate &in2, Tstate &energy) {
00133     idx_sqrdist(in1.x, in2.x, energy.x); // squared distance between in1 and in2
00134     idx_dotc(energy.x, 0.5, energy.x); // multiply by .5
00135   }
00136   
00137   template <typename T, class Tstate>
00138   void cross_entropy_energy<T,Tstate>::bprop(Tstate &in1, Tstate &in2, Tstate &energy) {
00139     idx_checkorder1(energy.x, 0); // energy.x must have an order of 0
00140     idx_sub(in1.x, in2.x, in1.dx); // derivative with respect to in1
00141     idx_dotc(in1.dx, energy.dx.get(), in1.dx); // multiply by energy derivative
00142     idx_minus(in1.dx, in2.dx); // derivative with respect to in2
00143   }
00144   
00145   template <typename T, class Tstate>
00146   void cross_entropy_energy<T,Tstate>::bbprop(Tstate &in1, Tstate &in2, Tstate &energy) {
00147     idx_addc(in1.ddx, energy.dx.get(), in1.ddx);
00148     idx_addc(in2.ddx, energy.dx.get(), in2.ddx);
00149   }
00150   
00151   template <typename T, class Tstate>
00152   void cross_entropy_energy<T,Tstate>::infer2_copy(Tstate &in1, Tstate &in2,
00153                                         Tstate &energy) {
00154     idx_copy(in1.x, in2.x);
00155     idx_clear(energy.x);
00156   }
00157     
00158   template <typename T, class Tstate>
00159   std::string cross_entropy_energy<T,Tstate>::describe() {
00160     std::string s;
00161     s << "energy " << this->name()
00162       << " is the euclidean distance between inputs";
00163     return s;
00164   }
00165     
00167   // scalerclass_energy
00168 
00169   template <typename T, class Tstate>
00170   scalerclass_energy<T,Tstate>::
00171   scalerclass_energy(bool apply_tanh_, uint jsize_, uint jselection,
00172                      float dist_coeff_, float scale_coeff_,
00173                      bool predict_conf_, bool predict_bconf_,
00174                      idx<T> *biases_, idx<T> *coeffs_,
00175                      const char *name_)
00176     : l2_energy<T,Tstate>(name_), jsize(jsize_), apply_tanh(apply_tanh_),
00177       jitter_selection(jselection), dist_coeff(dist_coeff_),
00178       scale_coeff(scale_coeff_), predict_conf(predict_conf_),
00179       predict_bconf(predict_bconf_), biases(NULL), coeffs(NULL) {
00180     if (biases_) biases = new idx<T>(*biases_);
00181     if (coeffs_) coeffs = new idx<T>(*coeffs_);
00182   }
00183 
00184   template <typename T, class Tstate>
00185   scalerclass_energy<T,Tstate>::~scalerclass_energy() {
00186     if (biases) delete biases;
00187     if (coeffs) delete coeffs;
00188   }
00189 
00190   template <typename T, class Tstate>
00191   void scalerclass_energy<T,Tstate>::
00192   fprop(Tstate &in, Tstate &in2, Tstate &energy) {
00193     // determine sizes
00194     int nclass = in.x.dim(0) - jsize;
00195     if (predict_conf) nclass--;
00196     // sanity checks
00197     // if (in.x.get_idxdim() != in2.x.get_idxdim())
00198     //   eblerror("expected same dimensions but got " << in.x << " and " << in2.x);
00199     // narrow inputs for regular l2 energy: class inputs
00200     tmp = in.narrow(0, nclass, 0);
00201     // apply tanh if requested
00202     if (apply_tanh) {
00203       if (tmp.x.get_idxdim() != tmp2.x.get_idxdim())
00204         tmp2 = Tstate(tmp.x.get_idxdim());
00205       mtanh.fprop(tmp, tmp2);
00206       tmp = tmp2;
00207     } else { // if no tanh, cap with -1/1 to avoid penalties beyond these
00208       idx_threshold(tmp.x, (T)-1); // cap below by -1
00209       idx_threshold2(tmp.x, (T)1); // cap above by 1
00210     }
00211     // select jitter target among all possible ones
00212     if (in2.x.dim(0) == 1) { // only 1 possible target
00213       best_target = in2.x.select(0, 0);
00214     } else { // multiple targets
00215       T minscore = limits<T>::max();
00216       switch (jitter_selection) {
00217       case 0: // select highest confidence target
00218         { uint jindex = 0;
00219           idx<T> tgt;
00220           if (predict_conf) { // use predict conf feature
00221             tgt = in2.x.narrow(1, 1, in2.x.dim(1) - 1);
00222             jindex = idx_indexmax(tgt);
00223           } else { // use class target
00224             idx<T> tgt = in2.x.narrow(1, nclass, 0);
00225             uint i = 0;
00226             T max_val = limits<T>::min();
00227             { idx_bloop1(t, tgt, T) {
00228                 T val = idx_max(t);
00229                 if (val > max_val) {
00230                   max_val = val;
00231                   jindex = i;
00232                 }
00233                 i++;
00234               }}
00235           }
00236           // select the highest confidence target
00237           best_target = in2.x.select(0, jindex);
00238         }
00239         break ;
00240       case 1: // select closest to center and scale 1
00241         // loop on all possible jitter
00242         { idx_bloop1(tgt, in2.x, T) {
00243             T s1 = tgt.gget(nclass);
00244             T h1 = tgt.gget(nclass + 1);
00245             T w1 = tgt.gget(nclass + 2);
00246             T score = fabs(s1 - 1) * scale_coeff // ~[.8,2.0]
00247               + sqrt(h1 * h1 + w1 * w1) * dist_coeff;
00248             if (score < minscore) { // we found a better match
00249               minscore = score;
00250               best_target = tgt;
00251             }
00252           }}
00253         break ;
00254       case 2: // select closest to current answer
00255         { T s = in.x.gget(nclass); // predicted scale
00256           T h = in.x.gget(nclass + 1); // predicted h
00257           T w = in.x.gget(nclass + 2); // predicted w
00258           // loop on all possible jitter
00259           { idx_bloop1(tgt, in2.x, T) {
00260               T s1 = tgt.gget(nclass);
00261               T h1 = tgt.gget(nclass + 1);
00262               T w1 = tgt.gget(nclass + 2);
00263               T score = sqrt((h1 - h) * (h1 - h) + (w1 - w) * (w1 - w))
00264                 * dist_coeff + fabs(s1 - s) * scale_coeff;
00265               if (score < minscore) { // we found a better match
00266                 minscore = score;
00267                 best_target = tgt;
00268               }
00269             }}
00270         }
00271         break ;
00272       default:
00273         eblerror("unknown selection mode " << jitter_selection);
00274       }
00275     }
00276     // resize target buffer
00277     idxdim d(best_target.get_idxdim());
00278     if (last_target.x.get_idxdim() != d) {
00279       if (last_target.x.order() != d.order()) {
00280         last_target = Tstate(d);
00281         last_target_raw = idx<T>(d);
00282       } else {
00283         last_target.resize(d);
00284         last_target_raw.resize(d);
00285       }
00286       last_class_target = last_target.narrow(0, nclass,0);
00287       last_jitt_target = last_target.narrow(0, jsize, nclass);
00288       if (predict_conf)
00289         last_conf_target = last_target.narrow(0, 1, in.x.dim(0) - 1);
00290     }
00291     idx_copy(best_target, last_target.x);
00292     // make confidence target binary if required
00293     uint conf_offset = idx_indexmax(last_class_target.x);
00294     if (predict_conf)
00295       conf_offset = in.x.dim(0) - 1;
00296     if (predict_bconf) { // make confidence binary (0, 1)
00297       if (last_target.x.gget(conf_offset) > .5)
00298         last_target.x.sset((T)1, conf_offset);
00299       else
00300         last_target.x.sset((T)0, conf_offset);
00301     }
00302     // save raw target
00303     idx_copy(last_target.x, last_target_raw);
00304     T s = last_target_raw.gget(nclass); // scale target
00305     // normalize jitt with bias then coeff
00306     if (biases) {
00307       idx<T> tmpbias = biases->narrow(0, jsize, 0);
00308       idx_add(last_jitt_target.x, tmpbias, last_jitt_target.x);
00309     }
00310     if (coeffs) {
00311       idx<T> tmpcoeff = coeffs->narrow(0, jsize, 0);
00312       idx_mul(last_jitt_target.x, tmpcoeff, last_jitt_target.x);
00313     }
00314     // normalize prediction with bias then coeff, only if using extra component
00315     // (otherwise, target uses the full -1,1 range already
00316     if (predict_conf) {
00317       if (biases) {
00318         idx<T> tmpbias = biases->narrow(0, 1, jsize);
00319         idx_add(last_conf_target.x, tmpbias, last_conf_target.x);
00320       }
00321       if (coeffs) {
00322         idx<T> tmpcoeff = coeffs->narrow(0, 1, jsize);
00323         idx_mul(last_conf_target.x, tmpcoeff, last_conf_target.x);
00324       }
00325     }
00326     // l2 energy
00327     l2_energy<T,Tstate>::fprop(tmp, last_class_target, energy);
00328     // energy of scale component
00329     T e = 0;
00330     // penalize quadraticaly only if scale is > 0
00331     if (s > 0 && last_target_raw.gget(conf_offset) > .5) {
00332       // narrow inputs for jitter energy
00333       tmp = in.narrow(0, jsize, nclass);
00334       e = .5 * idx_sqrdist(tmp.x, last_jitt_target.x);
00335       energy.x.set(energy.x.get() + e);
00336     }
00337     // penalize predicted confidence only if positive (i.e. scale > 0)
00338     if (predict_conf && s > 0) {
00339       tmp = in.narrow(0, 1, conf_offset);
00340       // cap below by 0 and above by 1 (or corresponding normalized values)
00341       T low = 0, high = 1;
00342       if (biases) { low += biases->gget(3); high += biases->gget(3); }
00343       if (coeffs) { low *= biases->gget(3); high *= biases->gget(3); }
00344       idx_threshold(tmp.x, low);
00345       idx_threshold2(tmp.x, high);
00346       e = .5 * idx_sqrdist(tmp.x, last_conf_target.x);
00347       energy.x.set(energy.x.get() + e);
00348     }
00349     EDEBUG("energy: " << energy.x.get() << " in: " << in.x.str() << " norm tgt: "
00350           << last_target.x.str() << " raw tgt: " << last_target_raw.str()
00351           << " conf penalized: " << ((predict_conf && s > 0) ? "yes":"no")
00352           << " jitt penalized: "
00353           << ((s > 0 && last_target_raw.gget(conf_offset) > .5) ? "yes":"no"));
00354   }
00355   
00356   template <typename T, class Tstate>
00357   void scalerclass_energy<T,Tstate>::
00358   bprop(Tstate &in, Tstate &in2, Tstate &energy) {
00359     idx_checkorder1(energy.x, 0); // energy.x must have an order of 0
00360     int nclass = in.x.dim(0) - jsize;
00361     if (predict_conf) nclass--;
00362     uint conf_offset = idx_indexmax(last_class_target.x);
00363     if (predict_conf)
00364       conf_offset = in.x.dim(0) - 1;
00365     // narrow inputs for regular l2 energy
00366     tmp = in.narrow(0, nclass, 0);
00367     l2_energy<T,Tstate>::bprop(tmp, last_class_target, energy);
00368     // get values
00369     T s = last_target_raw.gget(nclass);
00370     // penalize quadraticaly only if scale is > 0
00371     if (s > 0 && last_target_raw.gget(conf_offset) > .5) {
00372       // narrow inputs for jitter energy
00373       tmp = in.narrow(0, jsize, nclass);
00374       idx_sub(tmp.x, last_jitt_target.x, tmp.dx); // derivative w.r.t in1
00375       idx_dotc(tmp.dx, energy.dx.get(), tmp.dx);// multiply by energy derivative
00376     }
00377     // penalize predicted confidence only if positive (i.e. scale > 0)
00378     if (predict_conf && s > 0) {
00379       tmp = in.narrow(0, 1, conf_offset);
00380       idx_sub(tmp.x, last_conf_target.x, tmp.dx);
00381       idx_dotc(tmp.dx, energy.dx.get(), tmp.dx);// multiply by energy derivative
00382     }
00383   }
00384 
00385   template <typename T, class Tstate>
00386   void scalerclass_energy<T,Tstate>::
00387   bbprop(Tstate &in, Tstate &in2, Tstate &energy) {
00388     last_target.clear_ddx();
00389     // derivatives are all the same for everybody
00390     l2_energy<T,Tstate>::bbprop(in, last_target, energy);
00391   }
00392 
00393   template <typename T, class Tstate>
00394   void scalerclass_energy<T,Tstate>::infer2(Tstate &in, Tstate &in2,
00395                                             infer_param &ip, Tstate *energy) {
00396     idx_copy(in.x, in2.x);
00397     idx_clear(energy->x);
00398   }
00399   
00400   template <typename T, class Tstate>
00401   std::string scalerclass_energy<T,Tstate>::describe() {
00402     std::string s;
00403     s << "energy " << this->name()
00404       << " is the squared distance with target for class components and the "
00405       << "squared distance to the scale component when scale > 0 and jsize "
00406       << jsize;
00407     if (predict_conf)
00408       s << ", predicting confidence";
00409     s << ", target confidence is " << (predict_bconf?"binary":"continuous");
00410     s << ", biases: ";
00411     if (biases)
00412       biases->printElems(s);
00413     else
00414       s <<"none";
00415     s << ", coeffs: ";
00416     if (coeffs)
00417       coeffs->printElems(s);
00418     else
00419       s <<"none";
00420     return s;
00421   }
00422     
00424   // scaler_energy
00425 
00426   template <typename T, class Tstate>
00427   scaler_energy<T,Tstate>::scaler_energy(const char *name_)
00428     : ebm_2<Tstate>(name_) {
00429   }
00430 
00431   template <typename T, class Tstate>
00432   scaler_energy<T,Tstate>::~scaler_energy() {
00433   }
00434 
00435   template <typename T, class Tstate>
00436   void scaler_energy<T,Tstate>::
00437   fprop(Tstate &in, Tstate &in2, Tstate &energy) {
00438     // sanity checks
00439     idx_checknelems2_all(in, in2);
00440     // get values
00441     T i = in.x.gget();
00442     T s = in2.x.gget();
00443     T e = 0;
00444     // no scale case: penalize quadraticaly only if above -1
00445     if (s == 0) {
00446       e = std::max((T) 0, i + 1);
00447       e = e * e * .5;
00448     } else // positive case: penalize quadraticaly with distance to value
00449       e = .5 * idx_sqrdist(in.x, in2.x);
00450     energy.x.set(e);
00451   }
00452   
00453   template <typename T, class Tstate>
00454   void scaler_energy<T,Tstate>::
00455   bprop(Tstate &in, Tstate &in2, Tstate &energy) {
00456     idx_checkorder1(energy.x, 0); // energy.x must have an order of 0
00457     // get values
00458     T i = in.x.gget();
00459     T s = in2.x.gget();
00460     // no scale case: penalize quadraticaly only if above -1
00461     if (s == 0) {
00462       in.dx.sset(std::max((T) 0, i + 1));
00463       in.dx.sset(in.dx.gget() * energy.dx.get());
00464     } else { // scale case: penalize quadraticaly with distance to value
00465       idx_sub(in.x, in2.x, in.dx);
00466       idx_dotc(in.dx, energy.dx.get(), in.dx); // multiply by energy derivative
00467     }
00468   }
00469 
00470   template <typename T, class Tstate>
00471   void scaler_energy<T,Tstate>::
00472   bbprop(Tstate &in1, Tstate &in2, Tstate &energy) {
00473     idx_addc(in1.ddx, energy.dx.get(), in1.ddx);
00474   }
00475 
00476   template <typename T, class Tstate>
00477   void scaler_energy<T,Tstate>::infer2(Tstate &in, Tstate &scale,
00478                                        infer_param &ip, Tstate *energy) {
00479     T i = in.x.gget();
00480     if (i <= 0) // negative class
00481       scale.x.sset((T) 0);
00482     else // positive class
00483       scale.x.sset(i);
00484   }
00485   
00486   template <typename T, class Tstate>
00487   std::string scaler_energy<T,Tstate>::describe() {
00488     std::string s;
00489     s << "energy " << this->name()
00490       << " is the squared distance to -1 when input is > -1 for the "
00491       << "negative class and the squared distance to scale target for the "
00492       << "positive class";
00493     return s;
00494   }
00495     
00496 } // end namespace ebl
00497 
00498 #endif /* EBL_ENERGY_HPP */