libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_merge.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet *
00003  *   pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 namespace ebl {
00034 
00036   // flat_merge_module
00037 
00038   template <typename T, class Tstate>
00039   flat_merge_module<T, Tstate>::
00040   flat_merge_module(std::vector<Tstate**> &inputs_, idxdim &in_, midxdim &ins_,
00041                     fidxdim &stride_, mfidxdim &strides_,
00042                     const char *name_, const char *list)
00043     : m2s_module<T,Tstate>(inputs_.size() + 1, name_), inputs(inputs_),
00044       din(in_), dins(ins_), stride(stride_), strides(strides_), in0(NULL),
00045       use_pinputs(false) {
00046     if (list)
00047       merge_list = list;
00048     // allocate zpad vector
00049     zpads.assign(strides.size(), NULL);
00050   }
00051 
00052   template <typename T, class Tstate>
00053   flat_merge_module<T, Tstate>::
00054   flat_merge_module(std::vector<mstate<Tstate>**> &inputs_,
00055                     idxdim &in_, midxdim &ins_, fidxdim &stride_,
00056                     mfidxdim &strides_, const char *name_, const char *list)
00057     : m2s_module<T,Tstate>(inputs_.size() + 1, name_), //msinputs(inputs_),
00058       din(in_), dins(ins_), stride(stride_), strides(strides_),
00059       use_pinputs(false)  {
00060     if (list)
00061       merge_list = list;
00062     // allocate zpad vector
00063     zpads.assign(strides.size(), NULL);
00064     eblerror("not implemented");
00065   }
00066 
00067   template <typename T, class Tstate>
00068   flat_merge_module<T, Tstate>::
00069   flat_merge_module(midxdim &ins_, mfidxdim &strides_, bool bpad_,
00070                     const char *name_, mfidxdim *scales_,
00071                     /*TEMP*/ intg hextra_, intg wextra_, float ss_, float edge_)
00072     : m2s_module<T,Tstate>(ins_.size(), name_), in0(NULL), use_pinputs(true),
00073       bpad(bpad_),
00074       //TEMP
00075       hextra(hextra_), wextra(wextra_), subsampling(ss_), edge(edge_)
00076   {
00077     if (scales_)
00078       scales = *scales_;
00079     // check there are enough elements
00080     if (ins_.size() < 1 || strides_.size() < 1
00081         || ins_.size() != strides_.size())
00082       eblerror("expected at least 1 dim and stride (matching) but got: dims: "
00083                << ins_.size() << " strides: " << strides_.size());
00084     // separate first dim/strides from rest
00085     din = ins_[0];
00086     stride = strides_[0];
00087     // add remaining ones
00088     for (uint i = 0; i < ins_.size(); ++i) {
00089       dins.push_back(ins_[i]);
00090       strides.push_back(strides_[i]);
00091     }
00092     // allocate zpad vector
00093     zpads.assign(strides.size(), NULL);
00094   }
00095 
00096   template <typename T, class Tstate>
00097   flat_merge_module<T, Tstate>::~flat_merge_module() {
00098     // clean up zero padding modules
00099     for (uint i = 0; i < zpads.size(); ++i)
00100       if (zpads[i])
00101         delete zpads[i];
00102   }
00103 
00105   // generic state methods
00106 
00107   template <typename T, class Tstate>
00108   void flat_merge_module<T, Tstate>::fprop(Tstate &in, Tstate &out) {
00109     // pad each input so that all windows start centered on the first actual
00110     // pixel top right
00111 
00112     // if (inputs.size() == 0)
00113     //   eblerror("no inputs to merge");
00114     // feature size for main input
00115     intg fsize = din.dim(0) * din.dim(1) * in.x.dim(0);
00116     // number of possible windows
00117     intg nh = 1 + (intg) ((in.x.dim(1) - din.dim(0)) / stride.dim(0));
00118     intg nw = 1 + (intg) ((in.x.dim(2) - din.dim(1)) / stride.dim(1));
00119     // compute new size and resize output if necessary
00120     for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) {
00121       idxdim &d = dins[i];
00122       fidxdim &s = strides[i];
00123       idx<T> input = (use_pinputs ? pinputs[i]->x : (*inputs[i])->x);
00124       fsize += d.nelements() * input.dim(0);
00125       // check that strides match possible windows
00126       intg nh2 = (intg) (1 + (input.dim(1) - d.dim(0)) / s.dim(0));
00127       intg nw2 = (intg) (1 + (input.dim(2) - d.dim(1)) / s.dim(1));
00128       if (nh2 < nh || nw2 < nw) {
00129         eblerror("input " << input << " and window " << d << " with stride " <<s
00130                  << " produce " << nh2 << "x" << nw2
00131                  << " outputs but expected at least " << nh << "x" << nw);
00132       } else if (nh2 != nh || nw2 != nw)
00133         EDEBUG("warning: input " << input << " and window " << d
00134               << " with stride " <<s << " produce " << nh2 << "x" << nw2
00135               << ", ignoring extra cells and using only " <<nh << "x" << nw);
00136     }
00137     idxdim d(fsize, nh, nw);
00138     if (!out.x.same_dim(d)) {
00139       if (out.x.order() != d.order())
00140         out = Tstate(d);
00141       else
00142         out.resize(d);
00143     }
00144     intg offset = 0;
00145     // copy main input to out
00146     fsize = din.nelements() * in.x.dim(0); // feat size for main input
00147     // loop on all possible windows for this state
00148     float fh, fw;
00149     uint uh, uw, h, w;
00150     for (h = 0, fh = 0; h < nh; h++, fh += stride.dim(0)) {
00151       for (w = 0, fw = 0; w < nw; w++, fw += stride.dim(1)) {
00152         // integer positions
00153         uh = (uint) fh;
00154         uw = (uint) fw;
00155         // select 1 output pixel in the correct feature range
00156         idx<T> o = out.x.select(2, w);
00157         o = o.select(1, h);
00158         o = o.narrow(0, fsize, offset);
00159         // select input window
00160         idx<T> iw = in.x.narrow(2, din.dim(1), uw);
00161         iw = iw.narrow(1, din.dim(0), uh);
00162         // copy flat input to output
00163         // TODO: tmp buffer less efficient than direct copy which but requires
00164         // continuous data, make idx pointing to oo with flat's dims?
00165         idx<T> tmp(iw.get_idxdim());
00166         idx_copy(iw, tmp);
00167         iw = tmp.view_as_order(1);
00168         idx_copy(iw, o);
00169       }
00170     }
00171     offset += fsize;
00172     // copy inputs to out
00173     for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) {
00174       idxdim dd = dins[i];
00175       fidxdim s = strides[i];
00176       idx<T> input = (use_pinputs ? pinputs[i]->x : (*inputs[i])->x);
00177       fsize = dd.nelements() * input.dim(0); // feature size from input
00178       // copy
00179       for (h = 0, fh = 0; h < nh; h++, fh += s.dim(0)) {
00180         for (w = 0, fw = 0; w < nw; w++, fw += s.dim(1)) {
00181           // integer positions
00182           uh = (uint) fh;
00183           uw = (uint) fw;
00184           // select 1 output pixel in the correct feature range
00185           idx<T> o = out.x.select(2, w);
00186           o = o.select(1, h);
00187           o = o.narrow(0, fsize, offset);
00188           // select input window
00189           idx<T> iw = input.narrow(2, dd.dim(1), uw);
00190           iw = iw.narrow(1, dd.dim(0), uh);
00191           // copy flat input to output
00192           // TODO: tmp buffer less efficient than direct copy which but requires
00193           // continuous data, make idx pointing to oo with flat's dims?
00194           idx<T> tmp(iw.get_idxdim());
00195           idx_copy(iw, tmp);
00196           iw = tmp.view_as_order(1);
00197           idx_copy(iw, o);
00198         }
00199       }
00200       offset += fsize;
00201     }
00202 #ifdef __DEBUG_PRINT__
00203     cout << describe() << ": " << in.x << " (in " << din
00204          << " stride " << stride << ")";
00205     for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i)
00206       cout << " + " << (use_pinputs ? pinputs[i]->x : (*inputs[i])->x)
00207            << " (in " << dins[i] << " stride " << strides[i] << ")";
00208     cout << " -> " << out.x << endl;
00209 #endif
00210   }
00211 
00212   template <typename T, class Tstate>
00213   void flat_merge_module<T, Tstate>::bprop(Tstate &in, Tstate &out) {
00214     // if (inputs.size() == 0)
00215     //   eblerror("no inputs to merge");
00216     // copy out to main input
00217     intg offset = 0;
00218     idx<T> o1 = out.dx.view_as_order(1);
00219     idx<T> o = o1.narrow(0, in.dx.nelements(), offset);
00220     idx<T> input = in.dx.view_as_order(1);
00221     idx_add(o, input, input);
00222     offset += input.nelements();
00223     // copy out to inputs
00224     for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) {
00225       input = (use_pinputs ? pinputs[i]->dx : (*inputs[i])->dx);
00226       input = input.view_as_order(1);
00227       o = o1.narrow(0, input.nelements(), offset);
00228       idx_add(o, input, input);
00229       offset += input.nelements();
00230     }
00231   }
00232 
00233   template <typename T, class Tstate>
00234   void flat_merge_module<T, Tstate>::bbprop(Tstate &in,
00235                                         Tstate &out) {
00236     // if (inputs.size() == 0)
00237     //   eblerror("no inputs to merge");
00238     // copy out to main input
00239     intg offset = 0;
00240     idx<T> o1 = out.ddx.view_as_order(1);
00241     idx<T> o = o1.narrow(0, in.ddx.nelements(), offset);
00242     idx<T> input = in.ddx.view_as_order(1);
00243     idx_add(o, input, input);
00244     offset += input.nelements();
00245     // copy out to inputs
00246     for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) {
00247       input = (use_pinputs ? pinputs[i]->ddx : (*inputs[i])->ddx);
00248       input = input.view_as_order(1);
00249       o = o1.narrow(0, input.nelements(), offset);
00250       idx_add(o, input, input);
00251       offset += input.nelements();
00252     }
00253   }
00254 
00256 
00257   template <typename T, class Tstate>
00258   idxdim flat_merge_module<T, Tstate>::
00259   compute_pad(idxdim &window, float subsampling, float edge,
00260               float scale, fidxdim &stride) {
00261     float hoff = (edge * scale * stride.dim(0)) / subsampling + .5;
00262     float woff = (edge * scale * stride.dim(1)) / subsampling + .5;
00263     idxdim d = window;
00264     d.setdim(0, (int) (d.dim(0) + hoff * 2));
00265     d.setdim(1, (int) (d.dim(1) + woff * 2));
00266     return d;
00267   }
00268 
00269   // template <typename T, class Tstate>
00270   // void flat_merge_module<T, Tstate>::set_paddings(mfidxdim &pads) {
00271   //   paddings = pads;
00272   //   EDEBUG(this->name() << ": setting paddings to " << paddings);
00273   // }
00274 
00275   template <typename T, class Tstate>
00276   void flat_merge_module<T, Tstate>::set_offsets(vector<vector<int> > &off) {
00277     offsets = off;
00278     EDEBUG(this->name() << ": setting offsets to " << offsets);
00279   }
00280 
00281   template <typename T, class Tstate>
00282   void flat_merge_module<T, Tstate>::set_strides(mfidxdim &s) {
00283     strides = s;
00284     cout << this->name() << ": setting strides to " << strides << endl;
00285   }
00286 
00287   template <typename T, class Tstate>
00288   void flat_merge_module<T, Tstate>::
00289   fprop(mstate<Tstate> &in, Tstate &out) {
00290     LOCAL_TIMING_START(); // profiling
00291     EDEBUG(this->name() << ": " << in << ", wins: " << dins << ", strides: "
00292           << strides << ", scales: " << scales << ", paddings: " << paddings);
00293     padded.resize(in);
00294     //    strides.clear();
00295     idxdim dref, dref2;
00296     // loop on each state
00297     //float pix1 = .5;
00298     for (uint k = 0; k < in.size(); ++k) {
00299       Tstate i = in[k];
00300       Tstate &p = padded[k];
00301       idxdim d = dins[k];
00302       fidxdim &s = strides[k];
00303       // pad each input so that all windows start centered on the first actual
00304       // pixel top right
00305       if (bpad) {
00306         // narrow input if specified
00307         if (k < offsets.size()) {
00308           vector<int> &off = offsets[k];
00309           if (off.size() != 4) eblerror("expected 4");
00310           int oh0 = (int) (off[0]);
00311           int ow0 = (int) (off[1]);
00312           int oh1 = (int) (off[2]);
00313           int ow1 = (int) (off[3]);
00314           // int oh0 = (int) (off[0] * s.dim(0));
00315           // int ow0 = (int) (off[1] * s.dim(1));
00316           // int oh1 = (int) (off[2] * s.dim(0));
00317           // int ow1 = (int) (off[3] * s.dim(1));
00318           if (oh0 < 0) {
00319             int sz2 = i.x.dim(1) + oh0 + oh1;
00320             if (sz2 <= 0) {
00321               eblwarn("trying to narrow dim 1 of " << i.x << " to size "<< sz2);
00322             } else {
00323               EDEBUG("narrowing height with offset " << -oh0 << " in " << i.x);
00324               i = i.narrow(1, sz2, -oh0);
00325             }
00326           }
00327           if (ow0 < 0) {
00328             int sz2 = i.x.dim(2) + ow0 + ow1;
00329             if (sz2 <= 0) {
00330               eblwarn("trying to narrow dim 2 of " << i.x << " to size "<< sz2);
00331             } else {
00332               EDEBUG("narrowing width with offset " << -ow0 << " in " << i.x);
00333               i = i.narrow(2, sz2, -ow0);
00334             }
00335           }
00336         }
00337 
00338         if (scales.size() != strides.size())
00339           eblerror("expected scales to be the same size as strides but got "
00340                    << scales << " when strides are " << strides);
00341         //if (k % 4 == 0) pix1 *= 2;
00342         //float pix1 = scales[k].dim(0);
00343         //idxdim d4 = compute_pad(d, subsampling, edge, pix1, s);
00344         // fidxdim fpads;
00345         // if (k < paddings.size()) fpads = paddings[k];
00346         // idxdim d4 = fpads;
00347 
00348         // //idxdim d4 = compute_pad(d, 1, 0, pix1, s);
00349         // //   idxdim d4 = compute_pad(d, 4, 6, pix1, s);
00350         // // idxdim d4 = compute_pad(d, 6, 12, pix1, s);
00351         // //padder.set_kernel(d4);
00352         // padder.set_paddings(d4);
00353         // // add extra padding at ends if necessary to match target
00354         // //   intg hextra = 2, wextra = 2;
00355         // idxdim pads = padder.get_paddings();
00356         // // if (k > 0) {
00357         // //   hextra = (intg) (dref.dim(0) - ((i.x.dim(1) + pads.dim(0) + pads.dim(2)
00358         // //                             - (d.dim(0) -d.dim(0) % 2)) / s.dim(0)));
00359         // //   wextra = (intg) (dref.dim(1) - ((i.x.dim(2) + pads.dim(1) + pads.dim(3)
00360         // //                              - (d.dim(1) -d.dim(1) % 2)) / s.dim(1)));
00361         // // }
00362         // if (k > 0) {
00363         //   pads.setdim(2, pads.dim(2) + std::max(0, (int) hextra));
00364         //   pads.setdim(3, pads.dim(3) + std::max(0, (int) wextra));
00365         // }
00366         // padder.set_paddings(pads);
00367 
00368         // add padding if missing to reach target (failsafe)
00369         idxdim pads(0,0,0,0);// = padder.get_paddings();
00370         intg sh = (intg) ((i.x.dim(1) + pads.dim(0) + pads.dim(2)
00371                            - d.dim(0) + 1) / s.dim(0));
00372         intg sw = (intg) ((i.x.dim(2) + pads.dim(1) + pads.dim(3)
00373                            - d.dim(1) + 1) / s.dim(1));
00374         bool w = false;
00375         if (k > 0 && sh < dref.dim(0)) {
00376           pads.setdim(2, pads.dim(2) + std::max(0, (int) (dref.dim(0) - sh)));
00377           w = true;
00378         }
00379         if (k > 0 && sw < dref.dim(1)) {
00380           pads.setdim(3, pads.dim(3) + std::max(0, (int) (dref.dim(1) - sw)));
00381           if (!this->silent && pads.maxdim() > 2)
00382           w = true;
00383         }
00384         if (w && !this->silent && pads.maxdim() > 2)
00385             eblwarn("adding extra padding "<<pads<<" to match target " << dref);
00386 
00387         padder.set_paddings(pads);
00388 
00389         // EDEBUG("before adding padding: " << pads);
00390         // // add fixed extra padding
00391         // pads = padder.get_paddings();
00392         // sh = (intg) (4 * s.dim(0));
00393         // sw = (intg) (4 * s.dim(1));
00394         // pads.setdim(0, pads.dim(0) + sh);
00395         // pads.setdim(2, pads.dim(2) + sh);
00396         // pads.setdim(1, pads.dim(1) + sw);
00397         // pads.setdim(3, pads.dim(3) + sw);
00398         // padder.set_paddings(pads);
00399         // EDEBUG("after adding padding: " << pads);
00400         // pad
00401         padder.fprop(i, p);
00402         // } else {
00403         //   if (i.x.contiguousp()) padded[k] = i;
00404         //   else {
00405         //     if (padded[k].x.get_idxdim() != i.x.get_idxdim())
00406         //       padded[k] = Tstate(i.x.get_idxdim());
00407         //     idx_copy(i.x, padded[k].x);
00408         //   }
00409         // }
00410       } else padded[k] = in[k];
00411       // compute number of outputs for this kernel
00412       idxdim dout((intg) ((p.x.dim(1) - d.dim(0) + 1) / s.dim(0)),
00413                   (intg) ((p.x.dim(2) - d.dim(1) + 1) / s.dim(1)));
00414       idxdim dout2((intg) ((p.x.dim(1) - d.dim(0) + 1)),
00415                    (intg) ((p.x.dim(2) - d.dim(1) + 1)));
00416 //       idxdim dout((p.x.dim(1) - (d.dim(0) - d.dim(0) % 2)) / s.dim(0),
00417 //                (p.x.dim(2) - (d.dim(1) - d.dim(1) % 2)) / s.dim(1));
00418       // use 1st dout as reference
00419       if (k == 0) {
00420         dref = dout;
00421         dref2 = dout2;
00422       }
00423       EDEBUG(this->name() << ": in " << p.x << " (min: " << idx_min(p.x)
00424             << ", max: " << idx_max(p.x) << ") with window " << d
00425             << " and stride " << s << " -> " << dout);
00426 
00427 
00428       // // adjust strides so that all states produce dref outputs
00429       // fidxdim ss(1.0, 1.0);
00430       // if (k > 0)
00431       //        ss = fidxdim(dout2.dim(0) / (float) dref2.dim(0),
00432       //                     dout2.dim(1) / (float) dref2.dim(1));
00433       // strides[k] = ss;
00434       // EDEBUG("setting stride to " << ss << " for input " << p.x << " and window "
00435       //            << d << " to produce " << dref << " outputs");
00436     }
00437     LOCAL_TIMING_REPORT("merge padding");
00438 
00439 
00440     // if (inputs.size() == 0)
00441     //   eblerror("no inputs to merge");
00442     // feature size for main input
00443     idx<T> &in0 = padded[0].x;
00444     intg fsize = din.dim(0) * din.dim(1) * in0.dim(0);
00445     // number of possible windows
00446     // intg nh = 1 + (intg) ((in0.dim(1) - din.dim(0)) / stride.dim(0));
00447     // intg nw = 1 + (intg) ((in0.dim(2) - din.dim(1)) / stride.dim(1));
00448     intg nh = dref.dim(0), nw = dref.dim(1);
00449     // compute new size and resize output if necessary
00450     for (uint i = 1; i < padded.size(); ++i) {
00451       idxdim &d = dins[i];
00452       fidxdim &s = strides[i];
00453       idx<T> &input = padded[i].x;
00454       fsize += d.nelements() * input.dim(0);
00455       // check that strides match possible windows
00456       intg nh2 = (intg) ceil((input.dim(1) - d.dim(0) + 1)
00457                              / std::max(10e-9, (double) s.dim(0)));
00458       intg nw2 = (intg) ceil((input.dim(2) - d.dim(1) + 1)
00459                              / std::max(10e-9, (double) s.dim(1)));
00460       if (nh2 < nh || nw2 < nw) {
00461         *(this->mout) << "COUT input " << input << " and window " << d << " with stride " <<s
00462               << " produce " << nh2 << "x" << nw2
00463              << " outputs but expected at least " << nh << "x" << nw << endl;
00464         eblerror("input " << input << " and window " << d << " with stride " <<s
00465               << " produce " << nh2 << "x" << nw2
00466               << " outputs but expected at least " << nh << "x" << nw);
00467       } else if (nh2 != nh || nw2 != nw)
00468         EDEBUG("warning: input " << input << " and window " << d << " with stride " <<s
00469               << " produce " << nh2 << "x" << nw2
00470               << ", ignoring extra cells and using only " <<nh << "x" << nw);
00471       EDEBUG("input " << i << " " << input << ", min " << idx_min(input)
00472             << " max " << idx_max(input));
00473     }
00474     LOCAL_TIMING_REPORT("merge check");
00475     idxdim d(fsize, nh, nw);
00476     if (!out.x.same_dim(d)) {
00477       if (out.x.order() != d.order())
00478         out = Tstate(d);
00479       else
00480         out.resize(d);
00481     }
00482     LOCAL_TIMING_REPORT("merge resize");
00483     idx_clear(out.x);
00484     LOCAL_TIMING_REPORT("merge clear of " << out.x);
00485 
00486     intg offset = 0;
00487     int h = 0, w = 0;
00488     float fh, fw;
00489     uint uh = 0, uw = 0, uw0 = 0;
00490     idx<T> iw, ow, onarrowed, inarrowed;
00491     // copy inputs to out
00492     for (uint i = 0; i < padded.size(); ++i) {
00493       idxdim dd = dins[i];
00494       intg dd0 = dd.dim(0), dd1 = dd.dim(1);
00495       fidxdim s = strides[i];
00496       float s0 = s.dim(0), s1 = s.dim(1);
00497       idx<T> &input = padded[i].x;
00498       if (!input.contiguousp()) eblerror("expected contiguous");
00499       fsize = dd.nelements() * input.dim(0); // feature size from input
00500       onarrowed = out.x.narrow(0, fsize, offset);
00501       h = 0; w = 0;
00502       intg wmod = onarrowed.mod(2);
00503       // copy
00504       for (h = 0, fh = 0; h < nh; h++, fh += s0) {
00505         uh = (uint) fh;
00506         // select 1 output pixel in the correct feature range
00507         ow = onarrowed.select(2, 0);
00508         ow = ow.select(1, h);
00509         inarrowed = input.narrow(1, dd0, uh);
00510         intg iwmod = inarrowed.mod(2);
00511         uw = 0; uw0 = 0;
00512         iw = inarrowed.narrow(2, dd1, uw);
00513         for (w = 0, fw = 0; w < nw; ++w, fw += s1) {
00514           // integer positions
00515           uw = (uint) fw;
00516           // select input window
00517           if (uw != uw0)
00518             iw.add_offset(iwmod);
00519           // copy flat input to output
00520           // TODO: tmp buffer less efficient than direct copy which but requires
00521           // continuous data, make idx pointing to oo with flat's dims?
00522           // idx<T> tmp(iw.get_idxdim());
00523           // idx_copy(iw, tmp);
00524           // iw = tmp.view_as_order(1);
00525           idx_copy(iw, ow);
00526 
00527           uw0 = uw;
00528           ow.add_offset(wmod);
00529         }
00530       }
00531       offset += fsize;
00532     }
00533 
00534     LOCAL_TIMING_REPORT("merge copies");
00535 
00536 #ifdef __DEBUG_PRINT__
00537     cout << describe() << ": " << in0 << " (in " << din
00538          << " stride " << stride << ")";
00539     for (uint i = 1; i < padded.size(); ++i)
00540       cout << " + " << padded[i].x
00541            << " (in " << dins[i] << " stride " << strides[i] << ")";
00542     cout << " -> " << out.x << endl;
00543     cout << "output min: " << idx_min(out.x) << " max: " << idx_max(out.x)
00544          << endl;
00545 #endif
00546   }
00547 
00548   template <typename T, class Tstate>
00549   void flat_merge_module<T, Tstate>::bprop(mstate<Tstate> &in, Tstate &out) {
00550     idx<T> o, input, o1 = out.dx.view_as_order(1);
00551     intg offset = 0;
00552     // copy out to inputs
00553     for (uint i = 0; i < in.size(); ++i) {
00554       input = in[i].dx;
00555       input = input.view_as_order(1);
00556       o = o1.narrow(0, input.nelements(), offset);
00557       idx_add(o, input, input);
00558       offset += input.nelements();
00559     }
00560   }
00561 
00562   template <typename T, class Tstate>
00563   void flat_merge_module<T, Tstate>::bbprop(mstate<Tstate> &in,
00564                                             Tstate &out) {
00565     idx<T> o, input, o1 = out.ddx.view_as_order(1);
00566     intg offset = 0;
00567     // copy out to inputs
00568     for (uint i = 0; i < in.size(); ++i) {
00569       input = in[i].ddx;
00570       input = input.view_as_order(1);
00571       o = o1.narrow(0, input.nelements(), offset);
00572       idx_add(o, input, input);
00573       offset += input.nelements();
00574     }
00575   }
00576 
00577 //   template <typename T, class Tstate>
00578 //   void flat_merge_module<T, Tstate>::bprop(mstate<Tstate> &in, Tstate &out) {
00579 //     bprop(*in0, out);
00580 //   }
00581 
00582 //   template <typename T, class Tstate>
00583 //   void flat_merge_module<T, Tstate>::bbprop(mstate<Tstate> &in, Tstate &out) {
00584 //     bbprop(*in0, out);
00585 //   }
00586 
00587   template <typename T, class Tstate>
00588   flat_merge_module<T,Tstate>* flat_merge_module<T,Tstate>::copy() {
00589     flat_merge_module<T,Tstate> *l2 =
00590       new flat_merge_module<T,Tstate>(dins, strides, bpad, this->name(),
00591                                       &scales);
00592     return l2;
00593   }
00594 
00596 
00597   template <typename T, class Tstate>
00598   idxdim flat_merge_module<T,Tstate>::fprop_size(idxdim &isize) {
00599     // feature size for main input
00600     intg fsize = din.dim(0) * din.dim(1) * isize.dim(0);
00601     // number of possible windows
00602     intg nh = 1 + (intg) ((isize.dim(1) - din.dim(0)) / stride.dim(0));
00603     intg nw = 1 + (intg) ((isize.dim(2) - din.dim(1)) / stride.dim(1));
00605     idxdim osize(fsize, std::max((intg) 1, nh),
00606                  std::max((intg) 1, nw));
00607     fidxdim os = osize;
00608     isize = bprop_size(os);
00609     return osize;
00610   }
00611 
00612   template <typename T, class Tstate>
00613   fidxdim flat_merge_module<T,Tstate>::bprop_size(const fidxdim &osize) {
00614     //EDEBUG(this->name() << ": " << osize << " -> ...");
00615     // feature size for main input
00616     intg fsize = (intg) (osize.dim(0) / din.dim(0) / din.dim(1));
00617     // number of possible windows
00618     intg ih = (intg) (((osize.dim(1) - 1) * stride.dim(0)) + din.dim(0));
00619     intg iw = (intg) (((osize.dim(2) - 1) * stride.dim(1)) + din.dim(1));
00620     // extract its dimensions, update output size
00621     fidxdim isize(fsize, ih, iw);
00622     // set offsets
00623     for (uint j = 1; j < isize.order(); ++j)
00624       isize.setoffset(j, (intg) (osize.offset(j) * stride.dim(j - 1)));
00625     return isize;
00626   }
00627 
00628   template <typename T, class Tstate>
00629   mfidxdim flat_merge_module<T,Tstate>::bprop_size(mfidxdim &osize) {
00630     //EDEBUG(this->name() << ": " << osize << " -> ...");
00631     if (osize.size() == 0) eblerror("expected at least 1 idxdim");
00632     mfidxdim isize;
00633     idxdim pa, d;
00634     fidxdim s;
00635     // all inputs
00636     //float pix1 = .5;
00637     for (uint i = 0; i < dins.size(); ++i) {
00638       if (!osize.exists(0)) {
00639         isize.push_back_empty();
00640         continue ;
00641       }
00642       idxdim o0 = osize[0];
00643       d = dins[i];
00644       s = strides[i];
00645       if (bpad) {
00646         //if (i % 4 == 0) pix1 *= 2; // TODO: get from user
00647         //float pix1 = scales[i].dim(0);
00648         //      idxdim d4 = compute_pad(d, subsampling, edge, pix1, s);
00649         if (i < paddings.size()) {
00650           fidxdim fpads = paddings[i];
00651           pa = fpads;
00652         }
00653         if (i < offsets.size()) {
00654           vector<int> &off = offsets[i];
00655           d.setoffset(0, (int) (-off[0] / s.dim(0)));
00656           d.setoffset(1, (int) (-off[1] / s.dim(1)));
00657         }
00658 
00659         // idxdim d4 = compute_pad(d, 1, 0, pix1, s);
00660         //idxdim d4 = compute_pad(d, 4, 6, pix1, s);
00661         //      idxdim d4 = compute_pad(d, 6, 12, pix1, s);
00662         //pa = padder.get_paddings(d4);
00663 
00664         // TMP
00665         // intg sh = (intg) (4 * s.dim(0));
00666         // intg sw = (intg) (4 * s.dim(1));
00667         // pa.setdim(0, pa.dim(0) + sh);
00668         // pa.setdim(1, pa.dim(1) + sh);
00669 
00670       }
00671       d.insert_dim(0, o0.dim(0)); // add feature dimension
00672       // set offsets
00673       fidxdim fd(d);
00674       for (uint j = 1; j < d.order(); ++j) {
00675         float o = (o0.offset(j) + d.offset(j)) * s.dim(j - 1);
00676         if (j-1 < pa.order()) o -= pa.dim(j-1);
00677         fd.setoffset(j, o);
00678       }
00679       isize.push_back(fd);
00680     }
00681     //EDEBUG(this->name() << ": " << osize << " -> " << isize);
00682     return isize;
00683   }
00684 
00685   template <typename T, class Tstate>
00686   std::string flat_merge_module<T, Tstate>::describe() {
00687     std::string desc;
00688     desc << "flat_merge module " << this->name() << ", merging "
00689          << (int) dins.size() << " inputs: ";
00690     for (uint i = 0; i < dins.size(); ++i) {
00691       desc << " (in " << dins[i] << " stride " << strides[i];
00692       if (i < scales.size()) desc << " scale " << scales[i];
00693       desc << "), ";
00694     }
00695     if (bpad)
00696       desc << ", inputs are padded to center windows on borders";
00697     return desc;
00698   }
00699 
00700   template <typename T, class Tstate>
00701   uint flat_merge_module<T, Tstate>::get_ninputs() {
00702     return (uint) dins.size();
00703   }
00704 
00705   template <typename T, class Tstate>
00706   mfidxdim flat_merge_module<T, Tstate>::get_strides() {
00707     return strides;
00708   }
00709 
00710   template <typename T, class Tstate>
00711   mfidxdim flat_merge_module<T, Tstate>::get_scales() {
00712     return scales;
00713   }
00714 
00716   // mstate_merge_module
00717 
00718   template <typename T, class Tstate>
00719   mstate_merge_module<T, Tstate>::
00720   mstate_merge_module(midxdim &ins, mfidxdim &strides, const char *name_)
00721     : module_1_1<T,Tstate>(name_), dins(ins), dstrides(strides) {
00722   }
00723 
00724   template <typename T, class Tstate>
00725   mstate_merge_module<T, Tstate>::~mstate_merge_module() {
00726   }
00727 
00729   // multi-state methods
00730 
00731   template <typename T, class Tstate>
00732   void mstate_merge_module<T, Tstate>::
00733   fprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00734     // use state 0 as base for sizes
00735     Tstate &in0 = in[0];
00736     Tstate o0 = out[0];
00737     idxdim &d0 = dins[0];
00738     fidxdim &s0 = dstrides[0];
00739     // number of possible windows
00740     intg nh = (intg) (1 + (in0.x.dim(1) - d0.dim(0)) / s0.dim(0));
00741     intg nw = (intg) (1 + (in0.x.dim(2) - d0.dim(1)) / s0.dim(1));
00742     // compute new size and resize output if necessary
00743     intg fsize = 0;
00744     for (uint i = 0; i < dins.size(); ++i) {
00745       idxdim &d = dins[i];
00746       fidxdim &s = dstrides[i];
00747       Tstate &tin = in[i];
00748       fsize += d.nelements() * tin.x.dim(0);
00749       // check that strides match possible windows
00750       if (tin.x.dim(1) / s.dim(0) != nh || tin.x.dim(2) / s.dim(1) != nw)
00751         eblerror("input " << tin.x << " with stride " << s
00752                  << " does not produce " << nh << "x" << nw << " windows");
00753     }
00754     // resize output (only 1 state)
00755     idxdim d(fsize, nh, nw);
00756     if (out.size() != 1) {
00757       out.clear();
00758       out.push_back(d);
00759     } else {
00760       if (!o0.x.same_dim(d))
00761         o0.resize(d);
00762     }
00763     intg offset = 0;
00764     // copy all inputs to outputs
00765     for (uint i = 0; i < dins.size(); ++i) {
00766       idxdim &d = dins[i];
00767       fidxdim &s = dstrides[i];
00768       Tstate &tin = in[i];
00769       // feature size for this state
00770       fsize = d.nelements() * in0.x.dim(0);
00771       // loop on all possible windows for this state
00772       float fh, fw;
00773       uint uh, uw, h, w;
00774       for (h = 0, fh = 0; h < nh; h++, fh += s.dim(0)) {
00775         for (w = 0, fw = 0; w < nw; w++, fw += s.dim(1)) {
00776           // integer positions
00777           uh = (uint) h;
00778           uw = (uint) w;
00779           // select 1 output pixel in the corect feature range
00780           idx<T> o = o0.x.select(2, w);
00781           o = o.select(1, h);
00782           o = o.narrow(0, fsize, offset);
00783           // select input window
00784           idx<T> iw = tin.x.select(2, uw);
00785           iw = iw.select(1, uh);
00786           // copy flat input to output
00787           // TODO: tmp buffer less efficient than direct copy which but requires
00788           // continuous data, make idx pointing to oo with flat's dims?
00789           idx<T> tmp(iw.get_idxdim());
00790           idx_copy(iw, tmp);
00791           iw = tmp.view_as_order(1);
00792           idx_copy(iw, o);
00793         }
00794       }
00795       offset += fsize;
00796     }
00797 #ifdef __DEBUG_PRINT__
00798     // cout << describe() << ": " << in.x << " (in " << din
00799     //   << " stride " << stride << ")";
00800     // for (uint i = 0; i < inputs.size(); ++i)
00801     //   cout << " + " << (*inputs[i])->x << " (in " << dins[i]
00802     //     << " stride " << strides[i] << ")";
00803     // cout << " -> " << out.x << endl;
00804 #endif
00805   }
00806 
00807   template <typename T, class Tstate>
00808   void mstate_merge_module<T, Tstate>::
00809   bprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00810     // expect only 1 state in output
00811     if (out.size() != 1)
00812       eblerror("expected only 1 state in output but found " << out.size());
00813     Tstate &to = out[0];
00814     idx<T> o = to.dx.view_as_order(1);
00815     // copy out to inputs
00816     intg offset = 0;
00817     for (uint i = 0; i < in.size(); ++i) {
00818       Tstate &tin = in[i];
00819       idx<T> ii = tin.dx.view_as_order(1);
00820       idx<T> oo = o.narrow(0, ii.nelements(), offset);
00821       idx_add(oo, ii, ii);
00822       offset += ii.nelements();
00823     }
00824   }
00825 
00826   template <typename T, class Tstate>
00827   void mstate_merge_module<T, Tstate>::
00828   bbprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00829     // expect only 1 state in output
00830     if (out.size() != 1)
00831       eblerror("expected only 1 state in output but found " << out.size());
00832     Tstate &to = out[0];
00833     idx<T> o = to.ddx.view_as_order(1);
00834     // copy out to inputs
00835     intg offset = 0;
00836     for (uint i = 0; i < in.size(); ++i) {
00837       Tstate &tin = in[i];
00838       idx<T> ii = tin.ddx.view_as_order(1);
00839       idx<T> oo = o.narrow(0, ii.nelements(), offset);
00840       idx_add(oo, ii, ii);
00841       offset += ii.nelements();
00842     }
00843   }
00844 
00846 
00847   template <typename T, class Tstate>
00848   idxdim mstate_merge_module<T,Tstate>::fprop_size(idxdim &isize) {
00849     // use state 0 as base for sizes
00850     idxdim &d0 = dins[0];
00851     fidxdim &s0 = dstrides[0];
00852     // number of possible windows
00853     intg nh = (intg) (1 + (isize.dim(1) - d0.dim(0)) / s0.dim(0));
00854     intg nw = (intg) (1 + (isize.dim(2) - d0.dim(1)) / s0.dim(1));
00855     // compute new size and resize output if necessary
00856     intg fsize = 0;
00857     for (uint i = 0; i < dins.size(); ++i) {
00858       idxdim &d = dins[i];
00859       fsize += d.nelements() * isize.dim(0);
00860     }
00862     idxdim osize(fsize, std::max((intg) 1, nh),
00863                  std::max((intg) 1, nw));
00864     fidxdim os = osize;
00865     isize = bprop_size(os);
00866     return osize;
00867   }
00868 
00869   template <typename T, class Tstate>
00870   fidxdim mstate_merge_module<T,Tstate>::bprop_size(const fidxdim &osize) {
00871     // use state 0 as base for sizes
00872     fidxdim &d0 = dins[0];
00873     fidxdim &s0 = dstrides[0];
00874     // number of possible windows
00875     intg ih = (intg) (((osize.dim(1) - 1) * s0.dim(0)) + d0.dim(0));
00876     intg iw = (intg) (((osize.dim(2) - 1) * s0.dim(1)) + d0.dim(1));
00877     // compute new size and resize output if necessary
00878     intg fsize = osize.dim(0) / d0.dim(0) / d0.dim(1);
00880     fidxdim isize(fsize, ih, iw);
00881     return isize;
00882   }
00883 
00884   template <typename T, class Tstate>
00885   std::string mstate_merge_module<T, Tstate>::describe() {
00886     std::string desc;
00887     desc << "mstate_merge module " << this->name() << ", merging states ";
00888     for (uint i = 0; i < dins.size(); ++i)
00889       desc << " (in " << dins[i] << " stride " << dstrides[i] << "), ";
00890     return desc;
00891   }
00892 
00894   // merge
00895 
00896   template <typename T, class Tstate>
00897   merge_module<T, Tstate>::merge_module(std::vector<Tstate**> &ins,
00898                                         intg concat_dim_,
00899                                         const char *name_,
00900                                         const char *list)
00901     : module_1_1<T,Tstate>(name_), inputs(ins), concat_dim(concat_dim_),
00902       merge_list(list) {
00903     // for (uint i = 0; i < ins.size(); ++i)
00904     //   inputs.push_back(ins[i]);
00905   }
00906 
00907   template <typename T, class Tstate>
00908   merge_module<T, Tstate>::merge_module(std::vector<mstate<Tstate>**> &ins,
00909                                         intg concat_dim_,
00910                                         const char *name_,
00911                                         const char *list)
00912     : module_1_1<T,Tstate>(name_), msinputs(ins), merge_list(list),
00913       concat_dim(concat_dim_) {
00914     eblerror("not implemented");
00915   }
00916 
00917   template <typename T, class Tstate>
00918   merge_module<T, Tstate>::
00919   merge_module(std::vector<std::vector<uint> > &states, intg concat_dim_,
00920                const char *name_)
00921     : module_1_1<T,Tstate>(name_), states_list(states), concat_dim(concat_dim_){
00922   }
00923 
00924   template <typename T, class Tstate>
00925   merge_module<T, Tstate>::~merge_module() {
00926   }
00927 
00928   template <typename T, class Tstate>
00929   void merge_module<T, Tstate>::fprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00930     if (states_list.size() == 0) eblerror("expected non-empty states_list");
00931     // resize out if necessary
00932     out.resize(in, states_list.size());
00933     // loop on each merging
00934     for (uint i = 0; i < states_list.size(); ++i) {
00935       vector<uint> ids = states_list[i];
00936       mstate<Tstate> mi;
00937       // create multi-state of states to merge
00938       for (uint j = 0; j < ids.size(); ++j) {
00939         uint id = ids[j];
00940         if (id >= in.size())
00941           eblerror("trying to access state " << id << " but multi-state only "
00942                    << "contains " << in.size() << " states: " << in);
00943         mi.push_back(new Tstate(in[id]));
00944       }
00945       EDEBUG("merging states with ids " << ids << ": " << mi);
00946       // merge them
00947       merge(mi, out[i]);
00948     }
00949   }
00950 
00951   template <typename T, class Tstate>
00952   void merge_module<T, Tstate>::bprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00953     // TODO: implement
00954   }
00955 
00956   template <typename T, class Tstate>
00957   void merge_module<T, Tstate>::bbprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00958     // TODO: implement
00959   }
00960 
00961   template <typename T, class Tstate>
00962   void merge_module<T, Tstate>::fprop(Tstate &in, Tstate &out) {
00963     idxdim d(in.x), dtmp(in.x);
00964     // check that all inputs are compatible and compute output size
00965     for (uint i = 0; i < inputs.size(); ++i) {
00966       Tstate *input = *(inputs[i]);
00967       dtmp.setdim(concat_dim, input->x.dim(concat_dim));
00968       if (!input->x.same_dim(dtmp))
00969         eblerror("expected same dimensions but got " << input->x.get_idxdim()
00970                  << " and " << dtmp);
00971       // increment dimension
00972       d.setdim(concat_dim, d.dim(concat_dim) + input->x.dim(concat_dim));
00973     }
00974     // check that output has the right size, if not, resize
00975     if (out.x.get_idxdim() != d)
00976       out.resize(d);
00977     // copy main input to out
00978     intg offset = 0;
00979     idx<T> o = out.x.narrow(concat_dim, in.x.dim(concat_dim), offset);
00980     idx_copy(in.x, o);
00981     offset += in.x.dim(concat_dim);
00982     // copy inputs to out
00983     for (uint i = 0; i < inputs.size(); ++i) {
00984       Tstate *input = *(inputs[i]);
00985       o = out.x.narrow(concat_dim, input->x.dim(concat_dim), offset);
00986       idx_copy(input->x, o);
00987       offset += input->x.dim(concat_dim);
00988     }
00989 #ifdef __DEBUG_PRINT__
00990     cout << describe() << ": " << in.x;
00991     for (uint i = 0; i < inputs.size(); ++i)
00992       cout << " + " << (*inputs[i])->x;
00993     cout << " -> " << out.x << endl;
00994 #endif
00995   }
00996 
00997   template <typename T, class Tstate>
00998   void merge_module<T, Tstate>::bprop(Tstate &in, Tstate &out) {
00999     // TODO: implement
01000   }
01001 
01002   template <typename T, class Tstate>
01003   void merge_module<T, Tstate>::bbprop(Tstate &in, Tstate &out) {
01004     // TODO: implement
01005   }
01006 
01007 //   template <typename T, class Tstate>
01008 //   void merge_module<T, Tstate>::fprop(Tstate &in, Tstate &out) {
01009 //     idxdim d(in.x), dtmp(in.x);
01010 //     // check that all inputs are compatible and compute output size
01011 //     for (uint i = 0; i < inputs.size(); ++i) {
01012 //       Tstate *input = *(inputs[i]);
01013 //       dtmp.setdim(concat_dim, input->x.dim(concat_dim));
01014 //       if (!input->x.same_dim(dtmp))
01015 //      eblerror("expected same dimensions but got " << input->x.get_idxdim()
01016 //               << " and " << dtmp);
01017 //       // increment dimension
01018 //       d.setdim(concat_dim, d.dim(concat_dim) + input->x.dim(concat_dim));
01019 //     }
01020 //     // check that output has the right size, if not, resize
01021 //     if (out.x.get_idxdim() != d)
01022 //       out.resize(d);
01023 //     // copy main input to out
01024 //     intg offset = 0;
01025 //     idx<T> o = out.x.narrow(concat_dim, in.x.dim(concat_dim), offset);
01026 //     idx_copy(in.x, o);
01027 //     offset += in.x.dim(concat_dim);
01028 //     // copy inputs to out
01029 //     for (uint i = 0; i < inputs.size(); ++i) {
01030 //       Tstate *input = *(inputs[i]);
01031 //       o = out.x.narrow(concat_dim, input->x.dim(concat_dim), offset);
01032 //       idx_copy(input->x, o);
01033 //       offset += input->x.dim(concat_dim);
01034 //     }
01035 // #ifdef __DEBUG__
01036 //     cout << describe() << ": " << in.x;
01037 //     for (uint i = 0; i < inputs.size(); ++i)
01038 //       cout << " + " << (*inputs[i])->x;
01039 //     cout << " -> " << out.x << endl;
01040 // #endif
01041 //   }
01042 
01043   // template <typename T, class Tstate>
01044   // void merge_module<T, Tstate>::bprop(Tstate &in, Tstate &out) {
01045   //   // copy out to main input
01046   //   intg offset = 0;
01047   //   idx<T> o = out.dx.narrow(concat_dim, in.dx.dim(concat_dim), offset);
01048   //   idx_add(o, in.dx, in.dx);
01049   //   offset += in.dx.dim(concat_dim);
01050   //   // copy out to inputs
01051   //   for (uint i = 0; i < inputs.size(); ++i) {
01052   //     Tstate *input = *(inputs[i]);
01053   //     o = out.dx.narrow(concat_dim, input->dx.dim(concat_dim), offset);
01054   //     idx_add(o, input->dx, input->dx);
01055   //     offset += input->dx.dim(concat_dim);
01056   //   }
01057   // }
01058 
01059   // template <typename T, class Tstate>
01060   // void merge_module<T, Tstate>::bbprop(Tstate &in,
01061   //                                    Tstate &out) {
01062   //   // copy out to main input
01063   //   intg offset = 0;
01064   //   idx<T> o = out.ddx.narrow(concat_dim, in.ddx.dim(concat_dim), offset);
01065   //   idx_add(o, in.ddx, in.ddx);
01066   //   offset += in.ddx.dim(concat_dim);
01067   //   // copy out to inputs
01068   //   for (uint i = 0; i < inputs.size(); ++i) {
01069   //     Tstate *input = *(inputs[i]);
01070   //     o = out.ddx.narrow(concat_dim, input->ddx.dim(concat_dim), offset);
01071   //     idx_add(o, input->ddx, input->ddx);
01072   //     offset += input->ddx.dim(concat_dim);
01073   //   }
01074   // }
01075 
01076   template <typename T, class Tstate>
01077   std::string merge_module<T, Tstate>::describe() {
01078     std::string desc;
01079     desc << "merge module " << this->name();
01080     return desc;
01081   }
01082 
01083   template <typename T, class Tstate>
01084   void merge_module<T, Tstate>::merge(mstate<Tstate> &in, Tstate &out) {
01085     if (in.size() == 0) eblerror("expected at least 1 state in input");
01086     idxdim d(in[0].x), dtmp(in[0].x);
01087     // check that all inputs are compatible and compute output size
01088     for (uint i = 1; i < in.size(); ++i) {
01089       Tstate &s = in[i];
01090       dtmp.setdim(concat_dim, s.x.dim(concat_dim));
01091       if (!s.x.same_dim(dtmp))
01092         eblerror("expected same dimensions but got " << s.x.get_idxdim()
01093                  << " and " << dtmp);
01094       // increment dimension
01095       d.setdim(concat_dim, d.dim(concat_dim) + s.x.dim(concat_dim));
01096     }
01097     // check that output has the right size, if not, resize
01098     if (out.x.get_idxdim() != d) out.resize(d);
01099     // copy inputs to out
01100     intg offset = 0;
01101     idx<T> o;
01102     for (uint i = 0; i < in.size(); ++i) {
01103       Tstate &s = in[i];
01104       o = out.x.narrow(concat_dim, s.x.dim(concat_dim), offset);
01105       idx_copy(s.x, o);
01106       offset += s.x.dim(concat_dim);
01107     }
01108 #ifdef __DEBUG_PRINT__
01109     cout << describe() << ": " << in[0].x;
01110     for (uint i = 1; i < in.size(); ++i) cout << " + " << in[i].x;
01111     cout << " -> " << out.x << endl;
01112 #endif
01113   }
01114 
01116   // interlace
01117 
01118   template <typename T, class Tstate>
01119   interlace_module<T, Tstate>::interlace_module(uint stride_, const char *name_)
01120     : module_1_1<T,Tstate>(name_), stride(stride_) {
01121   }
01122 
01123   template <typename T, class Tstate>
01124   interlace_module<T, Tstate>::~interlace_module() {
01125   }
01126 
01127   template <typename T, class Tstate>
01128   void interlace_module<T, Tstate>::
01129   fprop(mstate<Tstate> &in, mstate<Tstate> &out) {
01130     if (in.size() % stride != 0)
01131       eblerror("expected number of states to be a multiple of " << stride
01132                << " but got: " << in);
01133     out.clear();
01134     // interlace
01135     for (uint i = 0; i < stride; ++i) {
01136       for (uint j = 0; j < in.size() / stride; ++j) {
01137         out.push_back(in[j * stride + i]);
01138       }
01139     }
01140     EDEBUG(this->name() << ": " << in << " -> " << out);
01141   }
01142 
01143   template <typename T, class Tstate>
01144   void interlace_module<T, Tstate>::
01145   bprop(mstate<Tstate> &in, mstate<Tstate> &out) {
01146     not_implemented();
01147   }
01148 
01149   template <typename T, class Tstate>
01150   void interlace_module<T, Tstate>::
01151     bbprop(mstate<Tstate> &in, mstate<Tstate> &out) {
01152     not_implemented();
01153   }
01154 
01155   template <typename T, class Tstate>
01156   mfidxdim interlace_module<T,Tstate>::bprop_size(mfidxdim &osize) {
01157     if (osize.size() % stride != 0) {
01158       eblwarn(this->name() << ": expected midxdim size to be a multiple of "
01159               << stride << " but got " << osize);
01160       return osize;
01161     }
01162     mfidxdim isize;
01163     uint step = osize.size() / stride;
01164     // interlace
01165     for (uint i = 0; i < step; ++i) {
01166       for (uint j = 0; j < stride; ++j) {
01167         if (osize.exists(j * step + i))
01168           isize.push_back(osize[j * step + i]);
01169         else
01170           isize.push_back_empty();
01171       }
01172     }
01173     EDEBUG(this->name() << ": " << osize << " -> " << isize);
01174     return isize;
01175   }
01176 
01177   template <typename T, class Tstate>
01178   std::string interlace_module<T, Tstate>::describe() {
01179     std::string desc;
01180     desc << "interlacing module " << this->name() << " with stride " << stride;
01181     return desc;
01182   }
01183 
01184   template <typename T, class Tstate>
01185   interlace_module<T,Tstate>* interlace_module<T,Tstate>::copy() {
01186     interlace_module<T,Tstate> *l2 =
01187       new interlace_module<T,Tstate>(stride, this->name());
01188     return l2;
01189   }
01190 
01191 } // end namespace ebl