#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <set>
#include <string>
#include <memory>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void | prepareExit (vector< SprAbsTwoClassCriterion * > &criteria, vector< SprAbsClassifier * > &classifiers, vector< SprIntegerBootstrap * > &bstraps) |
void help | ( | const char * | prog | ) |
Definition at line 38 of file SprBoosterApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00039 { 00040 cout << "Usage: " << prog 00041 << " training_data_file" 00042 << " file_of_classifier_parameters(see booster.config for syntax)" 00043 << endl; 00044 cout << "\t Options: " << endl; 00045 cout << "\t-h --- help " << endl; 00046 cout << "\t-M AdaBoost mode " << endl; 00047 cout << "\t\t 1 = Discrete AdaBoost (default) " << endl; 00048 cout << "\t\t 2 = Real AdaBoost " << endl; 00049 cout << "\t\t 3 = Epsilon AdaBoost " << endl; 00050 cout << "\t-E epsilon for Epsilon and Real AdaBoosts (def=0.01)" << endl; 00051 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00052 cout << "\t-A save output data in ascii instead of Root " << endl; 00053 cout << "\t-n number of AdaBoost training cycles " << endl; 00054 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl; 00055 cout << "\t-Q apply variable transformation saved in file " << endl; 00056 cout << "\t-g per-event loss for (cross-)validation " << endl; 00057 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl; 00058 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl; 00059 cout << "\t-b bootstrap training sample " << endl; 00060 cout << "\t-m replace data values below this cutoff with medians" << endl; 00061 cout << "\t-e skip initial event reweighting when resuming " << endl; 00062 cout << "\t-u store data with modified weights to file " << endl; 00063 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00064 cout << "\t-f store trained AdaBoost to file " << endl; 00065 cout << "\t-r resume training for AdaBoost stored in file " << endl; 00066 cout << "\t-K keep this fraction in training set and " << endl; 00067 cout << "\t\t put the rest into validation set " << endl; 00068 cout << "\t-D randomize training set split-up " << endl; 00069 cout << "\t-t read validation/test data from a file " << endl; 00070 cout << "\t\t (must be in same format as input data!!! " << endl; 00071 cout << "\t-d frequency of print-outs for validation data " << endl; 00072 cout << "\t-w scale all signal weights by this factor " << endl; 00073 cout << "\t-V include only these input variables " << endl; 00074 cout << "\t-z exclude input variables from the list " << endl; 00075 cout << "\t\t Variables must be listed in quotes and separated by commas." 00076 << endl; 00077 }
Definition at line 90 of file SprBoosterApp.cc.
References c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, lat::endl(), HLT_VtxMuL3::Epsilon, geometryDiff::epsilon, file, filter, help(), i, prepareExit(), split, t, vars, and weights.
00091 { 00092 // check command line 00093 if( argc < 3 ) { 00094 help(argv[0]); 00095 return 1; 00096 } 00097 00098 // init 00099 int readMode = 0; 00100 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00101 unsigned cycles = 0; 00102 int verbose = 0; 00103 string outFile; 00104 string resumeFile; 00105 string valFile; 00106 unsigned valPrint = 0; 00107 bool scaleWeights = false; 00108 double sW = 1.; 00109 bool useStandardAB = false; 00110 int iAdaBoostMode = 1; 00111 double epsilon = 0.01; 00112 bool skipInitialEventReweighting = false; 00113 string weightedDataOut; 00114 bool setLowCutoff = false; 00115 double lowCutoff = 0; 00116 string includeList, excludeList; 00117 int iLoss = 0; 00118 bool bagInput = false; 00119 string inputClassesString; 00120 bool split = false; 00121 double splitFactor = 0; 00122 bool splitRandomize = false; 00123 string transformerFile; 00124 00125 // decode command line 00126 int c; 00127 extern char* optarg; 00128 // extern int optind; 00129 while((c = getopt(argc,argv,"hM:E:a:An:y:Q:g:bm:eu:v:f:r:K:Dt:d:w:V:z:")) != EOF ) { 00130 switch( c ) 00131 { 00132 case 'h' : 00133 help(argv[0]); 00134 return 1; 00135 case 'M' : 00136 iAdaBoostMode = (optarg==0 ? 1 : atoi(optarg)); 00137 break; 00138 case 'E' : 00139 epsilon = (optarg==0 ? 0.01 : atof(optarg)); 00140 break; 00141 case 'a' : 00142 readMode = (optarg==0 ? 0 : atoi(optarg)); 00143 break; 00144 case 'A' : 00145 writeMode = SprRWFactory::Ascii; 00146 break; 00147 case 'n' : 00148 cycles = (optarg==0 ? 1 : atoi(optarg)); 00149 break; 00150 case 'y' : 00151 inputClassesString = optarg; 00152 break; 00153 case 'g' : 00154 iLoss = (optarg==0 ? 0 : atoi(optarg)); 00155 break; 00156 case 'b' : 00157 bagInput = true; 00158 break; 00159 case 'm' : 00160 if( optarg != 0 ) { 00161 setLowCutoff = true; 00162 lowCutoff = atof(optarg); 00163 } 00164 break; 00165 case 'e' : 00166 skipInitialEventReweighting = true; 00167 break; 00168 case 'u' : 00169 weightedDataOut = optarg; 00170 break; 00171 case 'v' : 00172 verbose = (optarg==0 ? 0 : atoi(optarg)); 00173 break; 00174 case 'f' : 00175 outFile = optarg; 00176 break; 00177 case 'r' : 00178 resumeFile = optarg; 00179 break; 00180 case 'K' : 00181 split = true; 00182 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00183 break; 00184 case 'D' : 00185 splitRandomize = true; 00186 break; 00187 case 't' : 00188 valFile = optarg; 00189 break; 00190 case 'd' : 00191 valPrint = (optarg==0 ? 0 : atoi(optarg)); 00192 break; 00193 case 'w' : 00194 if( optarg != 0 ) { 00195 scaleWeights = true; 00196 sW = atof(optarg); 00197 } 00198 break; 00199 case 'V' : 00200 includeList = optarg; 00201 break; 00202 case 'z' : 00203 excludeList = optarg; 00204 break; 00205 } 00206 } 00207 00208 // Must have 2 arguments after all options. 00209 string trFile = argv[argc-2]; 00210 string configFile = argv[argc-1]; 00211 if( trFile.empty() ) { 00212 cerr << "No training file is specified." << endl; 00213 return 1; 00214 } 00215 if( configFile.empty() ) { 00216 cerr << "No classifier configuration file specified." << endl; 00217 return 1; 00218 } 00219 00220 // make reader 00221 SprRWFactory::DataType inputType 00222 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00223 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00224 00225 // include variables 00226 set<string> includeSet; 00227 if( !includeList.empty() ) { 00228 vector<vector<string> > includeVars; 00229 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00230 assert( !includeVars.empty() ); 00231 for( int i=0;i<includeVars[0].size();i++ ) 00232 includeSet.insert(includeVars[0][i]); 00233 if( !reader->chooseVars(includeSet) ) { 00234 cerr << "Unable to include variables in training set." << endl; 00235 return 2; 00236 } 00237 else { 00238 cout << "Following variables have been included in optimization: "; 00239 for( set<string>::const_iterator 00240 i=includeSet.begin();i!=includeSet.end();i++ ) 00241 cout << "\"" << *i << "\"" << " "; 00242 cout << endl; 00243 } 00244 } 00245 00246 // exclude variables 00247 set<string> excludeSet; 00248 if( !excludeList.empty() ) { 00249 vector<vector<string> > excludeVars; 00250 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00251 assert( !excludeVars.empty() ); 00252 for( int i=0;i<excludeVars[0].size();i++ ) 00253 excludeSet.insert(excludeVars[0][i]); 00254 if( !reader->chooseAllBut(excludeSet) ) { 00255 cerr << "Unable to exclude variables from training set." << endl; 00256 return 2; 00257 } 00258 else { 00259 cout << "Following variables have been excluded from optimization: "; 00260 for( set<string>::const_iterator 00261 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00262 cout << "\"" << *i << "\"" << " "; 00263 cout << endl; 00264 } 00265 } 00266 00267 // read training data from file 00268 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00269 if( filter.get() == 0 ) { 00270 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00271 return 2; 00272 } 00273 vector<string> vars; 00274 filter->vars(vars); 00275 cout << "Read data from file " << trFile.c_str() << " for variables"; 00276 for( int i=0;i<vars.size();i++ ) 00277 cout << " \"" << vars[i].c_str() << "\""; 00278 cout << endl; 00279 cout << "Total number of points read: " << filter->size() << endl; 00280 00281 // filter training data by class 00282 vector<SprClass> inputClasses; 00283 if( !filter->filterByClass(inputClassesString.c_str()) ) { 00284 cerr << "Cannot choose input classes for string " 00285 << inputClassesString << endl; 00286 return 2; 00287 } 00288 filter->classes(inputClasses); 00289 assert( inputClasses.size() > 1 ); 00290 cout << "Training data filtered by class." << endl; 00291 for( int i=0;i<inputClasses.size();i++ ) { 00292 cout << "Points in class " << inputClasses[i] << ": " 00293 << filter->ptsInClass(inputClasses[i]) << endl; 00294 } 00295 00296 // scale weights 00297 if( scaleWeights ) { 00298 cout << "Signal weights are multiplied by " << sW << endl; 00299 filter->scaleWeights(inputClasses[1],sW); 00300 } 00301 00302 // apply low cutoff 00303 if( setLowCutoff ) { 00304 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00305 cerr << "Unable to replace missing values in training data." << endl; 00306 return 2; 00307 } 00308 else 00309 cout << "Values below " << lowCutoff << " in training data" 00310 << " have been replaced with medians." << endl; 00311 } 00312 00313 // read validation data from file 00314 auto_ptr<SprAbsFilter> valFilter; 00315 if( split && !valFile.empty() ) { 00316 cerr << "Unable to split training data and use validation data " 00317 << "from a separate file." << endl; 00318 return 2; 00319 } 00320 if( split && valPrint!=0 ) { 00321 cout << "Splitting training data with factor " << splitFactor << endl; 00322 if( splitRandomize ) 00323 cout << "Will use randomized splitting." << endl; 00324 vector<double> weights; 00325 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00326 if( splitted == 0 ) { 00327 cerr << "Unable to split training data." << endl; 00328 return 2; 00329 } 00330 bool ownData = true; 00331 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00332 cout << "Training data re-filtered:" << endl; 00333 for( int i=0;i<inputClasses.size();i++ ) { 00334 cout << "Points in class " << inputClasses[i] << ": " 00335 << filter->ptsInClass(inputClasses[i]) << endl; 00336 } 00337 } 00338 if( !valFile.empty() && valPrint!=0 ) { 00339 auto_ptr<SprAbsReader> 00340 valReader(SprRWFactory::makeReader(inputType,readMode)); 00341 if( !includeSet.empty() ) { 00342 if( !valReader->chooseVars(includeSet) ) { 00343 cerr << "Unable to include variables in validation set." << endl; 00344 return 2; 00345 } 00346 } 00347 if( !excludeSet.empty() ) { 00348 if( !valReader->chooseAllBut(excludeSet) ) { 00349 cerr << "Unable to exclude variables from validation set." << endl; 00350 return 2; 00351 } 00352 } 00353 valFilter.reset(valReader->read(valFile.c_str())); 00354 if( valFilter.get() == 0 ) { 00355 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00356 return 2; 00357 } 00358 vector<string> valVars; 00359 valFilter->vars(valVars); 00360 cout << "Read validation data from file " << valFile.c_str() 00361 << " for variables"; 00362 for( int i=0;i<valVars.size();i++ ) 00363 cout << " \"" << valVars[i].c_str() << "\""; 00364 cout << endl; 00365 cout << "Total number of points read: " << valFilter->size() << endl; 00366 cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0]) 00367 << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl; 00368 } 00369 00370 // filter validation data by class 00371 if( valFilter.get() != 0 ) { 00372 if( !valFilter->filterByClass(inputClassesString.c_str()) ) { 00373 cerr << "Cannot choose input classes for string " 00374 << inputClassesString << endl; 00375 return 2; 00376 } 00377 valFilter->classes(inputClasses); 00378 cout << "Validation data filtered by class." << endl; 00379 for( int i=0;i<inputClasses.size();i++ ) { 00380 cout << "Points in class " << inputClasses[i] << ": " 00381 << valFilter->ptsInClass(inputClasses[i]) << endl; 00382 } 00383 } 00384 00385 // scale weights 00386 if( scaleWeights && valFilter.get()!=0 ) 00387 valFilter->scaleWeights(inputClasses[1],sW); 00388 00389 // apply low cutoff 00390 if( setLowCutoff && valFilter.get()!=0 ) { 00391 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00392 cerr << "Unable to replace missing values in validation data." << endl; 00393 return 2; 00394 } 00395 else 00396 cout << "Values below " << lowCutoff << " in validation data" 00397 << " have been replaced with medians." << endl; 00398 } 00399 00400 // apply transformation of variables to training and test data 00401 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00402 if( !transformerFile.empty() ) { 00403 SprVarTransformerReader transReader; 00404 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00405 if( t == 0 ) { 00406 cerr << "Unable to read VarTransformer from file " 00407 << transformerFile.c_str() << endl; 00408 return 2; 00409 } 00410 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00411 SprTransformerFilter* t_valid = 0; 00412 if( valFilter.get() != 0 ) 00413 t_valid = new SprTransformerFilter(valFilter.get()); 00414 bool replaceOriginalData = true; 00415 if( !t_train->transform(t,replaceOriginalData) ) { 00416 cerr << "Unable to apply VarTransformer to training data." << endl; 00417 return 2; 00418 } 00419 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00420 cerr << "Unable to apply VarTransformer to validation data." << endl; 00421 return 2; 00422 } 00423 cout << "Variable transformation from file " 00424 << transformerFile.c_str() << " has been applied to " 00425 << "training and validation data." << endl; 00426 garbage_train.reset(filter.release()); 00427 garbage_valid.reset(valFilter.release()); 00428 filter.reset(t_train); 00429 valFilter.reset(t_valid); 00430 } 00431 00432 // make per-event loss 00433 auto_ptr<SprAverageLoss> loss; 00434 switch( iLoss ) 00435 { 00436 case 1 : 00437 loss.reset(new SprAverageLoss(&SprLoss::quadratic, 00438 &SprTransformation::logit)); 00439 cout << "Per-event loss set to " 00440 << "Quadratic loss (y-f(x))^2 " << endl; 00441 break; 00442 case 2 : 00443 loss.reset(new SprAverageLoss(&SprLoss::exponential)); 00444 cout << "Per-event loss set to " 00445 << "Exponential loss exp(-y*f(x)) " << endl; 00446 break; 00447 default : 00448 cout << "No per-event loss is chosen. Will use the default." << endl; 00449 break; 00450 } 00451 00452 // make AdaBoost mode 00453 SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete; 00454 switch( iAdaBoostMode ) 00455 { 00456 case 1 : 00457 abMode = SprTrainedAdaBoost::Discrete; 00458 cout << "Will train Discrete AdaBoost." << endl; 00459 break; 00460 case 2 : 00461 abMode = SprTrainedAdaBoost::Real; 00462 cout << "Will train Real AdaBoost." << endl; 00463 break; 00464 case 3 : 00465 abMode = SprTrainedAdaBoost::Epsilon; 00466 cout << "Will train Epsilon AdaBoost." << endl; 00467 break; 00468 default : 00469 cout << "Will train Discrete AdaBoost." << endl; 00470 break; 00471 } 00472 00473 // open file with classifier configs 00474 ifstream file(configFile.c_str()); 00475 if( !file ) { 00476 cerr << "Unable to open file " << configFile.c_str() << endl; 00477 return 3; 00478 } 00479 00480 // prepare vectors of objects 00481 vector<SprAbsTwoClassCriterion*> criteria; 00482 vector<SprAbsClassifier*> destroyC;// classifiers to be deleted 00483 vector<SprIntegerBootstrap*> bstraps; 00484 vector<SprCCPair> useC;// classifiers and cuts to be used 00485 00486 // read classifier params 00487 unsigned nLine = 0; 00488 bool discreteTree = (abMode!=SprTrainedAdaBoost::Real); 00489 bool mixedNodesTree = (abMode==SprTrainedAdaBoost::Real); 00490 bool fastSort = true; 00491 bool readOneEntry = false; 00492 if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(), 00493 discreteTree,mixedNodesTree, 00494 fastSort,criteria, 00495 bstraps,destroyC,useC, 00496 readOneEntry) ) { 00497 cerr << "Unable to read weak classifier configurations from file " 00498 << configFile.c_str() << endl; 00499 prepareExit(criteria,destroyC,bstraps); 00500 return 4; 00501 } 00502 cout << "Finished reading " << useC.size() << " classifiers from file " 00503 << configFile.c_str() << endl; 00504 00505 // make AdaBoost 00506 SprAdaBoost ab(filter.get(),cycles,useStandardAB,abMode,bagInput); 00507 cout << "Setting epsilon to " << epsilon << endl; 00508 ab.setEpsilon(epsilon); 00509 00510 // set validation 00511 if( valFilter.get()!=0 && !valFilter->empty() ) 00512 ab.setValidation(valFilter.get(),valPrint,loss.get()); 00513 00514 // read saved Boost from file 00515 if( !resumeFile.empty() ) { 00516 if( !SprClassifierReader::readTrainable(resumeFile.c_str(), 00517 &ab,verbose) ) { 00518 cerr << "Failed to read saved AdaBoost from file " 00519 << resumeFile.c_str() << endl; 00520 prepareExit(criteria,destroyC,bstraps); 00521 return 5; 00522 } 00523 cout << "Read saved AdaBoost from file " << resumeFile.c_str() 00524 << " with " << ab.nTrained() << " trained classifiers." << endl; 00525 } 00526 if( skipInitialEventReweighting ) ab.skipInitialEventReweighting(true); 00527 00528 // add trainable classifiers 00529 for( int i=0;i<useC.size();i++ ) { 00530 if( !ab.addTrainable(useC[i].first,useC[i].second) ) { 00531 cerr << "Unable to add classifier " << i << " of type " 00532 << useC[i].first->name() << " to AdaBoost." << endl; 00533 prepareExit(criteria,destroyC,bstraps); 00534 return 6; 00535 } 00536 } 00537 00538 // train 00539 if( !ab.train(verbose) ) 00540 cerr << "AdaBoost terminated with error." << endl; 00541 if( ab.nTrained() == 0 ) { 00542 cerr << "Unable to train AdaBoost." << endl; 00543 prepareExit(criteria,destroyC,bstraps); 00544 return 7; 00545 } 00546 else { 00547 cout << "AdaBoost finished training with " << ab.nTrained() 00548 << " classifiers." << endl; 00549 } 00550 00551 // save trained AdaBoost 00552 if( !outFile.empty() ) { 00553 if( !ab.store(outFile.c_str()) ) { 00554 cerr << "Cannot store AdaBoost in file " << outFile.c_str() << endl; 00555 prepareExit(criteria,destroyC,bstraps); 00556 return 8; 00557 } 00558 } 00559 00560 // save reweighted data 00561 if( !weightedDataOut.empty() ) { 00562 if( !ab.storeData(weightedDataOut.c_str()) ) { 00563 cerr << "Cannot store weighted AdaBoost data to file " 00564 << weightedDataOut.c_str() << endl; 00565 prepareExit(criteria,destroyC,bstraps); 00566 return 9; 00567 } 00568 } 00569 00570 // exit 00571 prepareExit(criteria,destroyC,bstraps); 00572 return 0; 00573 }
void prepareExit | ( | vector< SprAbsTwoClassCriterion * > & | criteria, | |
vector< SprAbsClassifier * > & | classifiers, | |||
vector< SprIntegerBootstrap * > & | bstraps | |||
) |