libidx
/home/rex/ebltrunk/core/libidx/include/thops.h
00001 /***************************************************************************
00002  *   Copyright (C) 2012 by Soumith Chintala   *
00003  *   soumith@gmail.com                        *
00004  *
00005  * Redistribution and use in source and binary forms, with or without
00006  * modification, are permitted provided that the following conditions are met:
00007  *     * Redistributions of source code must retain the above copyright
00008  *       notice, this list of conditions and the following disclaimer.
00009  *     * Redistributions in binary form must reproduce the above copyright
00010  *       notice, this list of conditions and the following disclaimer in the
00011  *       documentation and/or other materials provided with the distribution.
00012  *     * Redistribution under a license not approved by the Open Source
00013  *       Initiative (http://www.opensource.org) must display the
00014  *       following acknowledgement in all advertising material:
00015  *        This product includes software developed at the Courant
00016  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00017  *     * The names of the authors may not be used to endorse or promote products
00018  *       derived from this software without specific prior written permission.
00019  *
00020  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00021  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00022  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00023  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00024  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00025  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00026  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00027  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00028  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00029  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00030  ***************************************************************************/
00031 
00032 
00033 #ifndef THOPS_H_
00034 #define THOPS_H_
00035 
00036 #ifdef __TH__
00037 #include "idx.h"
00038 
00039 namespace ebl {
00040 
00041 
00042   // th_add ///////////////////////////////////////////////////////////////////
00043   
00048   template <typename T> void th_add(idx<T> &in1, idx<T> &in2);
00049 
00053   template <> void th_add(idx<float32> &in1, idx<float32> &in2);
00054 
00058   template <> void th_add(idx<float64> &in1, idx<float64> &in2);
00059 
00060 
00061   
00066   template <typename T> void th_add(idx<T> &in1,
00067                                             idx<T> &in2, idx<T> &out);
00068 
00072   template <> void th_add(idx<float32> &in1,
00073                                   idx<float32> &in2, idx<float32> &out);
00077   template <> void th_add(idx<float64> &in1,
00078                                   idx<float64> &in2, idx<float64> &out);
00079 
00080   // th_copy /////////////////////////////////////////////////////////////
00082   template <typename T> void th_copy(idx<T> &in, idx<T> &out);
00084     template <> void th_copy(idx<float32> &in, idx<float32> &out);
00086     template <> void th_copy(idx<float64> &in, idx<float64> &out);
00087 
00088   // th_convolution ///////////////////////////////////////////////////////////
00089 
00094   template <typename T>
00095     void th_convolution(idx<T> &in, idx<T> &ker, idx<T> &out, 
00096                         intg stride_x=1, intg stride_y=1);
00097 
00102   template <>
00103   void th_convolution(idx<float32> &in, idx<float32> &ker,
00104                       idx<float32> &out, intg stride_x, intg stride_y);
00107   template <>
00108   void th_convolution(idx<float64> &in, idx<float64> &ker,
00109                       idx<float64> &out, intg stride_x, intg stride_y);
00110 
00113   template <typename T>
00114   void th_convolution_3d(idx<T> &in, idx<T> &ker,
00115                               idx<T> &out, 
00116                         intg stride_x=1, intg stride_y=1);
00117 
00120   template <>
00121   void th_convolution_3d(idx<float32> &in, idx<float32> &ker,
00122                               idx<float32> &out, 
00123                         intg stride_x, intg stride_y);
00124 
00125 
00128   template <>
00129   void th_convolution_3d(idx<float64> &in, idx<float64> &ker,
00130                               idx<float64> &out, 
00131                         intg stride_x, intg stride_y);
00135   template <typename T>
00136   void th_convolution_3dmap(idx<T> &in, idx<T> &ker,
00137                             idx<T> &out, idx<intg> &table,
00138                             intg stride_x=1, intg stride_y=1);
00139 
00142   template <>
00143   void th_convolution_3dmap(idx<float32> &in, idx<float32> &ker,
00144                             idx<float32> &out, idx<intg> &table,
00145                             intg stride_x, intg stride_y);
00146 
00147 
00150   template <>
00151   void th_convolution_3dmap(idx<float64> &in, idx<float64> &ker,
00152                               idx<float64> &out, idx<intg> &table,
00153                             intg stride_x, intg stride_y);
00157   template <typename T>
00158   void th_convolution_3dmap_bprop(idx<T> &inx, idx<T> &kerx,
00159                                   idx<T> &outdx, idx<T> &indx, 
00160                                   idx<T> &kerdx, idx<intg> &table,
00161                             intg stride_w, intg stride_h);
00162 
00165   template <>
00166   void th_convolution_3dmap_bprop(idx<float32> &inx, idx<float32> &kerx,
00167                                   idx<float32> &outdx, idx<float32> &indx, 
00168                                   idx<float32> &kerdx,idx<intg> &table,
00169                             intg stride_w, intg stride_h);
00170 
00171 
00174   template <>
00175   void th_convolution_3dmap_bprop(idx<float64> &inx, idx<float64> &kerx,
00176                                   idx<float64> &outdx, idx<float64> &indx, 
00177                                   idx<float64> &kerdx, idx<intg> &table,
00178                             intg stride_w, intg stride_h);
00179 
00182   template <typename T>
00183     void th_tanh(idx<T> &in, idx<T> &out);
00184   template <>
00185     void th_tanh(idx<float32> &in, idx<float32> &out);
00186   template <>
00187     void th_tanh(idx<float64> &in, idx<float64> &out);
00188 
00191   template <typename T>
00192     void th_pow(idx<T> &in, idx<T> &out, T p);
00193   template <>
00194     void th_pow(idx<float32> &in, idx<float32> &out, float32 p);
00195   template <>
00196     void th_pow(idx<float64> &in, idx<float64> &out, float64 p);
00197 
00198 
00202   template <typename T>
00203   void th_maxpool_3d(idx<T> &in, intg kernel_w, intg kernel_h,
00204                             idx<T> &out, 
00205                      intg stride_x, intg stride_y, idx<T> &indices_e);
00206 
00209   template <>
00210   void th_maxpool_3d(idx<float32> &in, intg kernel_w, intg kernel_h,
00211                             idx<float32> &out, 
00212                             intg stride_x, intg stride_y, idx<float32> &indices_e);
00213 
00214 
00217   template <>
00218   void th_maxpool_3d(idx<float64> &in, intg kernel_w, intg kernel_h,
00219                               idx<float64> &out, 
00220                             intg stride_x, intg stride_y, idx<float64> &indices_e);
00224   template <typename T>
00225   void th_maxpool_3d_bprop(idx<T> &inx, intg kernel_w, intg kernel_h,
00226                                   idx<T> &outdx, idx<T> &indx, 
00227                             intg stride_w, intg stride_h, idx<T> &indices_e);
00228 
00231   template <>
00232   void th_maxpool_3d_bprop(idx<float32> &inx, intg kernel_w, intg kernel_h,
00233                                   idx<float32> &outdx, idx<float32> &indx, 
00234                             intg stride_w, intg stride_h, idx<float32> &indices_e);
00235 
00236 
00239   template <>
00240   void th_maxpool_3d_bprop(idx<float64> &inx, intg kernel_w, intg kernel_h,
00241                                   idx<float64> &outdx, idx<float64> &indx, 
00242                             intg stride_w, intg stride_h, idx<float64> &indices_e);
00243 
00244 
00245 } // end namespace ebl
00246 
00247 #include "thops.hpp"
00248 #endif /* __TH__ */
00249 
00250 #endif /* THOPS_H_ */