libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * 00005 * Redistribution and use in source and binary forms, with or without 00006 * modification, are permitted provided that the following conditions are met: 00007 * * Redistributions of source code must retain the above copyright 00008 * notice, this list of conditions and the following disclaimer. 00009 * * Redistributions in binary form must reproduce the above copyright 00010 * notice, this list of conditions and the following disclaimer in the 00011 * documentation and/or other materials provided with the distribution. 00012 * * Redistribution under a license not approved by the Open Source 00013 * Initiative (http://www.opensource.org) must display the 00014 * following acknowledgement in all advertising material: 00015 * This product includes software developed at the Courant 00016 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00017 * * The names of the authors may not be used to endorse or promote products 00018 * derived from this software without specific prior written permission. 00019 * 00020 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00021 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00022 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00023 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00024 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00025 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00026 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00027 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00028 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00029 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00030 ***************************************************************************/ 00031 00032 #ifndef EBL_LOGGER_H_ 00033 #define EBL_LOGGER_H_ 00034 00035 #include "ebl_defines.h" 00036 #include "libidx.h" 00037 #include "ebl_states.h" 00038 00039 #ifndef __NOSTL__ 00040 #include <vector> 00041 #endif 00042 00043 namespace ebl { 00044 00046 00051 class class_state { 00052 public: 00053 ubyte output_class; 00054 float confidence; 00055 idx<ubyte> *sorted_classes; 00056 idx<float> *sorted_scores; 00057 00058 class_state(ubyte n); 00059 ~class_state(); 00060 void resize(ubyte n); 00061 }; 00062 00084 00088 // TODO: allow definition of different comparison functions. 00089 class EXPORT classifier_meter { 00090 public: 00097 // TODO: allow passing of comparison function 00098 classifier_meter(); 00099 00101 ~classifier_meter(); 00102 00105 void init(uint nclasses); 00106 00109 int correctp(ubyte co, ubyte cd); 00110 00113 void clear(); 00114 void resize(intg sz); 00115 00121 char update(intg a, class_state *co, ubyte cd, double energy); 00122 void update(intg age_, bool correct, double energy); 00123 // TODO: clean up design 00124 // TODO: add confusion matrix computation 00125 void update(intg age, uint desired, uint infered, double energy); 00126 00127 void test(class_state *co, ubyte cd, double energy); 00128 00132 double class_normalized_average_error(idx<int> &confu); 00133 00136 double overall_average_error(idx<int> &confu); 00137 00140 double class_normalized_average_success(idx<int> &confu); 00141 00144 int get_class_samples(idx<int> &confu, intg classid); 00145 00148 int get_class_errors(idx<int> &confu, intg classid); 00151 double get_normalized_error(); 00153 idx<int>& get_confusion(); 00154 00160 void info(); 00161 void info_sprint(); 00162 void info_print(); 00163 00172 void display(int iteration, string &dsname, 00173 std::vector<string*> *lblstr = NULL, 00174 bool ds_is_test = false); 00175 00177 void display_average(string &dsname, std::vector<string*> *lblstr = NULL, 00178 bool ds_is_test = false); 00179 00182 void display_positive_rates(double threshold, 00183 std::vector<string*> *lblstr = NULL); 00184 00185 bool save(); 00186 bool load(); 00187 00188 public: 00189 double energy; 00190 float confidence; 00191 intg size; 00192 intg age; 00193 intg total_correct; 00194 intg total_error; 00195 intg total_punt; 00196 double total_energy; 00197 std::vector<uint> class_errors; 00198 std::vector<uint> class_totals; 00199 std::vector<uint> class_tpr; 00200 std::vector<uint> class_fpr; 00201 std::vector<std::string> log_fields; 00202 std::vector<double> log_values; 00203 std::vector<double> total_values; 00204 private: 00205 idx<int> confusion; 00206 idx<int> total_confusion; 00207 uint nclasses; 00208 }; 00209 00211 00217 template <class T> class max_classer { // TODO: idx3-classer 00218 public: 00220 idx<ubyte> *classindex2label; 00221 00224 max_classer(idx<ubyte> *classes); 00225 ~max_classer() { 00226 } 00227 ; 00228 00229 void fprop(fstate_idx<T> *in, class_state *out); 00230 }; 00231 00232 } // namespace ebl { 00233 00234 #endif /* EBL_LOGGER_H_ */