libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #include <stdlib.h> 00034 00035 #ifdef __LUA__ 00036 00037 #include "TH/TH.h" 00038 #undef inline 00039 #include <lua.h> 00040 #include <luaT.h> 00041 #include <lualib.h> 00042 #include <lauxlib.h> 00043 00044 #ifdef __cplusplus 00045 extern "C" { 00046 #endif 00047 00048 extern lua_State *sl_stack; 00049 00050 void sl_init(void); 00051 void sl_cleanup(void); 00052 int sl_dofile(const char *file); 00053 int sl_dostring(const char *string); 00054 lua_State * sl_getstack(void); 00055 void sl_callfunc(const char *funcname, int nargs, int nouts); 00056 00057 #ifdef __cplusplus 00058 } 00059 #endif 00060 00061 static bool lua_init_done = false; 00062 #endif 00063 00064 namespace ebl { 00065 00067 // lua_module 00068 00069 template <typename T, class Tstate> 00070 lua_module<T,Tstate>::lua_module(const char *script, const char *name_) 00071 : module_1_1<T,Tstate>(name_), script_fname(script), ninputs(0), noutputs(0) { 00072 #ifdef __LUA__ 00073 // map neuflow 00074 bool neuflow = false; 00075 00076 // init 00077 if (!lua_init_done) { 00078 sl_init(); 00079 lua_init_done = true; 00080 } 00081 00082 // load script 00083 printf("parsing Lua file: %s \n", script); 00084 int err = sl_dofile(script); 00085 if (err) { 00086 printf("could not load lua file: verify script \n"); 00087 exit(1); 00088 } 00089 00090 // initialize neuFlow: init() 00091 lua_getfield(sl_stack, LUA_GLOBALSINDEX, script); 00092 lua_getfield(sl_stack, -1, "init"); 00093 lua_pushboolean(sl_stack, neuflow); 00094 lua_call(sl_stack, 1, 0); 00095 lua_pop(sl_stack, 1); 00096 00097 // get pointers to inputs 00098 lua_getfield(sl_stack, LUA_GLOBALSINDEX, script); 00099 lua_getfield(sl_stack, -1, "inputs"); 00100 int nbstates = lua_objlen(sl_stack, -1); 00101 this->inputs = (float **)malloc(sizeof(float *) * nbstates); 00102 this->inputs_size = (long **)malloc(sizeof(long *) * nbstates); 00103 printf("--> network has %d inputs: \n", nbstates); 00104 for (int i=0; i<nbstates; i++) { 00105 // grab input i 00106 lua_rawgeti(sl_stack, -1, i+1); 00107 THFloatTensor *tensor = (THFloatTensor *)luaT_toudata(sl_stack, -1, 00108 luaT_checktypename2id(sl_stack, "torch.FloatTensor")); 00109 lua_pop(sl_stack, 1); 00110 // get raw pointer 00111 this->inputs[i] = THFloatTensor_data(tensor); 00112 // print dims 00113 this->inputs_size[i] = (long *)malloc(sizeof(long)*3); 00114 this->inputs_size[i][2] = 1; 00115 for (int s=0; s<THFloatTensor_nDimension(tensor); s++) { 00116 this->inputs_size[i][s] = tensor->size[s]; 00117 } 00118 this->ninputs = nbstates; 00119 printf(" @state %d : %ldx%ldx%ld \n", i, tensor->size[0], tensor->size[1], tensor->size[2]); 00120 } 00121 printf("\n"); 00122 lua_pop(sl_stack, 2); 00123 00124 // get pointers to outputs 00125 lua_getfield(sl_stack, LUA_GLOBALSINDEX, script); 00126 lua_getfield(sl_stack, -1, "outputs"); 00127 this->outputs = (float **)malloc(sizeof(float *) * nbstates); 00128 this->outputs_size = (long **)malloc(sizeof(long *) * nbstates); 00129 printf("--> network has %d outputs: \n", nbstates); 00130 for (int i=0; i<nbstates; i++) { 00131 // grab output i 00132 lua_rawgeti(sl_stack, -1, i+1); 00133 THFloatTensor *tensor = (THFloatTensor *)luaT_toudata(sl_stack, -1, 00134 luaT_checktypename2id(sl_stack, "torch.FloatTensor")); 00135 lua_pop(sl_stack, 1); 00136 // get raw pointer 00137 this->outputs[i] = THFloatTensor_data(tensor); 00138 // print dims 00139 this->outputs_size[i] = (long *)malloc(sizeof(long)*3); 00140 this->outputs_size[i][2] = 1; 00141 for (int s=0; s<THFloatTensor_nDimension(tensor); s++) { 00142 this->outputs_size[i][s] = tensor->size[s]; 00143 } 00144 this->noutputs = nbstates; 00145 printf(" @state %d : %ldx%ldx%ld \n", i, tensor->size[0], tensor->size[1], tensor->size[2]); 00146 } 00147 printf("\n"); 00148 lua_pop(sl_stack, 2); 00149 #endif 00150 } 00151 00152 template <typename T, class Tstate> 00153 lua_module<T,Tstate>::~lua_module() { 00154 free(this->inputs); 00155 free(this->outputs); 00156 free(this->inputs_size); 00157 free(this->outputs_size); 00158 } 00159 00160 template <typename T, class Tstate> 00161 void lua_module<T,Tstate>:: 00162 fprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00163 EDEBUG(this->name() << ": in " << in << " out " << out); 00164 #ifdef __LUA__ 00165 // get inputs 00166 long p = 0; 00167 cout << "ebl inputs -> lua" << endl; 00168 for (typename mstate<Tstate>::iterator i = in.begin(); i != in.end(); ++i) { 00169 T *src = i->x.idx_ptr(); 00170 cout << i->x << endl; 00171 long n = inputs_size[p][0] * inputs_size[p][1] * inputs_size[p][2]; 00172 for (long k=0; k < n; k++) 00173 this->inputs[p][k] = src[k]; 00174 p++; 00175 } 00176 00177 // do fprop 00178 lua_getfield(sl_stack, LUA_GLOBALSINDEX, this->script_fname.c_str()); 00179 lua_getfield(sl_stack, -1, "fprop"); 00180 lua_pushboolean(sl_stack, true); 00181 lua_call(sl_stack, 1, 0); 00182 lua_pop(sl_stack, 1); 00183 00184 // copy output 00185 p = 0; 00186 out.resize(in); 00187 cout << "lua outputs -> ebl" << endl; 00188 for (typename mstate<Tstate>::iterator o = out.begin(); 00189 o != out.end(); ++o) { 00190 // resize output 00191 Tstate &oo = *o; 00192 00193 idxdim d(this->outputs_size[p][0], this->outputs_size[p][1], this->outputs_size[p][2]); 00194 if (oo.x.get_idxdim() != d) oo.resize(d); 00195 T *dst = oo.x.idx_ptr(); 00196 cout << o->x << endl; 00197 00198 long n = outputs_size[p][0] * outputs_size[p][1] * outputs_size[p][2]; 00199 for (long k = 0; k < n; k++) 00200 dst[k] = this->outputs[p][k]; 00201 p++; 00202 } 00203 #else 00204 eblerror("trying to use lua_module but this was not compiled with lua" 00205 << ", recompile with it"); 00206 #endif 00207 } 00208 00209 template <typename T, class Tstate> 00210 void lua_module<T,Tstate>:: 00211 bprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00212 not_implemented(); 00213 } 00214 00215 template <typename T, class Tstate> 00216 void lua_module<T,Tstate>:: 00217 bbprop(mstate<Tstate> &in, mstate<Tstate> &out) { 00218 not_implemented(); 00219 } 00220 00221 template <typename T, class Tstate> 00222 std::string lua_module<T, Tstate>::describe() { 00223 std::string desc; 00224 desc << "lua module " << this->name(); 00225 return desc; 00226 } 00227 00228 } // end namespace ebl