![]() |
![]() |
#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <set>
#include <string>
#include <memory>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void | prepareExit (vector< SprAbsTwoClassCriterion * > &criteria, vector< SprAbsClassifier * > &classifiers, vector< SprIntegerBootstrap * > &bstraps) |
void help | ( | const char * | prog | ) |
Definition at line 39 of file SprBaggerApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00040 { 00041 cout << "Usage: " << prog 00042 << " training_data_file" 00043 << " file_of_classifier_parameters(see booster.config for syntax)" 00044 << endl; 00045 cout << "\t Options: " << endl; 00046 cout << "\t-h --- help " << endl; 00047 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00048 cout << "\t-A save output data in ascii instead of Root " << endl; 00049 cout << "\t-n number of Bagger training cycles " << endl; 00050 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl; 00051 cout << "\t-Q apply variable transformation saved in file " << endl; 00052 cout << "\t-b use a version of Breiman's arc-x4 algorithm " << endl; 00053 cout << "\t-g per-event loss for (cross-)validation " << endl; 00054 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl; 00055 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl; 00056 cout << "\t-m replace data values below this cutoff with medians" << endl; 00057 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00058 cout << "\t-f store trained AdaBoost to file " << endl; 00059 cout << "\t-r resume training for Bagger stored in file " << endl; 00060 cout << "\t-K keep this fraction in training set and " << endl; 00061 cout << "\t\t put the rest into validation set " << endl; 00062 cout << "\t-D randomize training set split-up " << endl; 00063 cout << "\t-G generate seed from time of day for bootstrap " << endl; 00064 cout << "\t\t (this option is required for parallelization) " << endl; 00065 cout << "\t-t read validation/test data from a file " << endl; 00066 cout << "\t\t (must be in same format as input data!!! " << endl; 00067 cout << "\t-d frequency of print-outs for validation data " << endl; 00068 cout << "\t-w scale all signal weights by this factor " << endl; 00069 cout << "\t-V include only these input variables " << endl; 00070 cout << "\t-z exclude input variables from the list " << endl; 00071 cout << "\t\t Variables must be listed in quotes and separated by commas." 00072 << endl; 00073 }
Definition at line 86 of file SprBaggerApp.cc.
References c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, lat::endl(), file, filter, help(), i, prepareExit(), split, t, vars, and weights.
00087 { 00088 // check command line 00089 if( argc < 3 ) { 00090 help(argv[0]); 00091 return 1; 00092 } 00093 00094 // init 00095 int readMode = 0; 00096 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00097 unsigned cycles = 0; 00098 int verbose = 0; 00099 string outFile; 00100 string resumeFile; 00101 string valFile; 00102 unsigned valPrint = 0; 00103 bool scaleWeights = false; 00104 double sW = 1.; 00105 bool setLowCutoff = false; 00106 double lowCutoff = 0; 00107 string includeList, excludeList; 00108 int iLoss = 1; 00109 string inputClassesString; 00110 bool useArcE4 = false; 00111 bool split = false; 00112 double splitFactor = 0; 00113 bool splitRandomize = false; 00114 bool initBootstrapFromTimeOfDay = false; 00115 string transformerFile; 00116 00117 // decode command line 00118 int c; 00119 extern char* optarg; 00120 // extern int optind; 00121 while((c = getopt(argc,argv,"ha:An:y:Q:bg:m:v:f:r:K:DGt:d:w:V:z:")) != EOF ) { 00122 switch( c ) 00123 { 00124 case 'h' : 00125 help(argv[0]); 00126 return 1; 00127 case 'a' : 00128 readMode = (optarg==0 ? 0 : atoi(optarg)); 00129 break; 00130 case 'A' : 00131 writeMode = SprRWFactory::Ascii; 00132 break; 00133 case 'n' : 00134 cycles = (optarg==0 ? 1 : atoi(optarg)); 00135 break; 00136 case 'y' : 00137 inputClassesString = optarg; 00138 break; 00139 case 'Q' : 00140 transformerFile = optarg; 00141 break; 00142 case 'b' : 00143 useArcE4 = true; 00144 break; 00145 case 'g' : 00146 iLoss = (optarg==0 ? 1 : atoi(optarg)); 00147 break; 00148 case 'm' : 00149 if( optarg != 0 ) { 00150 setLowCutoff = true; 00151 lowCutoff = atof(optarg); 00152 } 00153 break; 00154 case 'v' : 00155 verbose = (optarg==0 ? 0 : atoi(optarg)); 00156 break; 00157 case 'f' : 00158 outFile = optarg; 00159 break; 00160 case 'r' : 00161 resumeFile = optarg; 00162 break; 00163 case 'K' : 00164 split = true; 00165 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00166 break; 00167 case 'D' : 00168 splitRandomize = true; 00169 break; 00170 case 'G' : 00171 initBootstrapFromTimeOfDay = true; 00172 break; 00173 case 't' : 00174 valFile = optarg; 00175 break; 00176 case 'd' : 00177 valPrint = (optarg==0 ? 0 : atoi(optarg)); 00178 break; 00179 case 'w' : 00180 if( optarg != 0 ) { 00181 scaleWeights = true; 00182 sW = atof(optarg); 00183 } 00184 break; 00185 case 'V' : 00186 includeList = optarg; 00187 break; 00188 case 'z' : 00189 excludeList = optarg; 00190 break; 00191 } 00192 } 00193 00194 // Must have 2 arguments after all options. 00195 string trFile = argv[argc-2]; 00196 string configFile = argv[argc-1]; 00197 if( trFile.empty() ) { 00198 cerr << "No training file is specified." << endl; 00199 return 1; 00200 } 00201 if( configFile.empty() ) { 00202 cerr << "No classifier configuration file specified." << endl; 00203 return 1; 00204 } 00205 00206 // make reader 00207 SprRWFactory::DataType inputType 00208 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00209 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00210 00211 // include variables 00212 set<string> includeSet; 00213 if( !includeList.empty() ) { 00214 vector<vector<string> > includeVars; 00215 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00216 assert( !includeVars.empty() ); 00217 for( int i=0;i<includeVars[0].size();i++ ) 00218 includeSet.insert(includeVars[0][i]); 00219 if( !reader->chooseVars(includeSet) ) { 00220 cerr << "Unable to include variables in training set." << endl; 00221 return 2; 00222 } 00223 else { 00224 cout << "Following variables have been included in optimization: "; 00225 for( set<string>::const_iterator 00226 i=includeSet.begin();i!=includeSet.end();i++ ) 00227 cout << "\"" << *i << "\"" << " "; 00228 cout << endl; 00229 } 00230 } 00231 00232 // exclude variables 00233 set<string> excludeSet; 00234 if( !excludeList.empty() ) { 00235 vector<vector<string> > excludeVars; 00236 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00237 assert( !excludeVars.empty() ); 00238 for( int i=0;i<excludeVars[0].size();i++ ) 00239 excludeSet.insert(excludeVars[0][i]); 00240 if( !reader->chooseAllBut(excludeSet) ) { 00241 cerr << "Unable to exclude variables from training set." << endl; 00242 return 2; 00243 } 00244 else { 00245 cout << "Following variables have been excluded from optimization: "; 00246 for( set<string>::const_iterator 00247 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00248 cout << "\"" << *i << "\"" << " "; 00249 cout << endl; 00250 } 00251 } 00252 00253 // read training data from file 00254 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00255 if( filter.get() == 0 ) { 00256 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00257 return 2; 00258 } 00259 vector<string> vars; 00260 filter->vars(vars); 00261 cout << "Read data from file " << trFile.c_str() << " for variables"; 00262 for( int i=0;i<vars.size();i++ ) 00263 cout << " \"" << vars[i].c_str() << "\""; 00264 cout << endl; 00265 cout << "Total number of points read: " << filter->size() << endl; 00266 00267 // filter training data by class 00268 vector<SprClass> inputClasses; 00269 if( !filter->filterByClass(inputClassesString.c_str()) ) { 00270 cerr << "Cannot choose input classes for string " 00271 << inputClassesString << endl; 00272 return 2; 00273 } 00274 filter->classes(inputClasses); 00275 assert( inputClasses.size() > 1 ); 00276 cout << "Training data filtered by class." << endl; 00277 for( int i=0;i<inputClasses.size();i++ ) { 00278 cout << "Points in class " << inputClasses[i] << ": " 00279 << filter->ptsInClass(inputClasses[i]) << endl; 00280 } 00281 00282 // scale weights 00283 if( scaleWeights ) { 00284 cout << "Signal weights are multiplied by " << sW << endl; 00285 filter->scaleWeights(inputClasses[1],sW); 00286 } 00287 00288 // apply low cutoff 00289 if( setLowCutoff ) { 00290 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00291 cerr << "Unable to replace missing values in training data." << endl; 00292 return 2; 00293 } 00294 else 00295 cout << "Values below " << lowCutoff << " in training data" 00296 << " have been replaced with medians." << endl; 00297 } 00298 00299 // read validation data from file 00300 auto_ptr<SprAbsFilter> valFilter; 00301 if( split && !valFile.empty() ) { 00302 cerr << "Unable to split training data and use validation data " 00303 << "from a separate file." << endl; 00304 return 2; 00305 } 00306 if( split && valPrint!=0 ) { 00307 cout << "Splitting training data with factor " << splitFactor << endl; 00308 if( splitRandomize ) 00309 cout << "Will use randomized splitting." << endl; 00310 vector<double> weights; 00311 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00312 if( splitted == 0 ) { 00313 cerr << "Unable to split training data." << endl; 00314 return 2; 00315 } 00316 bool ownData = true; 00317 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00318 cout << "Training data re-filtered:" << endl; 00319 for( int i=0;i<inputClasses.size();i++ ) { 00320 cout << "Points in class " << inputClasses[i] << ": " 00321 << filter->ptsInClass(inputClasses[i]) << endl; 00322 } 00323 } 00324 if( !valFile.empty() && valPrint!=0 ) { 00325 auto_ptr<SprAbsReader> 00326 valReader(SprRWFactory::makeReader(inputType,readMode)); 00327 if( !includeSet.empty() ) { 00328 if( !valReader->chooseVars(includeSet) ) { 00329 cerr << "Unable to include variables in validation set." << endl; 00330 return 2; 00331 } 00332 } 00333 if( !excludeSet.empty() ) { 00334 if( !valReader->chooseAllBut(excludeSet) ) { 00335 cerr << "Unable to exclude variables from validation set." << endl; 00336 return 2; 00337 } 00338 } 00339 valFilter.reset(valReader->read(valFile.c_str())); 00340 if( valFilter.get() == 0 ) { 00341 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00342 return 2; 00343 } 00344 vector<string> valVars; 00345 valFilter->vars(valVars); 00346 cout << "Read validation data from file " << valFile.c_str() 00347 << " for variables"; 00348 for( int i=0;i<valVars.size();i++ ) 00349 cout << " \"" << valVars[i].c_str() << "\""; 00350 cout << endl; 00351 cout << "Total number of points read: " << valFilter->size() << endl; 00352 cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0]) 00353 << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl; 00354 } 00355 00356 // filter validation data by class 00357 if( valFilter.get() != 0 ) { 00358 if( !valFilter->filterByClass(inputClassesString.c_str()) ) { 00359 cerr << "Cannot choose input classes for string " 00360 << inputClassesString << endl; 00361 return 2; 00362 } 00363 valFilter->classes(inputClasses); 00364 cout << "Validation data filtered by class." << endl; 00365 for( int i=0;i<inputClasses.size();i++ ) { 00366 cout << "Points in class " << inputClasses[i] << ": " 00367 << valFilter->ptsInClass(inputClasses[i]) << endl; 00368 } 00369 } 00370 00371 // scale weights 00372 if( scaleWeights && valFilter.get()!=0 ) 00373 valFilter->scaleWeights(inputClasses[1],sW); 00374 00375 // apply low cutoff 00376 if( setLowCutoff && valFilter.get()!=0 ) { 00377 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00378 cerr << "Unable to replace missing values in validation data." << endl; 00379 return 2; 00380 } 00381 else 00382 cout << "Values below " << lowCutoff << " in validation data" 00383 << " have been replaced with medians." << endl; 00384 } 00385 00386 // apply transformation of variables to training and test data 00387 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00388 if( !transformerFile.empty() ) { 00389 SprVarTransformerReader transReader; 00390 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00391 if( t == 0 ) { 00392 cerr << "Unable to read VarTransformer from file " 00393 << transformerFile.c_str() << endl; 00394 return 2; 00395 } 00396 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00397 SprTransformerFilter* t_valid = 0; 00398 if( valFilter.get() != 0 ) 00399 t_valid = new SprTransformerFilter(valFilter.get()); 00400 bool replaceOriginalData = true; 00401 if( !t_train->transform(t,replaceOriginalData) ) { 00402 cerr << "Unable to apply VarTransformer to training data." << endl; 00403 return 2; 00404 } 00405 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00406 cerr << "Unable to apply VarTransformer to validation data." << endl; 00407 return 2; 00408 } 00409 cout << "Variable transformation from file " 00410 << transformerFile.c_str() << " has been applied to " 00411 << "training and validation data." << endl; 00412 garbage_train.reset(filter.release()); 00413 garbage_valid.reset(valFilter.release()); 00414 filter.reset(t_train); 00415 valFilter.reset(t_valid); 00416 } 00417 00418 // make per-event loss 00419 auto_ptr<SprAverageLoss> loss; 00420 switch( iLoss ) 00421 { 00422 case 1 : 00423 loss.reset(new SprAverageLoss(&SprLoss::quadratic)); 00424 cout << "Per-event loss set to " 00425 << "Quadratic loss (y-f(x))^2 " << endl; 00426 break; 00427 case 2 : 00428 loss.reset(new SprAverageLoss(&SprLoss::purity_ratio)); 00429 cout << "Per-event loss set to " 00430 << "Exponential loss exp(-y*f(x)) " << endl; 00431 break; 00432 default : 00433 cout << "No per-event loss is chosen. Will use the default." << endl; 00434 break; 00435 } 00436 00437 // open file with classifier configs 00438 ifstream file(configFile.c_str()); 00439 if( !file ) { 00440 cerr << "Unable to open file " << configFile.c_str() << endl; 00441 return 3; 00442 } 00443 00444 // prepare vectors of objects 00445 vector<SprAbsTwoClassCriterion*> criteria; 00446 vector<SprAbsClassifier*> destroyC;// classifiers to be deleted 00447 vector<SprIntegerBootstrap*> bstraps; 00448 vector<SprCCPair> useC;// classifiers and cuts to be used 00449 00450 // read classifier params 00451 unsigned nLine = 0; 00452 bool discreteTree = false; 00453 bool mixedNodesTree = false; 00454 bool readOneEntry = false; 00455 bool fastSort = true; 00456 if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(), 00457 discreteTree,mixedNodesTree, 00458 fastSort,criteria, 00459 bstraps,destroyC,useC, 00460 readOneEntry) ) { 00461 cerr << "Unable to read weak classifier configurations from file " 00462 << configFile.c_str() << endl; 00463 prepareExit(criteria,destroyC,bstraps); 00464 return 4; 00465 } 00466 cout << "Finished reading " << useC.size() << " classifiers from file " 00467 << configFile.c_str() << endl; 00468 00469 // make Bagger 00470 auto_ptr<SprBagger> bagger; 00471 bool discrete = false; 00472 if( useArcE4 ) 00473 bagger.reset(new SprArcE4(filter.get(),cycles,discrete)); 00474 else 00475 bagger.reset(new SprBagger(filter.get(),cycles,discrete)); 00476 00477 // set seed for bootstrap if necessary 00478 if( initBootstrapFromTimeOfDay && !bagger->initBootstrapFromTimeOfDay() ) { 00479 cerr << "Unable to generate seed from time of day for Bagger." << endl; 00480 return 4; 00481 } 00482 00483 // set validation 00484 if( valFilter.get()!=0 && !valFilter->empty() ) 00485 bagger->setValidation(valFilter.get(),valPrint,0,loss.get()); 00486 00487 // read saved Bagger from file 00488 if( !resumeFile.empty() ) { 00489 if( !SprClassifierReader::readTrainable(resumeFile.c_str(), 00490 bagger.get(),verbose) ) { 00491 cerr << "Failed to read saved Bagger from file " 00492 << resumeFile.c_str() << endl; 00493 prepareExit(criteria,destroyC,bstraps); 00494 return 5; 00495 } 00496 cout << "Read saved Bagger from file " << resumeFile.c_str() 00497 << " with " << bagger->nTrained() << " trained classifiers." << endl; 00498 } 00499 00500 // add trainable classifiers 00501 for( int i=0;i<useC.size();i++ ) { 00502 if( !bagger->addTrainable(useC[i].first) ) { 00503 cerr << "Unable to add classifier " << i << " of type " 00504 << useC[i].first->name() << " to Bagger." << endl; 00505 prepareExit(criteria,destroyC,bstraps); 00506 return 6; 00507 } 00508 } 00509 00510 // train 00511 if( !bagger->train(verbose) ) 00512 cerr << "Bagger terminated with error." << endl; 00513 if( bagger->nTrained() == 0 ) { 00514 cerr << "Unable to train Bagger." << endl; 00515 prepareExit(criteria,destroyC,bstraps); 00516 return 7; 00517 } 00518 else { 00519 cout << "Bagger finished training with " << bagger->nTrained() 00520 << " classifiers." << endl; 00521 } 00522 00523 // save trained Bagger 00524 if( !outFile.empty() ) { 00525 if( !bagger->store(outFile.c_str()) ) { 00526 cerr << "Cannot store Bagger in file " << outFile.c_str() << endl; 00527 prepareExit(criteria,destroyC,bstraps); 00528 return 8; 00529 } 00530 } 00531 00532 // exit 00533 prepareExit(criteria,destroyC,bstraps); 00534 return 0; 00535 }
void prepareExit | ( | vector< SprAbsTwoClassCriterion * > & | criteria, | |
vector< SprAbsClassifier * > & | classifiers, | |||
vector< SprIntegerBootstrap * > & | bstraps | |||
) |