libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_lua.hpp
00001 /***************************************************************************
00002 *   Copyright (C) 2011 by Pierre Sermanet *
00003 *   pierre.sermanet@gmail.com *
00004 *   All rights reserved.
00005 *
00006 * Redistribution and use in source and binary forms, with or without
00007 * modification, are permitted provided that the following conditions are met:
00008 *     * Redistributions of source code must retain the above copyright
00009 *       notice, this list of conditions and the following disclaimer.
00010 *     * Redistributions in binary form must reproduce the above copyright
00011 *       notice, this list of conditions and the following disclaimer in the
00012 *       documentation and/or other materials provided with the distribution.
00013 *     * Redistribution under a license not approved by the Open Source
00014 *       Initiative (http://www.opensource.org) must display the
00015 *       following acknowledgement in all advertising material:
00016 *        This product includes software developed at the Courant
00017 *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018 *     * The names of the authors may not be used to endorse or promote products
00019 *       derived from this software without specific prior written permission.
00020 *
00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031 ***************************************************************************/
00032 
00033 #include <stdlib.h>
00034 
00035 #ifdef __LUA__
00036 
00037 #include "TH/TH.h"
00038 #undef inline
00039 #include <lua.h>
00040 #include <luaT.h>
00041 #include <lualib.h>
00042 #include <lauxlib.h>
00043 
00044 #ifdef __cplusplus 
00045 extern "C" {
00046 #endif
00047 
00048 extern lua_State *sl_stack;
00049 
00050 void sl_init(void);
00051 void sl_cleanup(void);
00052 int sl_dofile(const char *file);
00053 int sl_dostring(const char *string);
00054 lua_State * sl_getstack(void);
00055 void sl_callfunc(const char *funcname, int nargs, int nouts);
00056 
00057 #ifdef __cplusplus 
00058 }
00059 #endif
00060 
00061 static bool lua_init_done = false;
00062 #endif
00063 
00064 namespace ebl {
00065 
00067   // lua_module
00068 
00069   template <typename T, class Tstate>
00070   lua_module<T,Tstate>::lua_module(const char *script, const char *name_)
00071     : module_1_1<T,Tstate>(name_), script_fname(script), ninputs(0), noutputs(0) {
00072 #ifdef __LUA__
00073     // map neuflow
00074     bool neuflow = false;
00075 
00076     // init
00077     if (!lua_init_done) {
00078       sl_init();
00079       lua_init_done = true;
00080     }
00081 
00082     // load script
00083     printf("parsing Lua file: %s \n", script);
00084     int err = sl_dofile(script);
00085     if (err) {
00086       printf("could not load lua file: verify script \n");
00087       exit(1);
00088     }
00089 
00090     // initialize neuFlow: init()
00091     lua_getfield(sl_stack, LUA_GLOBALSINDEX, script);
00092     lua_getfield(sl_stack, -1, "init");
00093     lua_pushboolean(sl_stack, neuflow);
00094     lua_call(sl_stack, 1, 0);
00095     lua_pop(sl_stack, 1);
00096 
00097     // get pointers to inputs
00098     lua_getfield(sl_stack, LUA_GLOBALSINDEX, script);
00099     lua_getfield(sl_stack, -1, "inputs");
00100     int nbstates = lua_objlen(sl_stack, -1);
00101     this->inputs = (float **)malloc(sizeof(float *) * nbstates);
00102     this->inputs_size = (long **)malloc(sizeof(long *) * nbstates);
00103     printf("--> network has %d inputs: \n", nbstates);
00104     for (int i=0; i<nbstates; i++) {
00105       // grab input i
00106       lua_rawgeti(sl_stack, -1, i+1);
00107       THFloatTensor *tensor = (THFloatTensor *)luaT_toudata(sl_stack, -1,
00108                                                             luaT_checktypename2id(sl_stack, "torch.FloatTensor"));
00109       lua_pop(sl_stack, 1);
00110       // get raw pointer
00111       this->inputs[i] = THFloatTensor_data(tensor);
00112       // print dims
00113       this->inputs_size[i] = (long *)malloc(sizeof(long)*3);
00114       this->inputs_size[i][2] = 1;
00115       for (int s=0; s<THFloatTensor_nDimension(tensor); s++) {
00116         this->inputs_size[i][s] = tensor->size[s];
00117       }
00118       this->ninputs = nbstates;
00119       printf("    @state %d : %ldx%ldx%ld \n", i, tensor->size[0], tensor->size[1], tensor->size[2]);
00120     }
00121     printf("\n");
00122     lua_pop(sl_stack, 2);
00123 
00124     // get pointers to outputs
00125     lua_getfield(sl_stack, LUA_GLOBALSINDEX, script);
00126     lua_getfield(sl_stack, -1, "outputs");
00127     this->outputs = (float **)malloc(sizeof(float *) * nbstates);
00128     this->outputs_size = (long **)malloc(sizeof(long *) * nbstates);
00129     printf("--> network has %d outputs: \n", nbstates);
00130     for (int i=0; i<nbstates; i++) {
00131       // grab output i
00132       lua_rawgeti(sl_stack, -1, i+1);
00133       THFloatTensor *tensor = (THFloatTensor *)luaT_toudata(sl_stack, -1,
00134                                                             luaT_checktypename2id(sl_stack, "torch.FloatTensor"));
00135       lua_pop(sl_stack, 1);
00136       // get raw pointer
00137       this->outputs[i] = THFloatTensor_data(tensor);
00138       // print dims
00139       this->outputs_size[i] = (long *)malloc(sizeof(long)*3);
00140       this->outputs_size[i][2] = 1;
00141       for (int s=0; s<THFloatTensor_nDimension(tensor); s++) {
00142         this->outputs_size[i][s] = tensor->size[s];
00143       }
00144       this->noutputs = nbstates;
00145       printf("    @state %d : %ldx%ldx%ld \n", i, tensor->size[0], tensor->size[1], tensor->size[2]);
00146     }
00147     printf("\n");
00148     lua_pop(sl_stack, 2);
00149 #endif
00150   }
00151 
00152   template <typename T, class Tstate>
00153   lua_module<T,Tstate>::~lua_module() {
00154     free(this->inputs);
00155     free(this->outputs);
00156     free(this->inputs_size);
00157     free(this->outputs_size);
00158   }
00159 
00160   template <typename T, class Tstate>
00161   void lua_module<T,Tstate>::
00162   fprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00163     EDEBUG(this->name() << ": in " << in << " out " << out);
00164 #ifdef __LUA__
00165     // get inputs
00166     long p = 0;
00167     cout << "ebl inputs -> lua" << endl;
00168     for (typename mstate<Tstate>::iterator i = in.begin(); i != in.end(); ++i) {
00169       T *src = i->x.idx_ptr();
00170       cout << i->x << endl;
00171       long n = inputs_size[p][0] * inputs_size[p][1] * inputs_size[p][2];
00172       for (long k=0; k < n; k++)
00173         this->inputs[p][k] = src[k];
00174       p++;
00175     }
00176 
00177     // do fprop
00178     lua_getfield(sl_stack, LUA_GLOBALSINDEX, this->script_fname.c_str());
00179     lua_getfield(sl_stack, -1, "fprop");
00180     lua_pushboolean(sl_stack, true);
00181     lua_call(sl_stack, 1, 0);
00182     lua_pop(sl_stack, 1);
00183 
00184     // copy output
00185     p = 0;
00186     out.resize(in);
00187     cout << "lua outputs -> ebl" << endl;
00188     for (typename mstate<Tstate>::iterator o = out.begin();
00189          o != out.end(); ++o) {
00190       // resize output
00191       Tstate &oo = *o;
00192 
00193       idxdim d(this->outputs_size[p][0], this->outputs_size[p][1], this->outputs_size[p][2]);
00194       if (oo.x.get_idxdim() != d) oo.resize(d);
00195       T *dst = oo.x.idx_ptr();
00196       cout << o->x << endl;
00197 
00198       long n = outputs_size[p][0] * outputs_size[p][1] * outputs_size[p][2];
00199       for (long k = 0; k < n; k++)
00200         dst[k] = this->outputs[p][k];
00201       p++;
00202     }
00203 #else
00204     eblerror("trying to use lua_module but this was not compiled with lua"
00205              << ", recompile with it");
00206 #endif
00207   }
00208 
00209   template <typename T, class Tstate>
00210   void lua_module<T,Tstate>::
00211   bprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00212     not_implemented();
00213   }
00214 
00215   template <typename T, class Tstate>
00216   void lua_module<T,Tstate>::
00217   bbprop(mstate<Tstate> &in, mstate<Tstate> &out) {
00218     not_implemented();
00219   }
00220 
00221   template <typename T, class Tstate>
00222   std::string lua_module<T, Tstate>::describe() {
00223     std::string desc;
00224     desc << "lua module " << this->name();
00225     return desc;
00226   }
00227 
00228 } // end namespace ebl