libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_ENERGY_HPP_ 00034 #define EBL_ENERGY_HPP_ 00035 00036 namespace ebl { 00037 00039 // l2_energy 00040 00041 template <typename T, class Tstate> 00042 l2_energy<T,Tstate>::l2_energy(const char *name_) : ebm_2<Tstate>(name_) { 00043 } 00044 00045 template <typename T, class Tstate> 00046 l2_energy<T,Tstate>::~l2_energy() { 00047 } 00048 00049 template <typename T, class Tstate> 00050 void l2_energy<T,Tstate>::fprop(Tstate &in1, Tstate &in2, Tstate &energy) { 00051 idx_sqrdist(in1.x, in2.x, energy.x); // squared distance between in1 and in2 00052 idx_dotc(energy.x, 0.5, energy.x); // multiply by .5 00053 } 00054 00055 template <typename T, class Tstate> 00056 void l2_energy<T,Tstate>::bprop(Tstate &in1, Tstate &in2, Tstate &energy) { 00057 idx_checkorder1(energy.x, 0); // energy.x must have an order of 0 00058 idx_sub(in1.x, in2.x, in1.dx); // derivative with respect to in1 00059 idx_dotc(in1.dx, energy.dx.get(), in1.dx); // multiply by energy derivative 00060 idx_minus(in1.dx, in2.dx); // derivative with respect to in2 00061 } 00062 00063 template <typename T, class Tstate> 00064 void l2_energy<T,Tstate>::bbprop(Tstate &in1, Tstate &in2, Tstate &energy) { 00065 idx_addc(in1.ddx, energy.dx.get(), in1.ddx); 00066 idx_addc(in2.ddx, energy.dx.get(), in2.ddx); 00067 } 00068 00069 template <typename T, class Tstate> 00070 void l2_energy<T,Tstate>::infer2_copy(Tstate &in1, Tstate &in2, 00071 Tstate &energy) { 00072 idx_copy(in1.x, in2.x); 00073 idx_clear(energy.x); 00074 } 00075 00076 template <typename T, class Tstate> 00077 std::string l2_energy<T,Tstate>::describe() { 00078 std::string s; 00079 s << "energy " << this->name() 00080 << " is the euclidean distance between inputs"; 00081 return s; 00082 } 00083 00085 // l1_penalty 00086 00087 template <typename T, class Tstate> 00088 l1_penalty<T,Tstate>::l1_penalty(T t, T c) : threshold(t), coeff(c) { 00089 } 00090 00091 template <typename T, class Tstate> 00092 l1_penalty<T,Tstate>::~l1_penalty() { 00093 } 00094 00095 template <typename T, class Tstate> 00096 void l1_penalty<T,Tstate>::fprop(Tstate &in, Tstate &energy) { 00097 idx_sumabs(in.x, energy.x.idx_ptr()); 00098 energy.x.set(energy.x.get() * coeff); 00099 } 00100 00101 template <typename T, class Tstate> 00102 void l1_penalty<T,Tstate>::bprop(Tstate &in, Tstate &energy) { 00103 idx_thresdotc_acc(in.x, energy.dx.get() * coeff, threshold, in.dx); 00104 } 00105 00106 template <typename T, class Tstate> 00107 void l1_penalty<T,Tstate>::bbprop(Tstate &in, Tstate &energy) { 00108 idx_addc(in.ddx, energy.ddx.get() * coeff * coeff, in.ddx); 00109 } 00110 00111 template <typename T, class Tstate> 00112 std::string l1_penalty<T,Tstate>::describe() { 00113 std::string s; 00114 s << "l1 penalty " << this->name() 00115 << " with threshold " << threshold << " and coefficient " << coeff; 00116 return s; 00117 } 00118 00120 // cross_entropy_energy 00121 00122 template <typename T, class Tstate> 00123 cross_entropy_energy<T,Tstate>::cross_entropy_energy(const char *name_) 00124 : ebm_2<Tstate>(name_) { 00125 } 00126 00127 template <typename T, class Tstate> 00128 cross_entropy_energy<T,Tstate>::~cross_entropy_energy() { 00129 } 00130 00131 template <typename T, class Tstate> 00132 void cross_entropy_energy<T,Tstate>::fprop(Tstate &in1, Tstate &in2, Tstate &energy) { 00133 idx_sqrdist(in1.x, in2.x, energy.x); // squared distance between in1 and in2 00134 idx_dotc(energy.x, 0.5, energy.x); // multiply by .5 00135 } 00136 00137 template <typename T, class Tstate> 00138 void cross_entropy_energy<T,Tstate>::bprop(Tstate &in1, Tstate &in2, Tstate &energy) { 00139 idx_checkorder1(energy.x, 0); // energy.x must have an order of 0 00140 idx_sub(in1.x, in2.x, in1.dx); // derivative with respect to in1 00141 idx_dotc(in1.dx, energy.dx.get(), in1.dx); // multiply by energy derivative 00142 idx_minus(in1.dx, in2.dx); // derivative with respect to in2 00143 } 00144 00145 template <typename T, class Tstate> 00146 void cross_entropy_energy<T,Tstate>::bbprop(Tstate &in1, Tstate &in2, Tstate &energy) { 00147 idx_addc(in1.ddx, energy.dx.get(), in1.ddx); 00148 idx_addc(in2.ddx, energy.dx.get(), in2.ddx); 00149 } 00150 00151 template <typename T, class Tstate> 00152 void cross_entropy_energy<T,Tstate>::infer2_copy(Tstate &in1, Tstate &in2, 00153 Tstate &energy) { 00154 idx_copy(in1.x, in2.x); 00155 idx_clear(energy.x); 00156 } 00157 00158 template <typename T, class Tstate> 00159 std::string cross_entropy_energy<T,Tstate>::describe() { 00160 std::string s; 00161 s << "energy " << this->name() 00162 << " is the euclidean distance between inputs"; 00163 return s; 00164 } 00165 00167 // scalerclass_energy 00168 00169 template <typename T, class Tstate> 00170 scalerclass_energy<T,Tstate>:: 00171 scalerclass_energy(bool apply_tanh_, uint jsize_, uint jselection, 00172 float dist_coeff_, float scale_coeff_, 00173 bool predict_conf_, bool predict_bconf_, 00174 idx<T> *biases_, idx<T> *coeffs_, 00175 const char *name_) 00176 : l2_energy<T,Tstate>(name_), jsize(jsize_), apply_tanh(apply_tanh_), 00177 jitter_selection(jselection), dist_coeff(dist_coeff_), 00178 scale_coeff(scale_coeff_), predict_conf(predict_conf_), 00179 predict_bconf(predict_bconf_), biases(NULL), coeffs(NULL) { 00180 if (biases_) biases = new idx<T>(*biases_); 00181 if (coeffs_) coeffs = new idx<T>(*coeffs_); 00182 } 00183 00184 template <typename T, class Tstate> 00185 scalerclass_energy<T,Tstate>::~scalerclass_energy() { 00186 if (biases) delete biases; 00187 if (coeffs) delete coeffs; 00188 } 00189 00190 template <typename T, class Tstate> 00191 void scalerclass_energy<T,Tstate>:: 00192 fprop(Tstate &in, Tstate &in2, Tstate &energy) { 00193 // determine sizes 00194 int nclass = in.x.dim(0) - jsize; 00195 if (predict_conf) nclass--; 00196 // sanity checks 00197 // if (in.x.get_idxdim() != in2.x.get_idxdim()) 00198 // eblerror("expected same dimensions but got " << in.x << " and " << in2.x); 00199 // narrow inputs for regular l2 energy: class inputs 00200 tmp = in.narrow(0, nclass, 0); 00201 // apply tanh if requested 00202 if (apply_tanh) { 00203 if (tmp.x.get_idxdim() != tmp2.x.get_idxdim()) 00204 tmp2 = Tstate(tmp.x.get_idxdim()); 00205 mtanh.fprop(tmp, tmp2); 00206 tmp = tmp2; 00207 } else { // if no tanh, cap with -1/1 to avoid penalties beyond these 00208 idx_threshold(tmp.x, (T)-1); // cap below by -1 00209 idx_threshold2(tmp.x, (T)1); // cap above by 1 00210 } 00211 // select jitter target among all possible ones 00212 if (in2.x.dim(0) == 1) { // only 1 possible target 00213 best_target = in2.x.select(0, 0); 00214 } else { // multiple targets 00215 T minscore = limits<T>::max(); 00216 switch (jitter_selection) { 00217 case 0: // select highest confidence target 00218 { uint jindex = 0; 00219 idx<T> tgt; 00220 if (predict_conf) { // use predict conf feature 00221 tgt = in2.x.narrow(1, 1, in2.x.dim(1) - 1); 00222 jindex = idx_indexmax(tgt); 00223 } else { // use class target 00224 idx<T> tgt = in2.x.narrow(1, nclass, 0); 00225 uint i = 0; 00226 T max_val = limits<T>::min(); 00227 { idx_bloop1(t, tgt, T) { 00228 T val = idx_max(t); 00229 if (val > max_val) { 00230 max_val = val; 00231 jindex = i; 00232 } 00233 i++; 00234 }} 00235 } 00236 // select the highest confidence target 00237 best_target = in2.x.select(0, jindex); 00238 } 00239 break ; 00240 case 1: // select closest to center and scale 1 00241 // loop on all possible jitter 00242 { idx_bloop1(tgt, in2.x, T) { 00243 T s1 = tgt.gget(nclass); 00244 T h1 = tgt.gget(nclass + 1); 00245 T w1 = tgt.gget(nclass + 2); 00246 T score = fabs(s1 - 1) * scale_coeff // ~[.8,2.0] 00247 + sqrt(h1 * h1 + w1 * w1) * dist_coeff; 00248 if (score < minscore) { // we found a better match 00249 minscore = score; 00250 best_target = tgt; 00251 } 00252 }} 00253 break ; 00254 case 2: // select closest to current answer 00255 { T s = in.x.gget(nclass); // predicted scale 00256 T h = in.x.gget(nclass + 1); // predicted h 00257 T w = in.x.gget(nclass + 2); // predicted w 00258 // loop on all possible jitter 00259 { idx_bloop1(tgt, in2.x, T) { 00260 T s1 = tgt.gget(nclass); 00261 T h1 = tgt.gget(nclass + 1); 00262 T w1 = tgt.gget(nclass + 2); 00263 T score = sqrt((h1 - h) * (h1 - h) + (w1 - w) * (w1 - w)) 00264 * dist_coeff + fabs(s1 - s) * scale_coeff; 00265 if (score < minscore) { // we found a better match 00266 minscore = score; 00267 best_target = tgt; 00268 } 00269 }} 00270 } 00271 break ; 00272 default: 00273 eblerror("unknown selection mode " << jitter_selection); 00274 } 00275 } 00276 // resize target buffer 00277 idxdim d(best_target.get_idxdim()); 00278 if (last_target.x.get_idxdim() != d) { 00279 if (last_target.x.order() != d.order()) { 00280 last_target = Tstate(d); 00281 last_target_raw = idx<T>(d); 00282 } else { 00283 last_target.resize(d); 00284 last_target_raw.resize(d); 00285 } 00286 last_class_target = last_target.narrow(0, nclass,0); 00287 last_jitt_target = last_target.narrow(0, jsize, nclass); 00288 if (predict_conf) 00289 last_conf_target = last_target.narrow(0, 1, in.x.dim(0) - 1); 00290 } 00291 idx_copy(best_target, last_target.x); 00292 // make confidence target binary if required 00293 uint conf_offset = idx_indexmax(last_class_target.x); 00294 if (predict_conf) 00295 conf_offset = in.x.dim(0) - 1; 00296 if (predict_bconf) { // make confidence binary (0, 1) 00297 if (last_target.x.gget(conf_offset) > .5) 00298 last_target.x.sset((T)1, conf_offset); 00299 else 00300 last_target.x.sset((T)0, conf_offset); 00301 } 00302 // save raw target 00303 idx_copy(last_target.x, last_target_raw); 00304 T s = last_target_raw.gget(nclass); // scale target 00305 // normalize jitt with bias then coeff 00306 if (biases) { 00307 idx<T> tmpbias = biases->narrow(0, jsize, 0); 00308 idx_add(last_jitt_target.x, tmpbias, last_jitt_target.x); 00309 } 00310 if (coeffs) { 00311 idx<T> tmpcoeff = coeffs->narrow(0, jsize, 0); 00312 idx_mul(last_jitt_target.x, tmpcoeff, last_jitt_target.x); 00313 } 00314 // normalize prediction with bias then coeff, only if using extra component 00315 // (otherwise, target uses the full -1,1 range already 00316 if (predict_conf) { 00317 if (biases) { 00318 idx<T> tmpbias = biases->narrow(0, 1, jsize); 00319 idx_add(last_conf_target.x, tmpbias, last_conf_target.x); 00320 } 00321 if (coeffs) { 00322 idx<T> tmpcoeff = coeffs->narrow(0, 1, jsize); 00323 idx_mul(last_conf_target.x, tmpcoeff, last_conf_target.x); 00324 } 00325 } 00326 // l2 energy 00327 l2_energy<T,Tstate>::fprop(tmp, last_class_target, energy); 00328 // energy of scale component 00329 T e = 0; 00330 // penalize quadraticaly only if scale is > 0 00331 if (s > 0 && last_target_raw.gget(conf_offset) > .5) { 00332 // narrow inputs for jitter energy 00333 tmp = in.narrow(0, jsize, nclass); 00334 e = .5 * idx_sqrdist(tmp.x, last_jitt_target.x); 00335 energy.x.set(energy.x.get() + e); 00336 } 00337 // penalize predicted confidence only if positive (i.e. scale > 0) 00338 if (predict_conf && s > 0) { 00339 tmp = in.narrow(0, 1, conf_offset); 00340 // cap below by 0 and above by 1 (or corresponding normalized values) 00341 T low = 0, high = 1; 00342 if (biases) { low += biases->gget(3); high += biases->gget(3); } 00343 if (coeffs) { low *= biases->gget(3); high *= biases->gget(3); } 00344 idx_threshold(tmp.x, low); 00345 idx_threshold2(tmp.x, high); 00346 e = .5 * idx_sqrdist(tmp.x, last_conf_target.x); 00347 energy.x.set(energy.x.get() + e); 00348 } 00349 EDEBUG("energy: " << energy.x.get() << " in: " << in.x.str() << " norm tgt: " 00350 << last_target.x.str() << " raw tgt: " << last_target_raw.str() 00351 << " conf penalized: " << ((predict_conf && s > 0) ? "yes":"no") 00352 << " jitt penalized: " 00353 << ((s > 0 && last_target_raw.gget(conf_offset) > .5) ? "yes":"no")); 00354 } 00355 00356 template <typename T, class Tstate> 00357 void scalerclass_energy<T,Tstate>:: 00358 bprop(Tstate &in, Tstate &in2, Tstate &energy) { 00359 idx_checkorder1(energy.x, 0); // energy.x must have an order of 0 00360 int nclass = in.x.dim(0) - jsize; 00361 if (predict_conf) nclass--; 00362 uint conf_offset = idx_indexmax(last_class_target.x); 00363 if (predict_conf) 00364 conf_offset = in.x.dim(0) - 1; 00365 // narrow inputs for regular l2 energy 00366 tmp = in.narrow(0, nclass, 0); 00367 l2_energy<T,Tstate>::bprop(tmp, last_class_target, energy); 00368 // get values 00369 T s = last_target_raw.gget(nclass); 00370 // penalize quadraticaly only if scale is > 0 00371 if (s > 0 && last_target_raw.gget(conf_offset) > .5) { 00372 // narrow inputs for jitter energy 00373 tmp = in.narrow(0, jsize, nclass); 00374 idx_sub(tmp.x, last_jitt_target.x, tmp.dx); // derivative w.r.t in1 00375 idx_dotc(tmp.dx, energy.dx.get(), tmp.dx);// multiply by energy derivative 00376 } 00377 // penalize predicted confidence only if positive (i.e. scale > 0) 00378 if (predict_conf && s > 0) { 00379 tmp = in.narrow(0, 1, conf_offset); 00380 idx_sub(tmp.x, last_conf_target.x, tmp.dx); 00381 idx_dotc(tmp.dx, energy.dx.get(), tmp.dx);// multiply by energy derivative 00382 } 00383 } 00384 00385 template <typename T, class Tstate> 00386 void scalerclass_energy<T,Tstate>:: 00387 bbprop(Tstate &in, Tstate &in2, Tstate &energy) { 00388 last_target.clear_ddx(); 00389 // derivatives are all the same for everybody 00390 l2_energy<T,Tstate>::bbprop(in, last_target, energy); 00391 } 00392 00393 template <typename T, class Tstate> 00394 void scalerclass_energy<T,Tstate>::infer2(Tstate &in, Tstate &in2, 00395 infer_param &ip, Tstate *energy) { 00396 idx_copy(in.x, in2.x); 00397 idx_clear(energy->x); 00398 } 00399 00400 template <typename T, class Tstate> 00401 std::string scalerclass_energy<T,Tstate>::describe() { 00402 std::string s; 00403 s << "energy " << this->name() 00404 << " is the squared distance with target for class components and the " 00405 << "squared distance to the scale component when scale > 0 and jsize " 00406 << jsize; 00407 if (predict_conf) 00408 s << ", predicting confidence"; 00409 s << ", target confidence is " << (predict_bconf?"binary":"continuous"); 00410 s << ", biases: "; 00411 if (biases) 00412 biases->printElems(s); 00413 else 00414 s <<"none"; 00415 s << ", coeffs: "; 00416 if (coeffs) 00417 coeffs->printElems(s); 00418 else 00419 s <<"none"; 00420 return s; 00421 } 00422 00424 // scaler_energy 00425 00426 template <typename T, class Tstate> 00427 scaler_energy<T,Tstate>::scaler_energy(const char *name_) 00428 : ebm_2<Tstate>(name_) { 00429 } 00430 00431 template <typename T, class Tstate> 00432 scaler_energy<T,Tstate>::~scaler_energy() { 00433 } 00434 00435 template <typename T, class Tstate> 00436 void scaler_energy<T,Tstate>:: 00437 fprop(Tstate &in, Tstate &in2, Tstate &energy) { 00438 // sanity checks 00439 idx_checknelems2_all(in, in2); 00440 // get values 00441 T i = in.x.gget(); 00442 T s = in2.x.gget(); 00443 T e = 0; 00444 // no scale case: penalize quadraticaly only if above -1 00445 if (s == 0) { 00446 e = std::max((T) 0, i + 1); 00447 e = e * e * .5; 00448 } else // positive case: penalize quadraticaly with distance to value 00449 e = .5 * idx_sqrdist(in.x, in2.x); 00450 energy.x.set(e); 00451 } 00452 00453 template <typename T, class Tstate> 00454 void scaler_energy<T,Tstate>:: 00455 bprop(Tstate &in, Tstate &in2, Tstate &energy) { 00456 idx_checkorder1(energy.x, 0); // energy.x must have an order of 0 00457 // get values 00458 T i = in.x.gget(); 00459 T s = in2.x.gget(); 00460 // no scale case: penalize quadraticaly only if above -1 00461 if (s == 0) { 00462 in.dx.sset(std::max((T) 0, i + 1)); 00463 in.dx.sset(in.dx.gget() * energy.dx.get()); 00464 } else { // scale case: penalize quadraticaly with distance to value 00465 idx_sub(in.x, in2.x, in.dx); 00466 idx_dotc(in.dx, energy.dx.get(), in.dx); // multiply by energy derivative 00467 } 00468 } 00469 00470 template <typename T, class Tstate> 00471 void scaler_energy<T,Tstate>:: 00472 bbprop(Tstate &in1, Tstate &in2, Tstate &energy) { 00473 idx_addc(in1.ddx, energy.dx.get(), in1.ddx); 00474 } 00475 00476 template <typename T, class Tstate> 00477 void scaler_energy<T,Tstate>::infer2(Tstate &in, Tstate &scale, 00478 infer_param &ip, Tstate *energy) { 00479 T i = in.x.gget(); 00480 if (i <= 0) // negative class 00481 scale.x.sset((T) 0); 00482 else // positive class 00483 scale.x.sset(i); 00484 } 00485 00486 template <typename T, class Tstate> 00487 std::string scaler_energy<T,Tstate>::describe() { 00488 std::string s; 00489 s << "energy " << this->name() 00490 << " is the squared distance to -1 when input is > -1 for the " 00491 << "negative class and the squared distance to scale target for the " 00492 << "positive class"; 00493 return s; 00494 } 00495 00496 } // end namespace ebl 00497 00498 #endif /* EBL_ENERGY_HPP */