CMS 3D CMS Logo

SprMultiClassReader.cc

Go to the documentation of this file.
00001 //$Id: SprMultiClassReader.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassLearner.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00009 
00010 #include <iostream>
00011 #include <fstream>
00012 #include <string>
00013 #include <sstream>
00014 
00015 using namespace std;
00016 
00017 
00018 SprMultiClassReader::~SprMultiClassReader()
00019 {
00020   for( int i=0;i<classifiers_.size();i++ ) {
00021     if( classifiers_[i].second )
00022       delete classifiers_[i].first;
00023   }
00024 }
00025 
00026 
00027 bool SprMultiClassReader::read(const char* filename)
00028 {
00029   // sanity check
00030   if( !classifiers_.empty() ) {
00031     cerr << "You are attempting to re-read the saved multi class learner "
00032          << "configuration without using the previous one." << endl;
00033     return false;
00034   }
00035 
00036   // open file
00037   string fname = filename;
00038   ifstream file(fname.c_str());
00039   if( !file ) {
00040     cerr << "Unable to open file " << fname.c_str() << endl;
00041     return false;
00042   }
00043   cout << "Reading MultiClassLearner from file " << fname.c_str() << endl;
00044 
00045   // read
00046   return this->read(file);
00047 }
00048 
00049 
00050 bool SprMultiClassReader::read(std::istream& input)
00051 {
00052   // read indicator matrix
00053   string line;
00054   unsigned nLine = 0;
00055   for( int i=0;i<2;i++ ) {
00056     nLine++;
00057     if( !getline(input,line) ) {
00058       cerr << "Cannot read from line " << nLine << endl;
00059       return false;
00060     }
00061   }
00062   nLine++;
00063   if( !getline(input,line) ) {
00064     cerr << "Cannot read from line " << nLine << endl;
00065     return false;
00066   }
00067   if( line.find(':') != string::npos )
00068     line.erase(0,line.find_first_of(':')+1);
00069   else {
00070     cerr << "Cannot read from line " << nLine << endl;
00071     return false;
00072   }
00073   istringstream ist(line);
00074   unsigned nClasses(0), nClassifiers(0);
00075   ist >> nClasses >> nClassifiers;
00076   if( nClasses == 0 ) {
00077     cerr << "No classes found." << endl;
00078     return false;
00079   }
00080   if( nClassifiers == 0 ) {
00081     cerr << "No classifiers found." << endl;
00082     return false;
00083   }
00084   nLine++;
00085   if( !getline(input,line) ) {
00086     cerr << "Cannot read from line " << nLine << endl;
00087     return false;
00088   }
00089   mapper_.clear();
00090   mapper_.resize(nClasses);
00091   SprMatrix mat(nClasses,nClassifiers,0);
00092   indicator_ = mat;
00093   for( int i=0;i<nClasses;i++ ) {
00094     nLine++;
00095     if( !getline(input,line) ) {
00096       cerr << "Cannot read from line " << nLine << endl;
00097       return false;
00098     }
00099     string sclass, srow;
00100     if( line.find(':') != string::npos ) {
00101       sclass = line.substr(0,line.find_first_of(':'));
00102       srow = line.substr(line.find_first_of(':')+1);
00103     }
00104     else {
00105       cerr << "Cannot read from line " << nLine << endl;
00106       return false;
00107     }
00108     if( sclass.empty() ) {
00109       cerr << "Cannot read class on line " << nLine << endl;
00110       return false;
00111     }
00112     if( srow.empty() ) {
00113       cerr << "Cannot read matrix row on line " << nLine << endl;
00114       return false;
00115     }
00116     istringstream istclass(sclass), istrow(srow);
00117     istclass >> mapper_[i];
00118     for( int j=0;j<nClassifiers;j++ )
00119       istrow >> indicator_[i][j];
00120   }
00121   nLine++;
00122   if( !getline(input,line) ) {
00123     cerr << "Cannot read from line " << nLine << endl;
00124     return false;
00125   }
00126   
00127   // read trained classifiers
00128   classifiers_.clear();
00129   classifiers_.resize(nClassifiers);
00130   for( int n=0;n<nClassifiers;n++ ) {
00131     // read index of the current classifier
00132     nLine++;
00133     if( !getline(input,line) ) {
00134       cerr << "Cannot read from line " << nLine << endl;
00135       return false;
00136     }
00137     if( line.find(':') != string::npos )
00138       line.erase(0,line.find_first_of(':')+1);
00139     else {
00140       cerr << "Cannot read from line " << nLine << endl;
00141       return false;
00142     }
00143     istringstream istc(line);
00144     unsigned iClassifiers = 0;
00145     istc >> iClassifiers;
00146     if( iClassifiers != n ) {
00147       cerr << "Wrong classifier index on line " << nLine << endl;
00148       return false;
00149     }
00150 
00151     // read each classifier
00152     string requested;
00153     SprAbsTrainedClassifier* trained =
00154       SprClassifierReader::readTrainedFromStream(input,requested,nLine);
00155     if( trained == 0 ) {
00156       cerr << "Unable to read trained classifier " << n << endl;
00157       return false;
00158     }
00159 
00160     // add classifier to the list
00161     classifiers_[n] = pair<const SprAbsTrainedClassifier*,bool>(trained,true);
00162   }// end of loop over classifiers
00163 
00164   // read variables
00165   if( !SprClassifierReader::readVars(input,vars_,nLine) ) {
00166     cerr << "Unable to read variables." << endl;
00167     return false;
00168   }
00169 
00170   // exit
00171   return true;
00172 }
00173 
00174 
00175 void SprMultiClassReader::setTrainable(SprMultiClassLearner* multi)
00176 {
00177   if( classifiers_.empty() ) {
00178     cerr << "Classifier list is empty in multi class reader." << endl;
00179     return;
00180   }
00181   assert( multi != 0 );
00182   multi->reset();
00183   multi->setTrained(indicator_,mapper_,classifiers_);
00184   classifiers_.clear();
00185 }
00186 
00187 
00188 SprTrainedMultiClassLearner* SprMultiClassReader::makeTrained()
00189 {
00190   if( classifiers_.empty() ) {
00191     cerr << "Classifier list is empty in multi class reader." << endl;
00192     return 0;
00193   }
00194   SprTrainedMultiClassLearner* t 
00195     = new SprTrainedMultiClassLearner(indicator_,mapper_,classifiers_);
00196   classifiers_.clear();
00197   t->setVars(vars_);
00198   return t;
00199 }
00200 
00201 
00202 bool SprMultiClassReader::readIndicatorMatrix(const char* filename, 
00203                                               SprMatrix& indicator)
00204 {
00205   // open file
00206   string fname = filename;
00207   ifstream input(fname.c_str());
00208   if( !input ) {
00209     cerr << "Unable to open file " << fname.c_str() << endl;
00210     return false;
00211   }
00212   cout << "Reading indicator matrix from file " << fname.c_str() << endl;
00213 
00214   // read indicator matrix dimensionality
00215   unsigned N(0), M(0);
00216   string line;
00217   unsigned nLine = 0;
00218   while( getline(input,line) ) {
00219     // update line counter
00220     nLine++;
00221 
00222     // remove comments
00223     if( line.find('#') != string::npos )
00224       line.erase( line.find_first_of('#') );
00225 
00226     // skip empty line
00227     if( line.find_first_not_of(' ') == string::npos ) continue;
00228 
00229     // make stream
00230     istringstream ist(line);
00231 
00232     // read matrix dimensions
00233     ist >> N >> M;
00234     break;
00235   }
00236   if( N==0 || M==0 ) {
00237     cerr << "Unable to read indicator matrix dimensionality: " 
00238          << N << " " << M << "    on line " << nLine << endl;
00239     return false;
00240   }
00241 
00242   // read the matrix itself
00243   SprMatrix temp(N,M,0);
00244   for( int n=0;n<N;n++ ) {
00245     nLine++;
00246     if( !getline(input,line) ) {
00247       cerr << "Unable to read line " << nLine << endl;
00248       return false;
00249     }
00250     istringstream ist(line);
00251     for( int m=0;m<M;m++ ) ist >> temp[n][m];
00252   }
00253 
00254   // check columns of indicator matrix
00255   for( int m=0;m<M;m++ ) {
00256     unsigned countPlus(0), countMinus(0);
00257     for( int n=0;n<N;n++ ) {
00258       int elem = int(temp[n][m]);
00259       if(      elem == -1 )
00260         countMinus++;
00261       else if( elem == +1 )
00262         countPlus++;
00263       else if( elem != 0 ) {
00264         cerr << "Invalid indicator matrix element [" << n+1 << "]" 
00265              << "[" << m+1 << "]=" << elem << endl;
00266         return false;
00267       }
00268     }
00269     if( countPlus==0 || countMinus==0 ) {
00270       cerr << "Column " << m+1 << " of the indicator matrix does not " 
00271            << "have background and signal labels present." << endl;
00272       return false;
00273     }
00274   }
00275 
00276   // check rows
00277   for( int n=0;n<N;n++ ) {
00278     unsigned sum = 0;
00279     for( int m=0;m<M;m++ )
00280       sum += abs(int(temp[n][m]));
00281     if( sum == 0 ) {
00282       cerr << "Row " << n+1 << " of the indicator matrix has nothing "
00283            << "but zeros." << endl;
00284       return false;
00285     }
00286   }
00287 
00288   // exit
00289   indicator = temp;
00290   return true;
00291 }

Generated on Tue Jun 9 17:42:03 2009 for CMSSW by  doxygen 1.5.4