libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_nonlinearity.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun and Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *
00005  * Redistribution and use in source and binary forms, with or without
00006  * modification, are permitted provided that the following conditions are met:
00007  *     * Redistributions of source code must retain the above copyright
00008  *       notice, this list of conditions and the following disclaimer.
00009  *     * Redistributions in binary form must reproduce the above copyright
00010  *       notice, this list of conditions and the following disclaimer in the
00011  *       documentation and/or other materials provided with the distribution.
00012  *     * Redistribution under a license not approved by the Open Source
00013  *       Initiative (http://www.opensource.org) must display the
00014  *       following acknowledgement in all advertising material:
00015  *        This product includes software developed at the Courant
00016  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00017  *     * The names of the authors may not be used to endorse or promote products
00018  *       derived from this software without specific prior written permission.
00019  *
00020  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00021  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00022  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00023  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00024  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00025  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00026  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00027  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00028  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00029  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00030  ***************************************************************************/
00031 
00032 namespace ebl {
00033 
00035 
00036   template <typename T, class Tstate>
00037   stdsigmoid_module<T,Tstate>::stdsigmoid_module()
00038   : module_1_1<T,Tstate>("stdsigmoid") {
00039   }
00040 
00041   template <typename T, class Tstate>
00042   stdsigmoid_module<T,Tstate>::~stdsigmoid_module() {
00043   }
00044 
00045   // standard sigmoid module
00046   template <typename T, class Tstate>
00047   void stdsigmoid_module<T,Tstate>::fprop(Tstate &in, Tstate &out) {
00048     this->resize_output(in, out); // resize iff necessary
00049     this->resize_output(in, tmp); // resize iff necessary
00050     idx_stdsigmoid(in.x, out.x);
00051   }
00052 
00053   template <typename T, class Tstate>
00054   void stdsigmoid_module<T,Tstate>::bprop(Tstate &in, Tstate &out) {
00055     idx_dstdsigmoid(in.x, tmp);
00056     idx_mulacc(tmp, out.dx, in.dx);
00057   }
00058 
00059   template <typename T, class Tstate>
00060   void stdsigmoid_module<T,Tstate>::bbprop(Tstate &in, Tstate &out) {
00061     idx_dstdsigmoid(in.x, tmp);
00062     idx_mul(tmp, tmp, tmp);
00063     idx_mulacc(tmp, out.ddx, in.ddx);
00064   }
00065 
00066   template <typename T, class Tstate>
00067   stdsigmoid_module<T,Tstate>* stdsigmoid_module<T,Tstate>::copy() {
00068     return new stdsigmoid_module<T,Tstate>();
00069   }
00070   
00072 
00073   template <typename T, class Tstate>
00074   tanh_module<T,Tstate>::tanh_module() : module_1_1<T,Tstate>("tanh") {
00075   }
00076 
00077   template <typename T, class Tstate>
00078   tanh_module<T,Tstate>::~tanh_module() {
00079   }
00080 
00081   // tanh module
00082   template <typename T, class Tstate>
00083   void tanh_module<T,Tstate>::fprop(Tstate &in, Tstate &out) {
00084     this->resize_output(in, out); // resize iff necessary
00085     this->resize_output(in, tmp); // resize iff necessary
00086     idx_tanh(in.x, out.x);
00087   }
00088 
00089   template <typename T, class Tstate>
00090   void tanh_module<T,Tstate>::bprop(Tstate &in, Tstate &out) {
00091     idx_dtanh(in.x, tmp);
00092     idx_mulacc(tmp, out.dx, in.dx);
00093   }
00094 
00095   template <typename T, class Tstate>
00096   void tanh_module<T,Tstate>::bbprop(Tstate &in, Tstate &out) {
00097     idx_dtanh(in.x, tmp);
00098     idx_mul(tmp, tmp, tmp);
00099     idx_mulacc(tmp, out.ddx, in.ddx);
00100   }
00101 
00102   template <typename T, class Tstate>
00103   tanh_module<T,Tstate>* tanh_module<T,Tstate>::copy() {
00104     return new tanh_module<T,Tstate>();
00105   }
00106   
00108 
00109   template <typename T, class Tstate>
00110   softmax<T,Tstate>::softmax(double b) : module_1_1<T,Tstate>("softmax") {
00111     beta = b;
00112   }
00113 
00114   template <typename T, class Tstate>
00115   void softmax<T,Tstate>::resize_nsame(Tstate &in, Tstate &out, int n){
00116     int nmax = in.x.order();
00117     if(n==0||n>nmax) {eblerror("illegal type")}
00118     else{
00119       switch(n){
00120       case 1: out.resize(in.x.dim(0));
00121         break;
00122       case 2: out.resize(in.x.dim(0), in.x.dim(1));
00123         break;
00124       case 3: out.resize(in.x.dim(0), in.x.dim(1), in.x.dim(2));
00125         break;
00126       case 4: out.resize(in.x.dim(0), in.x.dim(1), in.x.dim(2), 
00127                           in.x.dim(3));
00128         break;
00129       case 5: out.resize(in.x.dim(0), in.x.dim(1), in.x.dim(2), 
00130                           in.x.dim(3), in.x.dim(4));
00131         break;
00132       case 6: out.resize(in.x.dim(0), in.x.dim(1), in.x.dim(2), 
00133                           in.x.dim(3), in.x.dim(4), in.x.dim(5));
00134         break;
00135       }
00136     }
00137   }
00138 
00139   template <typename T, class Tstate>
00140   void softmax<T,Tstate>::fprop(Tstate &in, Tstate &out){
00141     int n=in.x.order();
00142     if(n==0){
00143       idx<double> ib;
00144       ib.set(1);
00145       idx_copy(ib, out.x);
00146     }
00147     else {
00148       resize_nsame(in, out, n);
00149       if( n > 6) {eblerror("illegal type")}
00150       else{
00151         idx<double> pp(new srg<double>(), in.x.spec);
00152         idx<double> dot(new srg<double>(), in.x.spec);
00153         double mm = idx_max(in.x);
00154         idx_addc(in.x, -mm, pp);
00155         idx_dotc(pp, beta, dot);
00156         double out_sum = 0.0;
00157         double d = idx_sum(dot, &out_sum);
00158         idx_dotc(dot, (double)(1/d), out.x);
00159       }
00160     }
00161   }
00162 
00163   template <typename T, class Tstate>
00164   void softmax<T,Tstate>::bprop(Tstate &in, Tstate &out){
00165     int n = in.x.order();
00166     if( n == 0) return;
00167     if( n > 6 ) { eblerror("illegal type")}
00168     else{
00169       idx<double> pp(new srg<double>(), out.dx.spec);
00170       idx<double> mul(new srg<double>(), out.dx.spec);
00171       double dot = idx_dot(out.dx, out.x);
00172       idx_addc(out.dx, -dot, pp);
00173       idx_mul(out.x, pp, mul);
00174       idx_dotcacc(mul, beta, in.x);
00175     }
00176   }
00177 
00178   template <typename T, class Tstate>
00179   void softmax<T,Tstate>::bbprop(Tstate &in, Tstate &out){
00180     int n = in.x.order();
00181     if( n == 0) return;
00182     if( n > 6 ) { eblerror("illegal type")}
00183     else{
00184       idx<double> mul(new srg<double>(), out.x.spec);
00185       idx<double> dot(new srg<double>(), out.x.spec);
00186       idx<double> pp(new srg<double>(), out.x.spec);
00187       idx<double> mul2(new srg<double>(), out.x.spec);
00188       idx<double> pp2(new srg<double>(), out.x.spec);
00189       idx<double> mul3(new srg<double>(), out.x.spec);
00190       idx_mul(out.x, out.x, mul);
00191       idx_dotc(out.x, (double)-2, dot);
00192       idx_addc(dot, (double)1, pp);
00193       idx_mul(pp, out.ddx, mul2);
00194       idx_addc(mul2, idx_dot(out.ddx, mul), pp2);
00195       idx_mul(mul, pp2, mul3);
00196 
00197       idx_dotcacc(mul3, beta*beta, in.ddx);
00198     }
00199   }
00200 
00202   // abs_module
00203 
00204   template <typename T, class Tstate>
00205   abs_module<T,Tstate>::abs_module(double thres) : module_1_1<T,Tstate>("abs") {
00206     threshold = thres;
00207   }
00208 
00209   template <typename T, class Tstate>
00210   abs_module<T,Tstate>::~abs_module() {
00211   }
00212 
00213   template <typename T, class Tstate>
00214   void abs_module<T,Tstate>::fprop(Tstate& in, Tstate& out) {
00215     this->resize_output(in, out); // resize iff necessary
00216     idx_abs(in.x, out.x);
00217   }
00218 
00219   template <typename T, class Tstate>
00220   void abs_module<T,Tstate>::bprop(Tstate& in, Tstate& out) {
00221     state_idx_check_different(in, out); // forbid same in and out
00222     idx_checknelems2_all(in.dx, out.dx); // must have same dimensions
00223     
00224     idx_aloopf3(inx, in.x, T, indx, in.dx, T, outdx, out.dx, T, {
00225         if (*inx > threshold)
00226           *indx = *indx + *outdx;
00227         else if (*inx < -threshold)
00228           *indx = *indx - *outdx;
00229       });
00230   }
00231 
00232   template <typename T, class Tstate>
00233   void abs_module<T,Tstate>::bbprop(Tstate& in, Tstate& out) {
00234     state_idx_check_different(in, out); // forbid same in and out
00235     idx_checknelems2_all(in.ddx, out.ddx); // must have same dimensions
00236     
00237     idx_add(in.ddx, out.ddx, in.ddx);
00238   }
00239   
00240   template <typename T, class Tstate>
00241   abs_module<T,Tstate>* abs_module<T,Tstate>::copy() {
00242     return new abs_module<T,Tstate>();
00243   }
00244 
00246   // linear_shrink_module
00247 
00248   template <typename T, class Tstate>
00249   linear_shrink_module<T,Tstate>::linear_shrink_module(parameter<T,Tstate> *p,
00250                                                        intg nf, T bs)
00251     : module_1_1<T,Tstate>("linear_shrink"), bias(p,nf), default_bias(bs) {
00252     idx_fill(bias.x, bs);
00253   }
00254   
00255   template <typename T, class Tstate>
00256   linear_shrink_module<T,Tstate>::~linear_shrink_module(){
00257   }
00258 
00259   template <typename T, class Tstate>
00260   void linear_shrink_module<T,Tstate>::fprop(Tstate& in, Tstate& out) {
00261     if (&in != &out) eblerror("in and out should be different buffers");
00262     this->resize_output(in, out); // resize iff necessary
00263 
00264     idx_bloop3(inx, in.x, T, outx, out.x, T, biasx, bias.x, T) {
00265       T b = biasx.get();
00266       idx_aloopf2(i, inx, T, o, outx, T, {
00267           if (*i > b) *o = *i - b;
00268           else if (*i < -b) *o = *i + b;
00269           else *o = 0; });
00270     }
00271   }
00272   
00273   template <typename T, class Tstate>
00274   void linear_shrink_module<T,Tstate>::bprop(Tstate& in, Tstate& out) {
00275     idx_bloop5(inx, in.x, T, indx, in.dx, T, outdx, out.dx, T, 
00276                biasx, bias.x, T, biasdx, bias.dx, T) {
00277       T b = biasx.get();
00278       idx_aloopf3(i, inx, T, id, indx, T, od, outdx, T, {
00279           if (*i > b) {
00280             *id += *od;
00281             biasdx.set(biasdx.get() - *od);
00282           } else if (*i < -b) {
00283             *id += *od;
00284             biasdx.set(biasdx.get() - *od);
00285           }});
00286     }
00287   }
00288   
00289   template <typename T, class Tstate>
00290   void linear_shrink_module<T,Tstate>::bbprop(Tstate& in, Tstate& out){    
00291     idx_bloop5(inx, in.x, T, inddx, in.ddx, T, outddx, out.ddx, T, 
00292                biasx, bias.x, T, biasddx, bias.ddx, T) {
00293       T b = biasx.get();
00294       idx_aloopf3(i, inx, T, idd, inddx, T, odd, outddx, T, {
00295           if (*i > b) {
00296             *idd += *odd;
00297             biasddx.set(biasddx.get() - *odd);
00298           } else if (*i < -b) {
00299             *idd += *odd;
00300             biasddx.set(biasddx.get() - *odd);
00301           }});
00302     }
00303   }
00304   
00305   template <typename T, class Tstate>
00306   linear_shrink_module<T,Tstate>* linear_shrink_module<T,Tstate>::copy() {
00307     linear_shrink_module<T,Tstate>* s2 =
00308       new linear_shrink_module<T,Tstate>(NULL, bias.x.dim(0), default_bias);
00309     // assign same parameter state
00310     s2->bias = bias;
00311     return s2;
00312   }
00313 
00314   template <typename T, class Tstate>
00315   std::string linear_shrink_module<T,Tstate>::describe() {
00316     std::string desc;
00317     desc << "linear_shrink module " << this->name() << " with biases: " 
00318          << bias.x << " min: " << idx_min(bias.x) 
00319          << " max: " << idx_max(bias.x);
00320     return desc;
00321   }
00322 
00324   // smooth_shrink_module
00325 
00326   template <typename T, class Tstate>
00327   smooth_shrink_module<T,Tstate>::smooth_shrink_module(parameter<T,Tstate> *p,
00328                                                        intg nf, T bt, T bs)
00329     : module_1_1<T,Tstate>("smooth_shrink"), 
00330       beta(p,nf), bias(p,nf), ebb(1), ebx(1,1,1), tin(1,1,1), absmod(0.0),
00331       default_beta(bt), default_bias(bs) {
00332     idx_fill(beta.x, bt);
00333     idx_fill(bias.x, bs);
00334   }
00335   
00336   template <typename T, class Tstate>
00337   smooth_shrink_module<T,Tstate>::~smooth_shrink_module(){
00338   }
00339 
00340   template <typename T, class Tstate>
00341   void smooth_shrink_module<T,Tstate>::fprop(Tstate& in, Tstate& out) {
00342     if (&in != &out) { // resize only when input and output are different
00343       idxdim d(in.x.spec); // use same dimensions as in
00344       out.resize(d);
00345     } else
00346       eblerror("in and out should be different buffers");
00347     absmod.fprop(in,tin);
00348     // failsafe
00349     idx_aloopf1(x, in.x, T, {
00350         if (*x > 20)
00351           *x = 20;
00352       });
00353     ebb.resize(bias.x.dim(0));
00354     ebx.resize(in.x.get_idxdim());
00355     
00356     idx_mul(beta.x, bias.x, ebb.x);
00357     idx_exp(ebb.x);
00358 
00359     idx_bloop5(inx, tin.x, T, outx, out.x, T, ebbx, ebb.x, T,
00360                betax, beta.x, T, biasx, bias.x, T) {
00361       idx_dotc(inx, betax.get(), outx);
00362       idx_exp(outx);
00363       idx_addc(outx, ebbx.get()-1, outx);
00364       idx_log(outx);
00365       idx_dotc(outx, 1/betax.get(), outx);
00366       idx_addc(outx, -biasx.get(), outx);
00367     }
00368     idx_aloopf2(x, in.x, T, y, out.x, T, {
00369         if (abs((int)*x) > 20)
00370           *y = *x;
00371         if(*x < 0.0) {
00372           *y = -(*y);
00373         }
00374       });
00375   }
00376   
00377   template <typename T, class Tstate>
00378   void smooth_shrink_module<T,Tstate>::bprop(Tstate& in, Tstate& out) {
00379     absmod.fprop(in,tin);
00380     // failsafe
00381     idx_aloopf1(x, in.x, T, {
00382         if (*x > 20)
00383           *x = 20;
00384       });
00385     tin.clear_dx();
00386     beta.clear_dx();
00387     bias.clear_dx();
00388 
00389     // bb = exp (beta .* bias)
00390     idx_mul(beta.x, bias.x, ebb.x);
00391     idx_exp(ebb.x);
00392     intg nf = bias.x.dim(0);
00393     
00394     idx<T> ttx(ebx.x[0].spec);
00395     idx<T> tty(ebx.x[0].spec);
00396     for (intg i=0; i< nf; i++) {
00397       // ebx = exp(beta * x)
00398       idx<T> ebxxi = ebx.x[i];
00399       idx<T> ebxdxi = ebx.dx[i];
00400       idx<T> ebxddxi = ebx.ddx[i];
00401       idx<T> tinxi = tin.x[i];
00402       idx<T> tindxi = tin.dx[i];
00403       idx<T> outdxi = out.dx[i];
00404 
00405       idx_dotc(tinxi,beta.x[i].get(),ebxxi);
00406       idx_exp(ebxxi);
00407 
00408       // ebdx = exp(beta*x) + exp(beta*bias) -1
00409       idx_addc(ebxxi,ebb.x[i].get()-1,ebxdxi);
00410       // ebddx = exp (beta*x)/ (exp(beta*x) + exp(beta*bias)-1)
00411       idx_div(ebxxi,ebxdxi,ebxddxi);
00412 
00413       // df/dx
00414       idx_mul(ebxddxi,outdxi,tindxi);
00415       
00416       //cout << tinxi.get(0,0) << tindxi.get(0,0) << endl;
00417 
00418       // ebddx = 1/ebdx
00419       idx_inv(ebxdxi,ebxddxi);
00420 
00421       // df/dbias
00422       idx_dotc(ebxddxi,ebb.x[i].get(),ttx);
00423       idx_addc(ttx,(T)-1.0,ttx);
00424       bias.dx[i].set(idx_dot(outdxi,ttx));
00425       
00426       // df/dbeta
00427       idx_mul(tinxi,ebxxi,ttx);
00428       idx_addc(ttx, bias.x[i].get() * ebb.x[i].get(),ttx);
00429       idx_mul(ttx,ebxddxi,ttx);
00430       idx_dotc(ttx, 1/beta.x[i].get(),ttx);
00431       idx_log(ebxdxi);
00432       idx_dotc(ebxdxi,-1/(beta.x[i].get()*beta.x[i].get()),tty);
00433       idx_add(ttx,tty,ttx);
00434       beta.dx[i].set((T)idx_dot(outdxi,ttx));
00435     }
00436     idx_add(in.dx,tin.dx,in.dx);
00437   }
00438   
00439   template <typename T, class Tstate>
00440   void smooth_shrink_module<T,Tstate>::bbprop(Tstate& in, Tstate& out){    
00441     absmod.fprop(in,tin);
00442     // failsafe
00443     idx_aloopf1(x, in.x, T, {
00444         if (*x > 20)
00445           *x = 20;
00446       });
00447     tin.clear_ddx();
00448     beta.clear_ddx();
00449     bias.clear_ddx();
00450 
00451     // bb = exp (beta .* bias)
00452     idx_mul(beta.x, bias.x, ebb.x);
00453     idx_exp(ebb.x);
00454     intg nf = bias.x.dim(0);
00455     
00456     idx<T> ttx(ebx.x[0].spec);
00457     idx<T> tty(ebx.x[0].spec);
00458     for (intg i=0; i< nf; i++) {
00459       // ebx = exp(beta * x)
00460       idx<T> ebxxi = ebx.x[i];
00461       idx<T> ebxdxi = ebx.dx[i];
00462       idx<T> ebxddxi = ebx.ddx[i];
00463       idx<T> tinxi = tin.x[i];
00464       idx<T> tindxi = tin.ddx[i];
00465       idx<T> outdxi = out.ddx[i];
00466 
00467       idx_dotc(tinxi,beta.x[i].get(),ebxxi);
00468       idx_exp(ebxxi);
00469 
00470       // ebdx = exp(beta*x) + exp(beta*bias) -1
00471       idx_addc(ebxxi,ebb.x[i].get()-1,ebxdxi);
00472       // ebddx = exp (beta*x)/ (exp(beta*x) + exp(beta*bias)-1)
00473       idx_div(ebxxi,ebxdxi,ebxddxi);
00474 
00475       // df/dx
00476       idx_mul(ebxddxi,ebxddxi,ebxddxi);
00477       idx_mul(ebxddxi,outdxi,tindxi);
00478       
00479       //cout << tinxi.get(0,0) << tindxi.get(0,0) << endl;
00480 
00481       // ebddx = 1/ebdx
00482       idx_inv(ebxdxi,ebxddxi);
00483 
00484       // df/dbias
00485       idx_dotc(ebxddxi,ebb.x[i].get(),ttx);
00486       idx_addc(ttx,(T)-1.0,ttx);
00487       idx_mul(ttx,ttx,ttx);
00488       bias.ddx[i].set((T)idx_dot(outdxi,ttx));
00489       
00490       // df/dbeta
00491       idx_mul(tinxi,ebxxi,ttx);
00492       idx_addc(ttx, bias.x[i].get() * ebb.x[i].get(),ttx);
00493       idx_mul(ttx,ebxddxi,ttx);
00494       idx_dotc(ttx, 1/beta.x[i].get(),ttx);
00495       idx_log(ebxdxi);
00496       idx_dotc(ebxdxi,-1/(beta.x[i].get()*beta.x[i].get()),tty);
00497       idx_add(ttx,tty,ttx);
00498       idx_mul(ttx,ttx,ttx);
00499       beta.ddx[i].set((T)idx_dot(outdxi,ttx));
00500     }
00501     idx_add(in.ddx,tin.ddx,in.ddx);
00502   }
00503   
00504   template <typename T, class Tstate>
00505   smooth_shrink_module<T,Tstate>* smooth_shrink_module<T,Tstate>::copy() {
00506     smooth_shrink_module<T,Tstate>* s2 =
00507       new smooth_shrink_module<T,Tstate>(NULL, beta.x.dim(0),
00508                                          default_beta, default_bias);
00509     // assign same parameter state
00510     s2->beta = beta;
00511     s2->bias = bias;
00512     return s2;
00513   }
00514 
00516   // tanh_shrink_module
00517 
00518   template <typename T, class Tstate>
00519   tanh_shrink_module<T,Tstate>::
00520   tanh_shrink_module(parameter<T,Tstate> *p, intg nf, bool diags_)
00521     : module_1_1<T,Tstate>("tanh_shrink"),
00522       nfeatures(nf), alpha(NULL), beta(NULL), diags(diags_) {
00523     if (diags) {
00524       alpha = new diag_module<T,Tstate>(p, nf);
00525       beta = new diag_module<T,Tstate>(p, nf);
00526     }
00527   }
00528   
00529   template <typename T, class Tstate>
00530   tanh_shrink_module<T,Tstate>::~tanh_shrink_module() {
00531     if (alpha) delete alpha;
00532     if (beta) delete beta;
00533   }
00534 
00535   template <typename T, class Tstate>
00536   void tanh_shrink_module<T,Tstate>::fprop(Tstate& in, Tstate& out) {
00537     if (&in != &out) { // resize only when input and output are different
00538       this->resize_output(in, out); // resize iff necessary
00539     } else eblerror("in and out should be different buffers");
00540     // fprop
00541     if (diags) { // use coefficients
00542       // x * alpha
00543       alpha->fprop(in, abuf);
00544       // tanh(x * alpha)
00545       mtanh.fprop(abuf, tbuf);
00546       // (x * alpha) - tanh(x * alpha)
00547       difmod.fprop(in, tbuf, bbuf);
00548       // beta * ((x * alpha) - tanh(x * alpha))
00549       beta->fprop(bbuf, out);
00550     } else { // no coefficients
00551       // tanh(x)
00552       mtanh.fprop(in, tbuf);
00553       // x - tanh(x)
00554       difmod.fprop(in, tbuf, out);
00555     }
00556   }
00557   
00558   template <typename T, class Tstate>
00559   void tanh_shrink_module<T,Tstate>::bprop(Tstate& in, Tstate& out) {
00560     // clear derivatives
00561     tbuf.clear_dx();
00562     // bprop
00563     if (diags) { // use coefficients
00564       // clear derivatives
00565       abuf.clear_dx();
00566       bbuf.clear_dx();
00567       // bprop
00568       beta->bprop(bbuf, out);
00569       difmod.bprop(in, tbuf, bbuf);
00570       mtanh.bprop(abuf, tbuf);
00571       alpha->bprop(in, abuf);
00572     } else { // no coefficients
00573       difmod.bprop(in, tbuf, out);
00574       mtanh.bprop(in, tbuf);
00575     }
00576   }
00577   
00578   template <typename T, class Tstate>
00579   void tanh_shrink_module<T,Tstate>::bbprop(Tstate& in, Tstate& out) {
00580     tbuf.clear_ddx();
00581     // bbprop
00582     if (diags) { // use coefficients
00583       // clear derivatives
00584       abuf.clear_ddx();
00585       bbuf.clear_ddx();
00586       // bprop
00587       beta->bbprop(bbuf, out);
00588       difmod.bbprop(in, tbuf, bbuf);
00589       mtanh.bbprop(abuf, tbuf);
00590       alpha->bbprop(in, abuf);
00591     } else { // no coefficients
00592       difmod.bbprop(in, tbuf, out);
00593       mtanh.bbprop(in, tbuf);
00594     }
00595   }
00596   
00597   template <typename T, class Tstate>
00598   tanh_shrink_module<T,Tstate>* tanh_shrink_module<T,Tstate>::copy() {
00599     tanh_shrink_module<T,Tstate>* s2 =
00600       new tanh_shrink_module<T,Tstate>(NULL, nfeatures);
00601     // assign same parameter state
00602     if (s2->alpha) delete s2->alpha;
00603     if (s2->beta) delete s2->beta;
00604     s2->alpha = alpha->copy();
00605     s2->beta = beta->copy();
00606     return s2;
00607   }
00608 
00609   template <typename T, class Tstate>
00610   std::string tanh_shrink_module<T,Tstate>::describe() {
00611     std::string desc;
00612     desc << "tanh_shrink module " << this->name() 
00613          << (diags ? " with" : " without") << " scaling coefficients";
00614     return desc;
00615   }
00616 
00617 } // end namespace ebl