00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassLearner.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00009
00010 #include <iostream>
00011 #include <fstream>
00012 #include <string>
00013 #include <sstream>
00014
00015 using namespace std;
00016
00017
00018 SprMultiClassReader::~SprMultiClassReader()
00019 {
00020 for( int i=0;i<classifiers_.size();i++ ) {
00021 if( classifiers_[i].second )
00022 delete classifiers_[i].first;
00023 }
00024 }
00025
00026
00027 bool SprMultiClassReader::read(const char* filename)
00028 {
00029
00030 if( !classifiers_.empty() ) {
00031 cerr << "You are attempting to re-read the saved multi class learner "
00032 << "configuration without using the previous one." << endl;
00033 return false;
00034 }
00035
00036
00037 string fname = filename;
00038 ifstream file(fname.c_str());
00039 if( !file ) {
00040 cerr << "Unable to open file " << fname.c_str() << endl;
00041 return false;
00042 }
00043 cout << "Reading MultiClassLearner from file " << fname.c_str() << endl;
00044
00045
00046 return this->read(file);
00047 }
00048
00049
00050 bool SprMultiClassReader::read(std::istream& input)
00051 {
00052
00053 string line;
00054 unsigned nLine = 0;
00055 for( int i=0;i<2;i++ ) {
00056 nLine++;
00057 if( !getline(input,line) ) {
00058 cerr << "Cannot read from line " << nLine << endl;
00059 return false;
00060 }
00061 }
00062 nLine++;
00063 if( !getline(input,line) ) {
00064 cerr << "Cannot read from line " << nLine << endl;
00065 return false;
00066 }
00067 if( line.find(':') != string::npos )
00068 line.erase(0,line.find_first_of(':')+1);
00069 else {
00070 cerr << "Cannot read from line " << nLine << endl;
00071 return false;
00072 }
00073 istringstream ist(line);
00074 unsigned nClasses(0), nClassifiers(0);
00075 ist >> nClasses >> nClassifiers;
00076 if( nClasses == 0 ) {
00077 cerr << "No classes found." << endl;
00078 return false;
00079 }
00080 if( nClassifiers == 0 ) {
00081 cerr << "No classifiers found." << endl;
00082 return false;
00083 }
00084 nLine++;
00085 if( !getline(input,line) ) {
00086 cerr << "Cannot read from line " << nLine << endl;
00087 return false;
00088 }
00089 mapper_.clear();
00090 mapper_.resize(nClasses);
00091 SprMatrix mat(nClasses,nClassifiers,0);
00092 indicator_ = mat;
00093 for( int i=0;i<nClasses;i++ ) {
00094 nLine++;
00095 if( !getline(input,line) ) {
00096 cerr << "Cannot read from line " << nLine << endl;
00097 return false;
00098 }
00099 string sclass, srow;
00100 if( line.find(':') != string::npos ) {
00101 sclass = line.substr(0,line.find_first_of(':'));
00102 srow = line.substr(line.find_first_of(':')+1);
00103 }
00104 else {
00105 cerr << "Cannot read from line " << nLine << endl;
00106 return false;
00107 }
00108 if( sclass.empty() ) {
00109 cerr << "Cannot read class on line " << nLine << endl;
00110 return false;
00111 }
00112 if( srow.empty() ) {
00113 cerr << "Cannot read matrix row on line " << nLine << endl;
00114 return false;
00115 }
00116 istringstream istclass(sclass), istrow(srow);
00117 istclass >> mapper_[i];
00118 for( int j=0;j<nClassifiers;j++ )
00119 istrow >> indicator_[i][j];
00120 }
00121 nLine++;
00122 if( !getline(input,line) ) {
00123 cerr << "Cannot read from line " << nLine << endl;
00124 return false;
00125 }
00126
00127
00128 classifiers_.clear();
00129 classifiers_.resize(nClassifiers);
00130 for( int n=0;n<nClassifiers;n++ ) {
00131
00132 nLine++;
00133 if( !getline(input,line) ) {
00134 cerr << "Cannot read from line " << nLine << endl;
00135 return false;
00136 }
00137 if( line.find(':') != string::npos )
00138 line.erase(0,line.find_first_of(':')+1);
00139 else {
00140 cerr << "Cannot read from line " << nLine << endl;
00141 return false;
00142 }
00143 istringstream istc(line);
00144 unsigned iClassifiers = 0;
00145 istc >> iClassifiers;
00146 if( iClassifiers != n ) {
00147 cerr << "Wrong classifier index on line " << nLine << endl;
00148 return false;
00149 }
00150
00151
00152 string requested;
00153 SprAbsTrainedClassifier* trained =
00154 SprClassifierReader::readTrainedFromStream(input,requested,nLine);
00155 if( trained == 0 ) {
00156 cerr << "Unable to read trained classifier " << n << endl;
00157 return false;
00158 }
00159
00160
00161 classifiers_[n] = pair<const SprAbsTrainedClassifier*,bool>(trained,true);
00162 }
00163
00164
00165 if( !SprClassifierReader::readVars(input,vars_,nLine) ) {
00166 cerr << "Unable to read variables." << endl;
00167 return false;
00168 }
00169
00170
00171 return true;
00172 }
00173
00174
00175 void SprMultiClassReader::setTrainable(SprMultiClassLearner* multi)
00176 {
00177 if( classifiers_.empty() ) {
00178 cerr << "Classifier list is empty in multi class reader." << endl;
00179 return;
00180 }
00181 assert( multi != 0 );
00182 multi->reset();
00183 multi->setTrained(indicator_,mapper_,classifiers_);
00184 classifiers_.clear();
00185 }
00186
00187
00188 SprTrainedMultiClassLearner* SprMultiClassReader::makeTrained()
00189 {
00190 if( classifiers_.empty() ) {
00191 cerr << "Classifier list is empty in multi class reader." << endl;
00192 return 0;
00193 }
00194 SprTrainedMultiClassLearner* t
00195 = new SprTrainedMultiClassLearner(indicator_,mapper_,classifiers_);
00196 classifiers_.clear();
00197 t->setVars(vars_);
00198 return t;
00199 }
00200
00201
00202 bool SprMultiClassReader::readIndicatorMatrix(const char* filename,
00203 SprMatrix& indicator)
00204 {
00205
00206 string fname = filename;
00207 ifstream input(fname.c_str());
00208 if( !input ) {
00209 cerr << "Unable to open file " << fname.c_str() << endl;
00210 return false;
00211 }
00212 cout << "Reading indicator matrix from file " << fname.c_str() << endl;
00213
00214
00215 unsigned N(0), M(0);
00216 string line;
00217 unsigned nLine = 0;
00218 while( getline(input,line) ) {
00219
00220 nLine++;
00221
00222
00223 if( line.find('#') != string::npos )
00224 line.erase( line.find_first_of('#') );
00225
00226
00227 if( line.find_first_not_of(' ') == string::npos ) continue;
00228
00229
00230 istringstream ist(line);
00231
00232
00233 ist >> N >> M;
00234 break;
00235 }
00236 if( N==0 || M==0 ) {
00237 cerr << "Unable to read indicator matrix dimensionality: "
00238 << N << " " << M << " on line " << nLine << endl;
00239 return false;
00240 }
00241
00242
00243 SprMatrix temp(N,M,0);
00244 for( int n=0;n<N;n++ ) {
00245 nLine++;
00246 if( !getline(input,line) ) {
00247 cerr << "Unable to read line " << nLine << endl;
00248 return false;
00249 }
00250 istringstream ist(line);
00251 for( int m=0;m<M;m++ ) ist >> temp[n][m];
00252 }
00253
00254
00255 for( int m=0;m<M;m++ ) {
00256 unsigned countPlus(0), countMinus(0);
00257 for( int n=0;n<N;n++ ) {
00258 int elem = int(temp[n][m]);
00259 if( elem == -1 )
00260 countMinus++;
00261 else if( elem == +1 )
00262 countPlus++;
00263 else if( elem != 0 ) {
00264 cerr << "Invalid indicator matrix element [" << n+1 << "]"
00265 << "[" << m+1 << "]=" << elem << endl;
00266 return false;
00267 }
00268 }
00269 if( countPlus==0 || countMinus==0 ) {
00270 cerr << "Column " << m+1 << " of the indicator matrix does not "
00271 << "have background and signal labels present." << endl;
00272 return false;
00273 }
00274 }
00275
00276
00277 for( int n=0;n<N;n++ ) {
00278 unsigned sum = 0;
00279 for( int m=0;m<M;m++ )
00280 sum += abs(int(temp[n][m]));
00281 if( sum == 0 ) {
00282 cerr << "Row " << n+1 << " of the indicator matrix has nothing "
00283 << "but zeros." << endl;
00284 return false;
00285 }
00286 }
00287
00288
00289 indicator = temp;
00290 return true;
00291 }