CMS 3D CMS Logo

SprBaggerApp.cc

Go to the documentation of this file.
00001 //$Id: SprBaggerApp.cc,v 1.5 2007/11/12 06:19:11 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00025 
00026 #include <stdlib.h>
00027 #include <unistd.h>
00028 #include <iostream>
00029 #include <fstream>
00030 #include <vector>
00031 #include <set>
00032 #include <string>
00033 #include <memory>
00034 
00035 using namespace std;
00036 
00037 
00038 
00039 void help(const char* prog) 
00040 {
00041   cout << "Usage:  " << prog 
00042        << " training_data_file"
00043        << " file_of_classifier_parameters(see booster.config for syntax)" 
00044        << endl;
00045   cout << "\t Options: " << endl;
00046   cout << "\t-h --- help                                        " << endl;
00047   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00048   cout << "\t-A save output data in ascii instead of Root       " << endl;
00049   cout << "\t-n number of Bagger training cycles                " << endl;
00050   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00051   cout << "\t-Q apply variable transformation saved in file     " << endl;
00052   cout << "\t-b use a version of Breiman's arc-x4 algorithm     " << endl;
00053   cout << "\t-g per-event loss for (cross-)validation           " << endl;
00054   cout << "\t\t 1 - quadratic loss (y-f(x))^2                   " << endl;
00055   cout << "\t\t 2 - exponential loss exp(-y*f(x))               " << endl;
00056   cout << "\t-m replace data values below this cutoff with medians" << endl;
00057   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00058   cout << "\t-f store trained AdaBoost to file                  " << endl;
00059   cout << "\t-r resume training for Bagger stored in file       " << endl;
00060   cout << "\t-K keep this fraction in training set and          " << endl;
00061   cout << "\t\t put the rest into validation set                " << endl;
00062   cout << "\t-D randomize training set split-up                 " << endl;
00063   cout << "\t-G generate seed from time of day for bootstrap    " << endl;
00064   cout << "\t\t (this option is required for parallelization)   " << endl;
00065   cout << "\t-t read validation/test data from a file           " << endl;
00066   cout << "\t\t (must be in same format as input data!!!        " << endl;
00067   cout << "\t-d frequency of print-outs for validation data     " << endl;
00068   cout << "\t-w scale all signal weights by this factor         " << endl;
00069   cout << "\t-V include only these input variables              " << endl;
00070   cout << "\t-z exclude input variables from the list           " << endl;
00071   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00072        << endl;
00073 }
00074 
00075 
00076 void prepareExit(vector<SprAbsTwoClassCriterion*>& criteria,
00077                  vector<SprAbsClassifier*>& classifiers,
00078                  vector<SprIntegerBootstrap*>& bstraps) 
00079 {
00080   for( int i=0;i<criteria.size();i++ ) delete criteria[i];
00081   for( int i=0;i<classifiers.size();i++ ) delete classifiers[i];
00082   for( int i=0;i<bstraps.size();i++ ) delete bstraps[i];
00083 }
00084 
00085 
00086 int main(int argc, char ** argv)
00087 {
00088   // check command line
00089   if( argc < 3 ) {
00090     help(argv[0]);
00091     return 1;
00092   }
00093 
00094   // init
00095   int readMode = 0;
00096   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00097   unsigned cycles = 0;
00098   int verbose = 0;
00099   string outFile;
00100   string resumeFile;
00101   string valFile;
00102   unsigned valPrint = 0;
00103   bool scaleWeights = false;
00104   double sW = 1.;
00105   bool setLowCutoff = false;
00106   double lowCutoff = 0;
00107   string includeList, excludeList;
00108   int iLoss = 1;
00109   string inputClassesString;
00110   bool useArcE4 = false;
00111   bool split = false;
00112   double splitFactor = 0;
00113   bool splitRandomize = false;
00114   bool initBootstrapFromTimeOfDay = false;
00115   string transformerFile;
00116 
00117   // decode command line
00118   int c;
00119   extern char* optarg;
00120   //  extern int optind;
00121   while((c = getopt(argc,argv,"ha:An:y:Q:bg:m:v:f:r:K:DGt:d:w:V:z:")) != EOF ) {
00122     switch( c )
00123       {
00124       case 'h' :
00125         help(argv[0]);
00126         return 1;
00127       case 'a' :
00128         readMode = (optarg==0 ? 0 : atoi(optarg));
00129         break;
00130       case 'A' :
00131         writeMode = SprRWFactory::Ascii;
00132         break;
00133       case 'n' :
00134         cycles = (optarg==0 ? 1 : atoi(optarg));
00135         break;
00136       case 'y' :
00137         inputClassesString = optarg;
00138         break;
00139       case 'Q' :
00140         transformerFile = optarg;
00141         break;
00142       case 'b' :
00143         useArcE4 = true;
00144         break;
00145       case 'g' :
00146         iLoss = (optarg==0 ? 1 : atoi(optarg));
00147         break;
00148       case 'm' :
00149         if( optarg != 0 ) {
00150           setLowCutoff = true;
00151           lowCutoff = atof(optarg);
00152         }
00153         break;
00154       case 'v' :
00155         verbose = (optarg==0 ? 0 : atoi(optarg));
00156         break;
00157       case 'f' :
00158         outFile = optarg;
00159         break;
00160       case 'r' :
00161         resumeFile = optarg;
00162         break;
00163       case 'K' :
00164         split = true;
00165         splitFactor = (optarg==0 ? 0 : atof(optarg));
00166         break;
00167       case 'D' :
00168         splitRandomize = true;
00169         break;
00170       case 'G' :
00171         initBootstrapFromTimeOfDay = true;
00172         break;
00173       case 't' :
00174         valFile = optarg;
00175         break;
00176       case 'd' :
00177         valPrint = (optarg==0 ? 0 : atoi(optarg));
00178         break;
00179       case 'w' :
00180         if( optarg != 0 ) {
00181           scaleWeights = true;
00182           sW = atof(optarg);
00183         }
00184         break;
00185       case 'V' :
00186         includeList = optarg;
00187         break;
00188       case 'z' :
00189         excludeList = optarg;
00190         break;
00191       }
00192   }
00193 
00194   // Must have 2 arguments after all options.
00195   string trFile = argv[argc-2];
00196   string configFile = argv[argc-1]; 
00197   if( trFile.empty() ) {
00198     cerr << "No training file is specified." << endl;
00199     return 1;
00200   }
00201   if( configFile.empty() ) {
00202     cerr << "No classifier configuration file specified." << endl;
00203     return 1;
00204   }
00205 
00206   // make reader
00207   SprRWFactory::DataType inputType 
00208     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00209   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00210 
00211   // include variables
00212   set<string> includeSet;
00213   if( !includeList.empty() ) {
00214     vector<vector<string> > includeVars;
00215     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00216     assert( !includeVars.empty() );
00217     for( int i=0;i<includeVars[0].size();i++ ) 
00218       includeSet.insert(includeVars[0][i]);
00219     if( !reader->chooseVars(includeSet) ) {
00220       cerr << "Unable to include variables in training set." << endl;
00221       return 2;
00222     }
00223     else {
00224       cout << "Following variables have been included in optimization: ";
00225       for( set<string>::const_iterator 
00226              i=includeSet.begin();i!=includeSet.end();i++ )
00227         cout << "\"" << *i << "\"" << " ";
00228       cout << endl;
00229     }
00230   }
00231 
00232   // exclude variables
00233   set<string> excludeSet;
00234   if( !excludeList.empty() ) {
00235     vector<vector<string> > excludeVars;
00236     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00237     assert( !excludeVars.empty() );
00238     for( int i=0;i<excludeVars[0].size();i++ ) 
00239       excludeSet.insert(excludeVars[0][i]);
00240     if( !reader->chooseAllBut(excludeSet) ) {
00241       cerr << "Unable to exclude variables from training set." << endl;
00242       return 2;
00243     }
00244     else {
00245       cout << "Following variables have been excluded from optimization: ";
00246       for( set<string>::const_iterator 
00247              i=excludeSet.begin();i!=excludeSet.end();i++ )
00248         cout << "\"" << *i << "\"" << " ";
00249       cout << endl;
00250     }
00251   }
00252 
00253   // read training data from file
00254   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00255   if( filter.get() == 0 ) {
00256     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00257     return 2;
00258   }
00259   vector<string> vars;
00260   filter->vars(vars);
00261   cout << "Read data from file " << trFile.c_str() << " for variables";
00262   for( int i=0;i<vars.size();i++ ) 
00263     cout << " \"" << vars[i].c_str() << "\"";
00264   cout << endl;
00265   cout << "Total number of points read: " << filter->size() << endl;
00266 
00267   // filter training data by class
00268   vector<SprClass> inputClasses;
00269   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00270     cerr << "Cannot choose input classes for string " 
00271          << inputClassesString << endl;
00272     return 2;
00273   }
00274   filter->classes(inputClasses);
00275   assert( inputClasses.size() > 1 );
00276   cout << "Training data filtered by class." << endl;
00277   for( int i=0;i<inputClasses.size();i++ ) {
00278     cout << "Points in class " << inputClasses[i] << ":   " 
00279          << filter->ptsInClass(inputClasses[i]) << endl;
00280   }
00281 
00282   // scale weights
00283   if( scaleWeights ) {
00284     cout << "Signal weights are multiplied by " << sW << endl;
00285     filter->scaleWeights(inputClasses[1],sW);
00286   }
00287 
00288   // apply low cutoff
00289   if( setLowCutoff ) {
00290     if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00291       cerr << "Unable to replace missing values in training data." << endl;
00292       return 2;
00293     }
00294     else
00295       cout << "Values below " << lowCutoff << " in training data"
00296            << " have been replaced with medians." << endl;
00297   }
00298 
00299   // read validation data from file
00300   auto_ptr<SprAbsFilter> valFilter;
00301   if( split && !valFile.empty() ) {
00302     cerr << "Unable to split training data and use validation data " 
00303          << "from a separate file." << endl;
00304     return 2;
00305   }
00306   if( split && valPrint!=0 ) {
00307     cout << "Splitting training data with factor " << splitFactor << endl;
00308     if( splitRandomize )
00309       cout << "Will use randomized splitting." << endl;
00310     vector<double> weights;
00311     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00312     if( splitted == 0 ) {
00313       cerr << "Unable to split training data." << endl;
00314       return 2;
00315     }
00316     bool ownData = true;
00317     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00318     cout << "Training data re-filtered:" << endl;
00319     for( int i=0;i<inputClasses.size();i++ ) {
00320       cout << "Points in class " << inputClasses[i] << ":   " 
00321            << filter->ptsInClass(inputClasses[i]) << endl;
00322     }
00323   }
00324   if( !valFile.empty() && valPrint!=0 ) {
00325     auto_ptr<SprAbsReader> 
00326       valReader(SprRWFactory::makeReader(inputType,readMode));
00327     if( !includeSet.empty() ) {
00328       if( !valReader->chooseVars(includeSet) ) {
00329         cerr << "Unable to include variables in validation set." << endl;
00330         return 2;
00331       }
00332     }
00333     if( !excludeSet.empty() ) {
00334       if( !valReader->chooseAllBut(excludeSet) ) {
00335         cerr << "Unable to exclude variables from validation set." << endl;
00336         return 2;
00337       }
00338     }
00339     valFilter.reset(valReader->read(valFile.c_str()));
00340     if( valFilter.get() == 0 ) {
00341       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00342       return 2;
00343     }
00344     vector<string> valVars;
00345     valFilter->vars(valVars);
00346     cout << "Read validation data from file " << valFile.c_str() 
00347          << " for variables";
00348     for( int i=0;i<valVars.size();i++ ) 
00349       cout << " \"" << valVars[i].c_str() << "\"";
00350     cout << endl;
00351     cout << "Total number of points read: " << valFilter->size() << endl;
00352     cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0])
00353          << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl;
00354   }
00355 
00356   // filter validation data by class
00357   if( valFilter.get() != 0 ) {
00358     if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00359       cerr << "Cannot choose input classes for string " 
00360            << inputClassesString << endl;
00361       return 2;
00362     }
00363     valFilter->classes(inputClasses);
00364     cout << "Validation data filtered by class." << endl;
00365     for( int i=0;i<inputClasses.size();i++ ) {
00366       cout << "Points in class " << inputClasses[i] << ":   " 
00367            << valFilter->ptsInClass(inputClasses[i]) << endl;
00368     }
00369   }
00370 
00371   // scale weights
00372   if( scaleWeights && valFilter.get()!=0 )
00373     valFilter->scaleWeights(inputClasses[1],sW);
00374 
00375   // apply low cutoff
00376   if( setLowCutoff && valFilter.get()!=0 ) {
00377     if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00378       cerr << "Unable to replace missing values in validation data." << endl;
00379       return 2;
00380     }
00381     else
00382       cout << "Values below " << lowCutoff << " in validation data"
00383            << " have been replaced with medians." << endl;
00384   }
00385 
00386   // apply transformation of variables to training and test data
00387   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00388   if( !transformerFile.empty() ) {
00389     SprVarTransformerReader transReader;
00390     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00391     if( t == 0 ) {
00392       cerr << "Unable to read VarTransformer from file "
00393            << transformerFile.c_str() << endl;
00394       return 2;
00395     }
00396     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00397     SprTransformerFilter* t_valid = 0;
00398     if( valFilter.get() != 0 )
00399       t_valid = new SprTransformerFilter(valFilter.get());
00400     bool replaceOriginalData = true;
00401     if( !t_train->transform(t,replaceOriginalData) ) {
00402       cerr << "Unable to apply VarTransformer to training data." << endl;
00403       return 2;
00404     }
00405     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00406       cerr << "Unable to apply VarTransformer to validation data." << endl;
00407       return 2;
00408     }
00409     cout << "Variable transformation from file "
00410          << transformerFile.c_str() << " has been applied to "
00411          << "training and validation data." << endl;
00412     garbage_train.reset(filter.release());
00413     garbage_valid.reset(valFilter.release());
00414     filter.reset(t_train);
00415     valFilter.reset(t_valid);
00416   }
00417 
00418   // make per-event loss
00419   auto_ptr<SprAverageLoss> loss;
00420   switch( iLoss )
00421     {
00422     case 1 :
00423       loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00424       cout << "Per-event loss set to "
00425            << "Quadratic loss (y-f(x))^2 " << endl;
00426       break;
00427     case 2 :
00428       loss.reset(new SprAverageLoss(&SprLoss::purity_ratio));
00429       cout << "Per-event loss set to "
00430            << "Exponential loss exp(-y*f(x)) " << endl;
00431       break;
00432     default :
00433       cout << "No per-event loss is chosen. Will use the default." << endl;
00434       break;
00435     }
00436 
00437   // open file with classifier configs
00438   ifstream file(configFile.c_str());
00439   if( !file ) {
00440     cerr << "Unable to open file " << configFile.c_str() << endl;
00441     return 3;
00442   }
00443 
00444   // prepare vectors of objects
00445   vector<SprAbsTwoClassCriterion*> criteria;
00446   vector<SprAbsClassifier*> destroyC;// classifiers to be deleted
00447   vector<SprIntegerBootstrap*> bstraps;
00448   vector<SprCCPair> useC;// classifiers and cuts to be used
00449 
00450   // read classifier params
00451   unsigned nLine = 0;
00452   bool discreteTree = false;
00453   bool mixedNodesTree = false;
00454   bool readOneEntry = false;
00455   bool fastSort = true;
00456   if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(),
00457                                                 discreteTree,mixedNodesTree,
00458                                                 fastSort,criteria,
00459                                                 bstraps,destroyC,useC,
00460                                                 readOneEntry) ) {
00461     cerr << "Unable to read weak classifier configurations from file " 
00462          << configFile.c_str() << endl;
00463     prepareExit(criteria,destroyC,bstraps);
00464     return 4;
00465   }
00466   cout << "Finished reading " << useC.size() << " classifiers from file "
00467        << configFile.c_str() << endl;
00468 
00469   // make Bagger
00470   auto_ptr<SprBagger> bagger;
00471   bool discrete = false;
00472   if( useArcE4 )
00473     bagger.reset(new SprArcE4(filter.get(),cycles,discrete));
00474   else
00475     bagger.reset(new SprBagger(filter.get(),cycles,discrete));
00476 
00477   // set seed for bootstrap if necessary
00478   if( initBootstrapFromTimeOfDay && !bagger->initBootstrapFromTimeOfDay() ) {
00479     cerr << "Unable to generate seed from time of day for Bagger." << endl;
00480     return 4;
00481   }
00482 
00483   // set validation
00484   if( valFilter.get()!=0 && !valFilter->empty() )
00485     bagger->setValidation(valFilter.get(),valPrint,0,loss.get());
00486 
00487   // read saved Bagger from file
00488   if( !resumeFile.empty() ) {
00489     if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00490                                             bagger.get(),verbose) ) {
00491       cerr << "Failed to read saved Bagger from file " 
00492            << resumeFile.c_str() << endl;
00493       prepareExit(criteria,destroyC,bstraps);
00494       return 5;
00495     }
00496     cout << "Read saved Bagger from file " << resumeFile.c_str()
00497          << " with " << bagger->nTrained() << " trained classifiers." << endl;
00498   }
00499 
00500   // add trainable classifiers
00501   for( int i=0;i<useC.size();i++ ) {
00502     if( !bagger->addTrainable(useC[i].first) ) {
00503       cerr << "Unable to add classifier " << i << " of type " 
00504            << useC[i].first->name() << " to Bagger." << endl;
00505       prepareExit(criteria,destroyC,bstraps);
00506       return 6;
00507     }
00508   }
00509 
00510   // train
00511   if( !bagger->train(verbose) )
00512     cerr << "Bagger terminated with error." << endl;
00513   if( bagger->nTrained() == 0 ) {
00514     cerr << "Unable to train Bagger." << endl;
00515     prepareExit(criteria,destroyC,bstraps);
00516     return 7;
00517   }
00518   else {
00519     cout << "Bagger finished training with " << bagger->nTrained() 
00520          << " classifiers." << endl;
00521   }
00522 
00523   // save trained Bagger
00524   if( !outFile.empty() ) {
00525     if( !bagger->store(outFile.c_str()) ) {
00526       cerr << "Cannot store Bagger in file " << outFile.c_str() << endl;
00527       prepareExit(criteria,destroyC,bstraps);
00528       return 8;
00529     }
00530   }
00531 
00532   // exit
00533   prepareExit(criteria,destroyC,bstraps);
00534   return 0;
00535 }

Generated on Tue Jun 9 17:41:59 2009 for CMSSW by  doxygen 1.5.4