libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 namespace ebl { 00034 00036 // flat_merge_module 00037 00038 template <typename T, class Tstate> 00039 flat_merge_module<T, Tstate>:: 00040 flat_merge_module(std::vector<Tstate**> &inputs_, idxdim &in_, midxdim &ins_, 00041 fidxdim &stride_, mfidxdim &strides_, 00042 const char *name_, const char *list) 00043 : m2s_module<T,Tstate>(inputs_.size() + 1, name_), inputs(inputs_), 00044 din(in_), dins(ins_), stride(stride_), strides(strides_), in0(NULL), 00045 use_pinputs(false) { 00046 if (list) 00047 merge_list = list; 00048 // allocate zpad vector 00049 zpads.assign(strides.size(), NULL); 00050 } 00051 00052 template <typename T, class Tstate> 00053 flat_merge_module<T, Tstate>:: 00054 flat_merge_module(std::vector<mstate<Tstate>**> &inputs_, 00055 idxdim &in_, midxdim &ins_, fidxdim &stride_, 00056 mfidxdim &strides_, const char *name_, const char *list) 00057 : m2s_module<T,Tstate>(inputs_.size() + 1, name_), //msinputs(inputs_), 00058 din(in_), dins(ins_), stride(stride_), strides(strides_), 00059 use_pinputs(false) { 00060 if (list) 00061 merge_list = list; 00062 // allocate zpad vector 00063 zpads.assign(strides.size(), NULL); 00064 eblerror("not implemented"); 00065 } 00066 00067 template <typename T, class Tstate> 00068 flat_merge_module<T, Tstate>:: 00069 flat_merge_module(midxdim &ins_, mfidxdim &strides_, bool bpad_, 00070 const char *name_, mfidxdim *scales_, 00071 /*TEMP*/ intg hextra_, intg wextra_, float ss_, float edge_) 00072 : m2s_module<T,Tstate>(ins_.size(), name_), in0(NULL), use_pinputs(true), 00073 bpad(bpad_), 00074 //TEMP 00075 hextra(hextra_), wextra(wextra_), subsampling(ss_), edge(edge_) 00076 { 00077 if (scales_) 00078 scales = *scales_; 00079 // check there are enough elements 00080 if (ins_.size() < 1 || strides_.size() < 1 00081 || ins_.size() != strides_.size()) 00082 eblerror("expected at least 1 dim and stride (matching) but got: dims: " 00083 << ins_.size() << " strides: " << strides_.size()); 00084 // separate first dim/strides from rest 00085 din = ins_[0]; 00086 stride = strides_[0]; 00087 // add remaining ones 00088 for (uint i = 0; i < ins_.size(); ++i) { 00089 dins.push_back(ins_[i]); 00090 strides.push_back(strides_[i]); 00091 } 00092 // allocate zpad vector 00093 zpads.assign(strides.size(), NULL); 00094 } 00095 00096 template <typename T, class Tstate> 00097 flat_merge_module<T, Tstate>::~flat_merge_module() { 00098 // clean up zero padding modules 00099 for (uint i = 0; i < zpads.size(); ++i) 00100 if (zpads[i]) 00101 delete zpads[i]; 00102 } 00103 00105 // generic state methods 00106 00107 template <typename T, class Tstate> 00108 void flat_merge_module<T, Tstate>::fprop(Tstate &in, Tstate &out) { 00109 // pad each input so that all windows start centered on the first actual 00110 // pixel top right 00111 00112 // if (inputs.size() == 0) 00113 // eblerror("no inputs to merge"); 00114 // feature size for main input 00115 intg fsize = din.dim(0) * din.dim(1) * in.x.dim(0); 00116 // number of possible windows 00117 intg nh = 1 + (intg) ((in.x.dim(1) - din.dim(0)) / stride.dim(0)); 00118 intg nw = 1 + (intg) ((in.x.dim(2) - din.dim(1)) / stride.dim(1)); 00119 // compute new size and resize output if necessary 00120 for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) { 00121 idxdim &d = dins[i]; 00122 fidxdim &s = strides[i]; 00123 idx<T> input = (use_pinputs ? pinputs[i]->x : (*inputs[i])->x); 00124 fsize += d.nelements() * input.dim(0); 00125 // check that strides match possible windows 00126 intg nh2 = (intg) (1 + (input.dim(1) - d.dim(0)) / s.dim(0)); 00127 intg nw2 = (intg) (1 + (input.dim(2) - d.dim(1)) / s.dim(1)); 00128 if (nh2 < nh || nw2 < nw) { 00129 eblerror("input " << input << " and window " << d << " with stride " <<s 00130 << " produce " << nh2 << "x" << nw2 00131 << " outputs but expected at least " << nh << "x" << nw); 00132 } else if (nh2 != nh || nw2 != nw) 00133 EDEBUG("warning: input " << input << " and window " << d 00134 << " with stride " <<s << " produce " << nh2 << "x" << nw2 00135 << ", ignoring extra cells and using only " <<nh << "x" << nw); 00136 } 00137 idxdim d(fsize, nh, nw); 00138 if (!out.x.same_dim(d)) { 00139 if (out.x.order() != d.order()) 00140 out = Tstate(d); 00141 else 00142 out.resize(d); 00143 } 00144 intg offset = 0; 00145 // copy main input to out 00146 fsize = din.nelements() * in.x.dim(0); // feat size for main input 00147 // loop on all possible windows for this state 00148 float fh, fw; 00149 uint uh, uw, h, w; 00150 for (h = 0, fh = 0; h < nh; h++, fh += stride.dim(0)) { 00151 for (w = 0, fw = 0; w < nw; w++, fw += stride.dim(1)) { 00152 // integer positions 00153 uh = (uint) fh; 00154 uw = (uint) fw; 00155 // select 1 output pixel in the correct feature range 00156 idx<T> o = out.x.select(2, w); 00157 o = o.select(1, h); 00158 o = o.narrow(0, fsize, offset); 00159 // select input window 00160 idx<T> iw = in.x.narrow(2, din.dim(1), uw); 00161 iw = iw.narrow(1, din.dim(0), uh); 00162 // copy flat input to output 00163 // TODO: tmp buffer less efficient than direct copy which but requires 00164 // continuous data, make idx pointing to oo with flat's dims? 00165 idx<T> tmp(iw.get_idxdim()); 00166 idx_copy(iw, tmp); 00167 iw = tmp.view_as_order(1); 00168 idx_copy(iw, o); 00169 } 00170 } 00171 offset += fsize; 00172 // copy inputs to out 00173 for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) { 00174 idxdim dd = dins[i]; 00175 fidxdim s = strides[i]; 00176 idx<T> input = (use_pinputs ? pinputs[i]->x : (*inputs[i])->x); 00177 fsize = dd.nelements() * input.dim(0); // feature size from input 00178 // copy 00179 for (h = 0, fh = 0; h < nh; h++, fh += s.dim(0)) { 00180 for (w = 0, fw = 0; w < nw; w++, fw += s.dim(1)) { 00181 // integer positions 00182 uh = (uint) fh; 00183 uw = (uint) fw; 00184 // select 1 output pixel in the correct feature range 00185 idx<T> o = out.x.select(2, w); 00186 o = o.select(1, h); 00187 o = o.narrow(0, fsize, offset); 00188 // select input window 00189 idx<T> iw = input.narrow(2, dd.dim(1), uw); 00190 iw = iw.narrow(1, dd.dim(0), uh); 00191 // copy flat input to output 00192 // TODO: tmp buffer less efficient than direct copy which but requires 00193 // continuous data, make idx pointing to oo with flat's dims? 00194 idx<T> tmp(iw.get_idxdim()); 00195 idx_copy(iw, tmp); 00196 iw = tmp.view_as_order(1); 00197 idx_copy(iw, o); 00198 } 00199 } 00200 offset += fsize; 00201 } 00202 #ifdef __DEBUG_PRINT__ 00203 cout << describe() << ": " << in.x << " (in " << din 00204 << " stride " << stride << ")"; 00205 for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) 00206 cout << " + " << (use_pinputs ? pinputs[i]->x : (*inputs[i])->x) 00207 << " (in " << dins[i] << " stride " << strides[i] << ")"; 00208 cout << " -> " << out.x << endl; 00209 #endif 00210 } 00211 00212 template <typename T, class Tstate> 00213 void flat_merge_module<T, Tstate>::bprop(Tstate &in, Tstate &out) { 00214 // if (inputs.size() == 0) 00215 // eblerror("no inputs to merge"); 00216 // copy out to main input 00217 intg offset = 0; 00218 idx<T> o1 = out.dx.view_as_order(1); 00219 idx<T> o = o1.narrow(0, in.dx.nelements(), offset); 00220 idx<T> input = in.dx.view_as_order(1); 00221 idx_add(o, input, input); 00222 offset += input.nelements(); 00223 // copy out to inputs 00224 for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) { 00225 input = (use_pinputs ? pinputs[i]->dx : (*inputs[i])->dx); 00226 input = input.view_as_order(1); 00227 o = o1.narrow(0, input.nelements(), offset); 00228 idx_add(o, input, input); 00229 offset += input.nelements(); 00230 } 00231 } 00232 00233 template <typename T, class Tstate> 00234 void flat_merge_module<T, Tstate>::bbprop(Tstate &in, 00235 Tstate &out) { 00236 // if (inputs.size() == 0) 00237 // eblerror("no inputs to merge"); 00238 // copy out to main input 00239 intg offset = 0; 00240 idx<T> o1 = out.ddx.view_as_order(1); 00241 idx<T> o = o1.narrow(0, in.ddx.nelements(), offset); 00242 idx<T> input = in.ddx.view_as_order(1); 00243 idx_add(o, input, input); 00244 offset += input.nelements(); 00245 // copy out to inputs 00246 for (uint i = 0; i < std::max(inputs.size(), pinputs.size()); ++i) { 00247 input = (use_pinputs ? pinputs[i]->ddx : (*inputs[i])->ddx); 00248 input = input.view_as_order(1); 00249 o = o1.narrow(0, input.nelements(), offset); 00250 idx_add(o, input, input); 00251 offset += input.nelements(); 00252 } 00253 } 00254 00256 00257 template <typename T, class Tstate> 00258 idxdim flat_merge_module<T, Tstate>:: 00259 compute_pad(idxdim &window, float subsampling, float edge, 00260 float scale, fidxdim &stride) { 00261 float hoff = (edge * scale * stride.dim(0)) / subsampling + .5; 00262 float woff = (edge * scale * stride.dim(1)) / subsampling + .5; 00263 idxdim d = window; 00264 d.setdim(0, (int) (d.dim(0) + hoff * 2)); 00265 d.setdim(1, (int) (d.dim(1) + woff * 2)); 00266 return d; 00267 } 00268 00269 // template <typename T, class Tstate> 00270 // void flat_merge_module<T, Tstate>::set_paddings(mfidxdim &pads) { 00271 // paddings = pads; 00272 // EDEBUG(this->name() << ": setting paddings to " << paddings); 00273 // } 00274 00275 template <typename T, class Tstate> 00276 void flat_merge_module<T, Tstate>::set_offsets(vector<vector<int> > &off) { 00277 offsets = off; 00278 EDEBUG(this->name() << ": setting offsets to " << offsets); 00279 } 00280 00281 template <typename T, class Tstate> 00282 void flat_merge_module<T, Tstate>::set_strides(mfidxdim &s) { 00283 strides = s; 00284 cout << this->name() << ": setting strides to " << strides << endl; 00285 } 00286 00287 template <typename T, class Tstate> 00288 void flat_merge_module<T, Tstate>:: 00289 fprop(mstate<Tstate> &in, Tstate &out) { 00290 LOCAL_TIMING_START(); // profiling 00291 EDEBUG(this->name() << ": " << in << ", wins: " << dins << ", strides: " 00292 << strides << ", scales: " << scales << ", paddings: " << paddings); 00293 padded.resize(in); 00294 // strides.clear(); 00295 idxdim dref, dref2; 00296 // loop on each state 00297 //float pix1 = .5; 00298 for (uint k = 0; k < in.size(); ++k) { 00299 Tstate i = in[k]; 00300 Tstate &p = padded[k]; 00301 idxdim d = dins[k]; 00302 fidxdim &s = strides[k]; 00303 // pad each input so that all windows start centered on the first actual 00304 // pixel top right 00305 if (bpad) { 00306 // narrow input if specified 00307 if (k < offsets.size()) { 00308 vector<int> &off = offsets[k]; 00309 if (off.size() != 4) eblerror("expected 4"); 00310 int oh0 = (int) (off[0]); 00311 int ow0 = (int) (off[1]); 00312 int oh1 = (int) (off[2]); 00313 int ow1 = (int) (off[3]); 00314 // int oh0 = (int) (off[0] * s.dim(0)); 00315 // int ow0 = (int) (off[1] * s.dim(1)); 00316 // int oh1 = (int) (off[2] * s.dim(0)); 00317 // int ow1 = (int) (off[3] * s.dim(1)); 00318 if (oh0 < 0) { 00319 int sz2 = i.x.dim(1) + oh0 + oh1; 00320 if (sz2 <= 0) { 00321 eblwarn("trying to narrow dim 1 of " << i.x << " to size "<< sz2); 00322 } else { 00323 EDEBUG("narrowing height with offset " << -oh0 << " in " << i.x); 00324 i = i.narrow(1, sz2, -oh0); 00325 } 00326 } 00327 if (ow0 < 0) { 00328 int sz2 = i.x.dim(2) + ow0 + ow1; 00329 if (sz2 <= 0) { 00330 eblwarn("trying to narrow dim 2 of " << i.x << " to size "<< sz2); 00331 } else { 00332 EDEBUG("narrowing width with offset " << -ow0 << " in " << i.x); 00333 i = i.narrow(2, sz2, -ow0); 00334 } 00335 } 00336 } 00337 00338 if (scales.size() != strides.size()) 00339 eblerror("expected scales to be the same size as strides but got " 00340 << scales << " when strides are " << strides); 00341 //if (k % 4 == 0) pix1 *= 2; 00342 //float pix1 = scales[k].dim(0); 00343 //idxdim d4 = compute_pad(d, subsampling, edge, pix1, s); 00344 // fidxdim fpads; 00345 // if (k < paddings.size()) fpads = paddings[k]; 00346 // idxdim d4 = fpads; 00347 00348 // //idxdim d4 = compute_pad(d, 1, 0, pix1, s); 00349 // // idxdim d4 = compute_pad(d, 4, 6, pix1, s); 00350 // // idxdim d4 = compute_pad(d, 6, 12, pix1, s); 00351 // //padder.set_kernel(d4); 00352 // padder.set_paddings(d4); 00353 // // add extra padding at ends if necessary to match target 00354 // // intg hextra = 2, wextra = 2; 00355 // idxdim pads = padder.get_paddings(); 00356 // // if (k > 0) { 00357 // // hextra = (intg) (dref.dim(0) - ((i.x.dim(1) + pads.dim(0) + pads.dim(2) 00358 // // - (d.dim(0) -d.dim(0) % 2)) / s.dim(0))); 00359 // // wextra = (intg) (dref.dim(1) - ((i.x.dim(2) + pads.dim(1) + pads.dim(3) 00360 // // - (d.dim(1) -d.dim(1) % 2)) / s.dim(1))); 00361 // // } 00362 // if (k > 0) { 00363 // pads.setdim(2, pads.dim(2) + std::max(0, (int) hextra)); 00364 // pads.setdim(3, pads.dim(3) + std::max(0, (int) wextra)); 00365 // } 00366 // padder.set_paddings(pads); 00367 00368 // add padding if missing to reach target (failsafe) 00369 idxdim pads(0,0,0,0);// = padder.get_paddings(); 00370 intg sh = (intg) ((i.x.dim(1) + pads.dim(0) + pads.dim(2) 00371 - d.dim(0) + 1) / s.dim(0)); 00372 intg sw = (intg) ((i.x.dim(2) + pads.dim(1) + pads.dim(3) 00373 - d.dim(1) + 1) / s.dim(1)); 00374 bool w = false; 00375 if (k > 0 && sh < dref.dim(0)) { 00376 pads.setdim(2, pads.dim(2) + std::max(0, (int) (dref.dim(0) - sh))); 00377 w = true; 00378 } 00379 if (k > 0 && sw < dref.dim(1)) { 00380 pads.setdim(3, pads.dim(3) + std::max(0, (int) (dref.dim(1) - sw))); 00381 if (!this->silent && pads.maxdim() > 2) 00382 w = true; 00383 } 00384 if (w && !this->silent && pads.maxdim() > 2) 00385 eblwarn("adding extra padding "<<pads<<" to match target " << dref); 00386 00387 padder.set_paddings(pads); 00388 00389 // EDEBUG("before adding padding: " << pads); 00390 // // add fixed extra padding 00391 // pads = padder.get_paddings(); 00392 // sh = (intg) (4 * s.dim(0)); 00393 // sw = (intg) (4 * s.dim(1)); 00394 // pads.setdim(0, pads.dim(0) + sh); 00395 // pads.setdim(2, pads.dim(2) + sh); 00396 // pads.setdim(1, pads.dim(1) + sw); 00397 // pads.setdim(3, pads.dim(3) + sw); 00398 // padder.set_paddings(pads); 00399 // EDEBUG("after adding padding: " << pads); 00400 // pad 00401 padder.fprop(i, p); 00402 // } else { 00403 // if (i.x.contiguousp()) padded[k] = i; 00404 // else { 00405 // if (padded[k].x.get_idxdim() != i.x.get_idxdim()) 00406 // padded[k] = Tstate(i.x.get_idxdim()); 00407 // idx_copy(i.x, padded[k].x); 00408 // } 00409 // } 00410 } else padded[k] = in[k]; 00411 // compute number of outputs for this kernel 00412 idxdim dout((intg) ((p.x.dim(1) - d.dim(0) + 1) / s.dim(0)), 00413 (intg) ((p.x.dim(2) - d.dim(1) + 1) / s.dim(1))); 00414 idxdim dout2((intg) ((p.x.dim(1) - d.dim(0) + 1)), 00415 (intg) ((p.x.dim(2) - d.dim(1) + 1))); 00416 // idxdim dout((p.x.dim(1) - (d.dim(0) - d.dim(0) % 2)) / s.dim(0), 00417 // (p.x.dim(2) - (d.dim(1) - d.dim(1) % 2)) / s.dim(1)); 00418 // use 1st dout as reference 00419 if (k == 0) { 00420 dref = dout; 00421 dref2 = dout2; 00422 } 00423 EDEBUG(this->name() << ": in " << p.x << " (min: " << idx_min(p.x) 00424 << ", max: " << idx_max(p.x) << ") with window " << d 00425 << " and stride " << s << " -> " << dout); 00426 00427 00428 // // adjust strides so that all states produce dref outputs 00429 // fidxdim ss(1.0, 1.0); 00430 // if (k > 0) 00431 // ss = fidxdim(dout2.dim(0) / (float) dref2.dim(0), 00432 // dout2.dim(1) / (float) dref2.dim(1)); 00433 // strides[k] = ss; 00434 // EDEBUG("setting stride to " << ss << " for input " << p.x << " and window " 00435 // << d << " to produce " << dref << " outputs"); 00436 } 00437 LOCAL_TIMING_REPORT("merge padding"); 00438 00439 00440 // if (inputs.size() == 0) 00441 // eblerror("no inputs to merge"); 00442 // feature size for main input 00443 idx<T> &in0 = padded[0].x; 00444 intg fsize = din.dim(0) * din.dim(1) * in0.dim(0); 00445 // number of possible windows 00446 // intg nh = 1 + (intg) ((in0.dim(1) - din.dim(0)) / stride.dim(0)); 00447 // intg nw = 1 + (intg) ((in0.dim(2) - din.dim(1)) / stride.dim(1)); 00448 intg nh = dref.dim(0), nw = dref.dim(1); 00449 // compute new size and resize output if necessary 00450 for (uint i = 1; i < padded.size(); ++i) { 00451 idxdim &d = dins[i]; 00452 fidxdim &s = strides[i]; 00453 idx<T> &input = padded[i].x; 00454 fsize += d.nelements() * input.dim(0); 00455 // check that strides match possible windows 00456 intg nh2 = (intg) ceil((input.dim(1) - d.dim(0) + 1) 00457 / std::max(10e-9, (double) s.dim(0))); 00458 intg nw2 = (intg) ceil((input.dim(2) - d.dim(1) + 1) 00459 / std::max(10e-9, (double) s.dim(1))); 00460 if (nh2 < nh || nw2 < nw) { 00461 *(this->mout) << "COUT input " << input << " and window " << d << " with stride " <<s 00462 << " produce " << nh2 << "x" << nw2 00463 << " outputs but expected at least " << nh << "x" << nw << endl; 00464 eblerror("input " << input << " and window " << d << " with stride " <<s 00465 << " produce " << nh2 << "x" << nw2 00466 << " outputs but expected at least " << nh << "x" << nw); 00467 } else if (nh2 != nh || nw2 != nw) 00468 EDEBUG("warning: input " << input << " and window " << d << " with stride " <<s 00469 << " produce " << nh2 << "x" << nw2 00470 << ", ignoring extra cells and using only " <<nh << "x" << nw); 00471 EDEBUG("input " << i << " " << input << ", min " << idx_min(input) 00472 << " max " << idx_max(input)); 00473 } 00474 LOCAL_TIMING_REPORT("merge check"); 00475 idxdim d(fsize, nh, nw); 00476 if (!out.x.same_dim(d)) { 00477 if (out.x.order() != d.order()) 00478 out = Tstate(d); 00479 else 00480 out.resize(d); 00481 } 00482 LOCAL_TIMING_REPORT("merge resize"); 00483 idx_clear(out.x); 00484 LOCAL_TIMING_REPORT("merge clear of " << out.x); 00485 00486 intg offset = 0; 00487 int h = 0, w = 0; 00488 float fh, fw; 00489 uint uh = 0, uw = 0, uw0 = 0; 00490 idx<T> iw, ow, onarrowed, inarrowed; 00491 // copy inputs to out 00492 for (uint i = 0; i < padded.size(); ++i) { 00493 idxdim dd = dins[i]; 00494 intg dd0 = dd.dim(0), dd1 = dd.dim(1); 00495 fidxdim s = strides[i]; 00496 float s0 = s.dim(0), s1 = s.dim(1); 00497 idx<T> &input = padded[i].x; 00498 if (!input.contiguousp()) eblerror("expected contiguous"); 00499 fsize = dd.nelements() * input.dim(0); // feature size from input 00500 onarrowed = out.x.narrow(0, fsize, offset); 00501 h = 0; w = 0; 00502 intg wmod = onarrowed.mod(2); 00503 // copy 00504 for (h = 0, fh = 0; h < nh; h++, fh += s0) { 00505 uh = (uint) fh; 00506 // select 1 output pixel in the correct feature range 00507 ow = onarrowed.select(2, 0); 00508 ow = ow.select(1, h); 00509 inarrowed = input.narrow(1, dd0, uh); 00510 intg iwmod = inarrowed.mod(2); 00511 uw = 0; uw0 = 0; 00512 iw = inarrowed.narrow(2, dd1, uw); 00513 for (w = 0, fw = 0; w < nw; ++w, fw += s1) { 00514 // integer positions 00515 uw = (uint) fw; 00516 // select input window 00517 if (uw != uw0) 00518 iw.add_offset(iwmod); 00519 // copy flat input to output 00520 // TODO: tmp buffer less efficient than direct copy which but requires 00521 // continuous data, make idx pointing to oo with flat's dims? 00522 // idx<T> tmp(iw.get_idxdim()); 00523 // idx_copy(iw, tmp); 00524 // iw = tmp.view_as_order(1); 00525 idx_copy(iw, ow); 00526 00527 uw0 = uw; 00528 ow.add_offset(wmod); 00529 } 00530 } 00531 offset += fsize; 00532 } 00533 00534 LOCAL_TIMING_REPORT("merge copies"); 00535 00536 #ifdef __DEBUG_PRINT__ 00537 cout << describe() << ": " << in0 << " (in " << din 00538 << " stride " << stride << ")"; 00539 for (uint i = 1; i < padded.size(); ++i) 00540 cout << " + " << padded[i].x 00541 << " (in " << dins[i] << " stride " << strides[i] << ")"; 00542 cout << " -> " << out.x << endl; 00543 cout << "output min: " << idx_min(out.x) << " max: " << idx_max(out.x) 00544 << endl; 00545 #endif 00546 } 00547 00548 template <typename T, class Tstate> 00549 void flat_merge_module<T, Tstate>::bprop(mstate<Tstate> &in, Tstate &out) { 00550 idx<T> o, input, o1 = out.dx.view_as_order(1); 00551 intg offset = 0; 00552 // copy out to inputs 00553 for (uint i = 0; i < in.size(); ++i) { 00554 input = in[i].dx; 00555 input = input.view_as_order(1); 00556 o = o1.narrow(0, input.nelements(), offset); 00557 idx_add(o, input, input); 00558 offset += input.nelements(); 00559 } 00560 } 00561 00562 template <typename T, class Tstate> 00563 void flat_merge_module<T, Tstate>::bbprop(mstate<Tstate> &in, 00564 Tstate &out) { 00565 idx<T> o, input, o1 = out.ddx.view_as_order(1); 00566 intg offset = 0; 00567 // copy out to inputs 00568 for (uint i = 0; i < in.size(); ++i) { 00569 input = in[i].ddx; 00570 input = input.view_as_order(1); 00571 o = o1.narrow(0, input.nelements(), offset); 00572 idx_add(o, input, input); 00573 offset += input.nelements(); 00574 } 00575 } 00576 00577 // template <typename T, class Tstate> 00578 // void flat_merge_module<T, Tstate>::bprop(mstate<Tstate> &in, Tstate &out) { 00579 // bprop(*in0, out); 00580 // } 00581 00582 // template <typename T, class Tstate> 00583 // void flat_merge_module<T, Tstate>::bbprop(mstate<Tstate> &in, Tstate &out) { 00584 // bbprop(*in0, out); 00585 // } 00586 00587 template <typename T, class Tstate> 00588 flat_merge_module<T,Tstate>* flat_merge_module<T,Tstate>::copy() { 00589 flat_merge_module<T,Tstate> *l2 = 00590 new flat_merge_module<T,Tstate>(dins, strides, bpad, this->name(), 00591 &scales); 00592 return l2; 00593 } 00594 00596 00597 template <typename T, class Tstate> 00598 idxdim flat_merge_module<T,Tstate>::fprop_size(idxdim &isize) { 00599 // feature size for main input 00600 intg fsize = din.dim(0) * din.dim(1) * isize.dim(0); 00601 // number of possible windows 00602 intg nh = 1 + (intg) ((isize.dim(1) - din.dim(0)) / stride.dim(0)); 00603 intg nw = 1 + (intg) ((isize.dim(2) - din.dim(1)) / stride.dim(1)); 00605 idxdim osize(fsize, std::max((intg) 1, nh), 00606 std::max((intg) 1, nw)); 00607 fidxdim os = osize; 00608 isize = bprop_size(os); 00609 return osize; 00610 } 00611 00612 template <typename T, class Tstate> 00613 fidxdim flat_merge_module<T,Tstate>::bprop_size(const fidxdim &osize) { 00614 //EDEBUG(this->name() << ": " << osize << " -> ..."); 00615 // feature size for main input 00616 intg fsize = (intg) (osize.dim(0) / din.dim(0) / din.dim(1)); 00617 // number of possible windows 00618 intg ih = (intg) (((osize.dim(1) - 1) * stride.dim(0)) + din.dim(0)); 00619 intg iw = (intg) (((osize.dim(2) - 1) * stride.dim(1)) + din.dim(1)); 00620 // extract its dimensions, update output size 00621 fidxdim isize(fsize, ih, iw); 00622 // set offsets 00623 for (uint j = 1; j < isize.order(); ++j) 00624 isize.setoffset(j, (intg) (osize.offset(j) * stride.dim(j - 1))); 00625 return isize; 00626 } 00627 00628 template <typename T, class Tstate> 00629 mfidxdim flat_merge_module<T,Tstate>::bprop_size(mfidxdim &osize) { 00630 //EDEBUG(this->name() << ": " << osize << " -> ..."); 00631 if (osize.size() == 0) eblerror("expected at least 1 idxdim"); 00632 mfidxdim isize; 00633 idxdim pa, d; 00634 fidxdim s; 00635 // all inputs 00636 //float pix1 = .5; 00637 for (uint i = 0; i < dins.size(); ++i) { 00638 if (!osize.exists(0)) { 00639 isize.push_back_empty(); 00640 continue ; 00641 } 00642 idxdim o0 = osize[0]; 00643 d = dins[i]; 00644 s = strides[i]; 00645 if (bpad) { 00646 //if (i % 4 == 0) pix1 *= 2; // TODO: get from user 00647 //float pix1 = scales[i].dim(0); 00648 // idxdim d4 = compute_pad(d, subsampling, edge, pix1, s); 00649 if (i < paddings.size()) { 00650 fidxdim fpads = paddings[i]; 00651 pa = fpads; 00652 } 00653 if (i < offsets.size()) { 00654 vector<int> &off = offsets[i]; 00655 d.setoffset(0, (int) (-off[0] / s.dim(0))); 00656 d.setoffset(1, (int) (-off[1] / s.dim(1))); 00657 } 00658 00659 // idxdim d4 = compute_pad(d, 1, 0, pix1, s); 00660 //idxdim d4 = compute_pad(d, 4, 6, pix1, s); 00661 // idxdim d4 = compute_pad(d, 6, 12, pix1, s); 00662 //pa = padder.get_paddings(d4); 00663 00664 // TMP 00665 // intg sh = (intg) (4 * s.dim(0)); 00666 // intg sw = (intg) (4 * s.dim(1)); 00667 // pa.setdim(0, pa.dim(0) + sh); 00668 // pa.setdim(1, pa.dim(1) + sh); 00669 00670 } 00671 d.insert_dim(0, o0.dim(0)); // add feature dimension 00672 // set offsets 00673 fidxdim fd(d); 00674 for (uint j = 1; j < d.order(); ++j) { 00675 float o = (o0.offset(j) + d.offset(j)) * s.dim(j - 1); 00676 if (j-1 < pa.order()) o -= pa.dim(j-1); 00677 fd.setoffset(j, o); 00678 } 00679 isize.push_back(fd); 00680 } 00681 //EDEBUG(this->name() << ": " << osize << " -> " << isize); 00682 return isize; 00683 } 00684 00685 template <typename T, class Tstate> 00686 std::string flat_merge_module<T, Tstate>::describe() { 00687 std::string desc; 00688 desc << "flat_merge module " << this->name() << ", merging " 00689 << (int) dins.size() << " inputs: "; 00690 for (uint i = 0; i < dins.size(); ++i) { 00691 desc << " (in " << dins[i] << " stride " << strides[i]; 00692 if (i < scales.size()) desc << " scale " << scales[i]; 00693 desc << "), "; 00694 } 00695 if (bpad) 00696 desc << ", inputs are padded to center windows on borders"; 00697 return desc; 00698 } 00699 00700 template <typename T, class Tstate> 00701 uint flat_merge_module<T, Tstate>::get_ninputs() { 00702 return (uint) dins.size(); 00703 } 00704 00705 template <typename T, class Tstate> 00706 mfidxdim flat_merge_module<T, Tstate>::get_strides() { 00707 return strides; 00708 } 00709 00710 template <typename T, class Tstate> 00711 mfidxdim flat_merge_module<T, Tstate>::get_scales() { 00712 return scales; 00713 } 00714 00716 // mstate_merge_module 00717 00718 template <typename T, class Tstate> 00719 mstate_merge_module<T, Tstate>:: 00720 mstate_merge_module(midxdim &ins, mfidxdim &strides, const char *name_) 00721 : module_1_1<T,Tstate>(name_), dins(ins), dstrides(strides) { 00722 } 00723 00724 template <typename T, class Tstate> 00725 mstate_merge_module<T, Tstate>::~mstate_merge_module() { 00726 } 00727 00729 // multi-state methods 00730 00731 template <typename T, class Tstate> 00732 void mstate_merge_module<T, Tstate>:: 00733 fprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00734 // use state 0 as base for sizes 00735 Tstate &in0 = in[0]; 00736 Tstate o0 = out[0]; 00737 idxdim &d0 = dins[0]; 00738 fidxdim &s0 = dstrides[0]; 00739 // number of possible windows 00740 intg nh = (intg) (1 + (in0.x.dim(1) - d0.dim(0)) / s0.dim(0)); 00741 intg nw = (intg) (1 + (in0.x.dim(2) - d0.dim(1)) / s0.dim(1)); 00742 // compute new size and resize output if necessary 00743 intg fsize = 0; 00744 for (uint i = 0; i < dins.size(); ++i) { 00745 idxdim &d = dins[i]; 00746 fidxdim &s = dstrides[i]; 00747 Tstate &tin = in[i]; 00748 fsize += d.nelements() * tin.x.dim(0); 00749 // check that strides match possible windows 00750 if (tin.x.dim(1) / s.dim(0) != nh || tin.x.dim(2) / s.dim(1) != nw) 00751 eblerror("input " << tin.x << " with stride " << s 00752 << " does not produce " << nh << "x" << nw << " windows"); 00753 } 00754 // resize output (only 1 state) 00755 idxdim d(fsize, nh, nw); 00756 if (out.size() != 1) { 00757 out.clear(); 00758 out.push_back(d); 00759 } else { 00760 if (!o0.x.same_dim(d)) 00761 o0.resize(d); 00762 } 00763 intg offset = 0; 00764 // copy all inputs to outputs 00765 for (uint i = 0; i < dins.size(); ++i) { 00766 idxdim &d = dins[i]; 00767 fidxdim &s = dstrides[i]; 00768 Tstate &tin = in[i]; 00769 // feature size for this state 00770 fsize = d.nelements() * in0.x.dim(0); 00771 // loop on all possible windows for this state 00772 float fh, fw; 00773 uint uh, uw, h, w; 00774 for (h = 0, fh = 0; h < nh; h++, fh += s.dim(0)) { 00775 for (w = 0, fw = 0; w < nw; w++, fw += s.dim(1)) { 00776 // integer positions 00777 uh = (uint) h; 00778 uw = (uint) w; 00779 // select 1 output pixel in the corect feature range 00780 idx<T> o = o0.x.select(2, w); 00781 o = o.select(1, h); 00782 o = o.narrow(0, fsize, offset); 00783 // select input window 00784 idx<T> iw = tin.x.select(2, uw); 00785 iw = iw.select(1, uh); 00786 // copy flat input to output 00787 // TODO: tmp buffer less efficient than direct copy which but requires 00788 // continuous data, make idx pointing to oo with flat's dims? 00789 idx<T> tmp(iw.get_idxdim()); 00790 idx_copy(iw, tmp); 00791 iw = tmp.view_as_order(1); 00792 idx_copy(iw, o); 00793 } 00794 } 00795 offset += fsize; 00796 } 00797 #ifdef __DEBUG_PRINT__ 00798 // cout << describe() << ": " << in.x << " (in " << din 00799 // << " stride " << stride << ")"; 00800 // for (uint i = 0; i < inputs.size(); ++i) 00801 // cout << " + " << (*inputs[i])->x << " (in " << dins[i] 00802 // << " stride " << strides[i] << ")"; 00803 // cout << " -> " << out.x << endl; 00804 #endif 00805 } 00806 00807 template <typename T, class Tstate> 00808 void mstate_merge_module<T, Tstate>:: 00809 bprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00810 // expect only 1 state in output 00811 if (out.size() != 1) 00812 eblerror("expected only 1 state in output but found " << out.size()); 00813 Tstate &to = out[0]; 00814 idx<T> o = to.dx.view_as_order(1); 00815 // copy out to inputs 00816 intg offset = 0; 00817 for (uint i = 0; i < in.size(); ++i) { 00818 Tstate &tin = in[i]; 00819 idx<T> ii = tin.dx.view_as_order(1); 00820 idx<T> oo = o.narrow(0, ii.nelements(), offset); 00821 idx_add(oo, ii, ii); 00822 offset += ii.nelements(); 00823 } 00824 } 00825 00826 template <typename T, class Tstate> 00827 void mstate_merge_module<T, Tstate>:: 00828 bbprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00829 // expect only 1 state in output 00830 if (out.size() != 1) 00831 eblerror("expected only 1 state in output but found " << out.size()); 00832 Tstate &to = out[0]; 00833 idx<T> o = to.ddx.view_as_order(1); 00834 // copy out to inputs 00835 intg offset = 0; 00836 for (uint i = 0; i < in.size(); ++i) { 00837 Tstate &tin = in[i]; 00838 idx<T> ii = tin.ddx.view_as_order(1); 00839 idx<T> oo = o.narrow(0, ii.nelements(), offset); 00840 idx_add(oo, ii, ii); 00841 offset += ii.nelements(); 00842 } 00843 } 00844 00846 00847 template <typename T, class Tstate> 00848 idxdim mstate_merge_module<T,Tstate>::fprop_size(idxdim &isize) { 00849 // use state 0 as base for sizes 00850 idxdim &d0 = dins[0]; 00851 fidxdim &s0 = dstrides[0]; 00852 // number of possible windows 00853 intg nh = (intg) (1 + (isize.dim(1) - d0.dim(0)) / s0.dim(0)); 00854 intg nw = (intg) (1 + (isize.dim(2) - d0.dim(1)) / s0.dim(1)); 00855 // compute new size and resize output if necessary 00856 intg fsize = 0; 00857 for (uint i = 0; i < dins.size(); ++i) { 00858 idxdim &d = dins[i]; 00859 fsize += d.nelements() * isize.dim(0); 00860 } 00862 idxdim osize(fsize, std::max((intg) 1, nh), 00863 std::max((intg) 1, nw)); 00864 fidxdim os = osize; 00865 isize = bprop_size(os); 00866 return osize; 00867 } 00868 00869 template <typename T, class Tstate> 00870 fidxdim mstate_merge_module<T,Tstate>::bprop_size(const fidxdim &osize) { 00871 // use state 0 as base for sizes 00872 fidxdim &d0 = dins[0]; 00873 fidxdim &s0 = dstrides[0]; 00874 // number of possible windows 00875 intg ih = (intg) (((osize.dim(1) - 1) * s0.dim(0)) + d0.dim(0)); 00876 intg iw = (intg) (((osize.dim(2) - 1) * s0.dim(1)) + d0.dim(1)); 00877 // compute new size and resize output if necessary 00878 intg fsize = osize.dim(0) / d0.dim(0) / d0.dim(1); 00880 fidxdim isize(fsize, ih, iw); 00881 return isize; 00882 } 00883 00884 template <typename T, class Tstate> 00885 std::string mstate_merge_module<T, Tstate>::describe() { 00886 std::string desc; 00887 desc << "mstate_merge module " << this->name() << ", merging states "; 00888 for (uint i = 0; i < dins.size(); ++i) 00889 desc << " (in " << dins[i] << " stride " << dstrides[i] << "), "; 00890 return desc; 00891 } 00892 00894 // merge 00895 00896 template <typename T, class Tstate> 00897 merge_module<T, Tstate>::merge_module(std::vector<Tstate**> &ins, 00898 intg concat_dim_, 00899 const char *name_, 00900 const char *list) 00901 : module_1_1<T,Tstate>(name_), inputs(ins), concat_dim(concat_dim_), 00902 merge_list(list) { 00903 // for (uint i = 0; i < ins.size(); ++i) 00904 // inputs.push_back(ins[i]); 00905 } 00906 00907 template <typename T, class Tstate> 00908 merge_module<T, Tstate>::merge_module(std::vector<mstate<Tstate>**> &ins, 00909 intg concat_dim_, 00910 const char *name_, 00911 const char *list) 00912 : module_1_1<T,Tstate>(name_), msinputs(ins), merge_list(list), 00913 concat_dim(concat_dim_) { 00914 eblerror("not implemented"); 00915 } 00916 00917 template <typename T, class Tstate> 00918 merge_module<T, Tstate>:: 00919 merge_module(std::vector<std::vector<uint> > &states, intg concat_dim_, 00920 const char *name_) 00921 : module_1_1<T,Tstate>(name_), states_list(states), concat_dim(concat_dim_){ 00922 } 00923 00924 template <typename T, class Tstate> 00925 merge_module<T, Tstate>::~merge_module() { 00926 } 00927 00928 template <typename T, class Tstate> 00929 void merge_module<T, Tstate>::fprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00930 if (states_list.size() == 0) eblerror("expected non-empty states_list"); 00931 // resize out if necessary 00932 out.resize(in, states_list.size()); 00933 // loop on each merging 00934 for (uint i = 0; i < states_list.size(); ++i) { 00935 vector<uint> ids = states_list[i]; 00936 mstate<Tstate> mi; 00937 // create multi-state of states to merge 00938 for (uint j = 0; j < ids.size(); ++j) { 00939 uint id = ids[j]; 00940 if (id >= in.size()) 00941 eblerror("trying to access state " << id << " but multi-state only " 00942 << "contains " << in.size() << " states: " << in); 00943 mi.push_back(new Tstate(in[id])); 00944 } 00945 EDEBUG("merging states with ids " << ids << ": " << mi); 00946 // merge them 00947 merge(mi, out[i]); 00948 } 00949 } 00950 00951 template <typename T, class Tstate> 00952 void merge_module<T, Tstate>::bprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00953 // TODO: implement 00954 } 00955 00956 template <typename T, class Tstate> 00957 void merge_module<T, Tstate>::bbprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00958 // TODO: implement 00959 } 00960 00961 template <typename T, class Tstate> 00962 void merge_module<T, Tstate>::fprop(Tstate &in, Tstate &out) { 00963 idxdim d(in.x), dtmp(in.x); 00964 // check that all inputs are compatible and compute output size 00965 for (uint i = 0; i < inputs.size(); ++i) { 00966 Tstate *input = *(inputs[i]); 00967 dtmp.setdim(concat_dim, input->x.dim(concat_dim)); 00968 if (!input->x.same_dim(dtmp)) 00969 eblerror("expected same dimensions but got " << input->x.get_idxdim() 00970 << " and " << dtmp); 00971 // increment dimension 00972 d.setdim(concat_dim, d.dim(concat_dim) + input->x.dim(concat_dim)); 00973 } 00974 // check that output has the right size, if not, resize 00975 if (out.x.get_idxdim() != d) 00976 out.resize(d); 00977 // copy main input to out 00978 intg offset = 0; 00979 idx<T> o = out.x.narrow(concat_dim, in.x.dim(concat_dim), offset); 00980 idx_copy(in.x, o); 00981 offset += in.x.dim(concat_dim); 00982 // copy inputs to out 00983 for (uint i = 0; i < inputs.size(); ++i) { 00984 Tstate *input = *(inputs[i]); 00985 o = out.x.narrow(concat_dim, input->x.dim(concat_dim), offset); 00986 idx_copy(input->x, o); 00987 offset += input->x.dim(concat_dim); 00988 } 00989 #ifdef __DEBUG_PRINT__ 00990 cout << describe() << ": " << in.x; 00991 for (uint i = 0; i < inputs.size(); ++i) 00992 cout << " + " << (*inputs[i])->x; 00993 cout << " -> " << out.x << endl; 00994 #endif 00995 } 00996 00997 template <typename T, class Tstate> 00998 void merge_module<T, Tstate>::bprop(Tstate &in, Tstate &out) { 00999 // TODO: implement 01000 } 01001 01002 template <typename T, class Tstate> 01003 void merge_module<T, Tstate>::bbprop(Tstate &in, Tstate &out) { 01004 // TODO: implement 01005 } 01006 01007 // template <typename T, class Tstate> 01008 // void merge_module<T, Tstate>::fprop(Tstate &in, Tstate &out) { 01009 // idxdim d(in.x), dtmp(in.x); 01010 // // check that all inputs are compatible and compute output size 01011 // for (uint i = 0; i < inputs.size(); ++i) { 01012 // Tstate *input = *(inputs[i]); 01013 // dtmp.setdim(concat_dim, input->x.dim(concat_dim)); 01014 // if (!input->x.same_dim(dtmp)) 01015 // eblerror("expected same dimensions but got " << input->x.get_idxdim() 01016 // << " and " << dtmp); 01017 // // increment dimension 01018 // d.setdim(concat_dim, d.dim(concat_dim) + input->x.dim(concat_dim)); 01019 // } 01020 // // check that output has the right size, if not, resize 01021 // if (out.x.get_idxdim() != d) 01022 // out.resize(d); 01023 // // copy main input to out 01024 // intg offset = 0; 01025 // idx<T> o = out.x.narrow(concat_dim, in.x.dim(concat_dim), offset); 01026 // idx_copy(in.x, o); 01027 // offset += in.x.dim(concat_dim); 01028 // // copy inputs to out 01029 // for (uint i = 0; i < inputs.size(); ++i) { 01030 // Tstate *input = *(inputs[i]); 01031 // o = out.x.narrow(concat_dim, input->x.dim(concat_dim), offset); 01032 // idx_copy(input->x, o); 01033 // offset += input->x.dim(concat_dim); 01034 // } 01035 // #ifdef __DEBUG__ 01036 // cout << describe() << ": " << in.x; 01037 // for (uint i = 0; i < inputs.size(); ++i) 01038 // cout << " + " << (*inputs[i])->x; 01039 // cout << " -> " << out.x << endl; 01040 // #endif 01041 // } 01042 01043 // template <typename T, class Tstate> 01044 // void merge_module<T, Tstate>::bprop(Tstate &in, Tstate &out) { 01045 // // copy out to main input 01046 // intg offset = 0; 01047 // idx<T> o = out.dx.narrow(concat_dim, in.dx.dim(concat_dim), offset); 01048 // idx_add(o, in.dx, in.dx); 01049 // offset += in.dx.dim(concat_dim); 01050 // // copy out to inputs 01051 // for (uint i = 0; i < inputs.size(); ++i) { 01052 // Tstate *input = *(inputs[i]); 01053 // o = out.dx.narrow(concat_dim, input->dx.dim(concat_dim), offset); 01054 // idx_add(o, input->dx, input->dx); 01055 // offset += input->dx.dim(concat_dim); 01056 // } 01057 // } 01058 01059 // template <typename T, class Tstate> 01060 // void merge_module<T, Tstate>::bbprop(Tstate &in, 01061 // Tstate &out) { 01062 // // copy out to main input 01063 // intg offset = 0; 01064 // idx<T> o = out.ddx.narrow(concat_dim, in.ddx.dim(concat_dim), offset); 01065 // idx_add(o, in.ddx, in.ddx); 01066 // offset += in.ddx.dim(concat_dim); 01067 // // copy out to inputs 01068 // for (uint i = 0; i < inputs.size(); ++i) { 01069 // Tstate *input = *(inputs[i]); 01070 // o = out.ddx.narrow(concat_dim, input->ddx.dim(concat_dim), offset); 01071 // idx_add(o, input->ddx, input->ddx); 01072 // offset += input->ddx.dim(concat_dim); 01073 // } 01074 // } 01075 01076 template <typename T, class Tstate> 01077 std::string merge_module<T, Tstate>::describe() { 01078 std::string desc; 01079 desc << "merge module " << this->name(); 01080 return desc; 01081 } 01082 01083 template <typename T, class Tstate> 01084 void merge_module<T, Tstate>::merge(mstate<Tstate> &in, Tstate &out) { 01085 if (in.size() == 0) eblerror("expected at least 1 state in input"); 01086 idxdim d(in[0].x), dtmp(in[0].x); 01087 // check that all inputs are compatible and compute output size 01088 for (uint i = 1; i < in.size(); ++i) { 01089 Tstate &s = in[i]; 01090 dtmp.setdim(concat_dim, s.x.dim(concat_dim)); 01091 if (!s.x.same_dim(dtmp)) 01092 eblerror("expected same dimensions but got " << s.x.get_idxdim() 01093 << " and " << dtmp); 01094 // increment dimension 01095 d.setdim(concat_dim, d.dim(concat_dim) + s.x.dim(concat_dim)); 01096 } 01097 // check that output has the right size, if not, resize 01098 if (out.x.get_idxdim() != d) out.resize(d); 01099 // copy inputs to out 01100 intg offset = 0; 01101 idx<T> o; 01102 for (uint i = 0; i < in.size(); ++i) { 01103 Tstate &s = in[i]; 01104 o = out.x.narrow(concat_dim, s.x.dim(concat_dim), offset); 01105 idx_copy(s.x, o); 01106 offset += s.x.dim(concat_dim); 01107 } 01108 #ifdef __DEBUG_PRINT__ 01109 cout << describe() << ": " << in[0].x; 01110 for (uint i = 1; i < in.size(); ++i) cout << " + " << in[i].x; 01111 cout << " -> " << out.x << endl; 01112 #endif 01113 } 01114 01116 // interlace 01117 01118 template <typename T, class Tstate> 01119 interlace_module<T, Tstate>::interlace_module(uint stride_, const char *name_) 01120 : module_1_1<T,Tstate>(name_), stride(stride_) { 01121 } 01122 01123 template <typename T, class Tstate> 01124 interlace_module<T, Tstate>::~interlace_module() { 01125 } 01126 01127 template <typename T, class Tstate> 01128 void interlace_module<T, Tstate>:: 01129 fprop(mstate<Tstate> &in, mstate<Tstate> &out) { 01130 if (in.size() % stride != 0) 01131 eblerror("expected number of states to be a multiple of " << stride 01132 << " but got: " << in); 01133 out.clear(); 01134 // interlace 01135 for (uint i = 0; i < stride; ++i) { 01136 for (uint j = 0; j < in.size() / stride; ++j) { 01137 out.push_back(in[j * stride + i]); 01138 } 01139 } 01140 EDEBUG(this->name() << ": " << in << " -> " << out); 01141 } 01142 01143 template <typename T, class Tstate> 01144 void interlace_module<T, Tstate>:: 01145 bprop(mstate<Tstate> &in, mstate<Tstate> &out) { 01146 not_implemented(); 01147 } 01148 01149 template <typename T, class Tstate> 01150 void interlace_module<T, Tstate>:: 01151 bbprop(mstate<Tstate> &in, mstate<Tstate> &out) { 01152 not_implemented(); 01153 } 01154 01155 template <typename T, class Tstate> 01156 mfidxdim interlace_module<T,Tstate>::bprop_size(mfidxdim &osize) { 01157 if (osize.size() % stride != 0) { 01158 eblwarn(this->name() << ": expected midxdim size to be a multiple of " 01159 << stride << " but got " << osize); 01160 return osize; 01161 } 01162 mfidxdim isize; 01163 uint step = osize.size() / stride; 01164 // interlace 01165 for (uint i = 0; i < step; ++i) { 01166 for (uint j = 0; j < stride; ++j) { 01167 if (osize.exists(j * step + i)) 01168 isize.push_back(osize[j * step + i]); 01169 else 01170 isize.push_back_empty(); 01171 } 01172 } 01173 EDEBUG(this->name() << ": " << osize << " -> " << isize); 01174 return isize; 01175 } 01176 01177 template <typename T, class Tstate> 01178 std::string interlace_module<T, Tstate>::describe() { 01179 std::string desc; 01180 desc << "interlacing module " << this->name() << " with stride " << stride; 01181 return desc; 01182 } 01183 01184 template <typename T, class Tstate> 01185 interlace_module<T,Tstate>* interlace_module<T,Tstate>::copy() { 01186 interlace_module<T,Tstate> *l2 = 01187 new interlace_module<T,Tstate>(stride, this->name()); 01188 return l2; 01189 } 01190 01191 } // end namespace ebl