00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00024
00025 #include <stdlib.h>
00026 #include <unistd.h>
00027 #include <iostream>
00028 #include <fstream>
00029 #include <vector>
00030 #include <set>
00031 #include <string>
00032 #include <memory>
00033
00034 using namespace std;
00035
00036
00037
00038 void help(const char* prog)
00039 {
00040 cout << "Usage: " << prog
00041 << " training_data_file"
00042 << " file_of_classifier_parameters(see booster.config for syntax)"
00043 << endl;
00044 cout << "\t Options: " << endl;
00045 cout << "\t-h --- help " << endl;
00046 cout << "\t-M AdaBoost mode " << endl;
00047 cout << "\t\t 1 = Discrete AdaBoost (default) " << endl;
00048 cout << "\t\t 2 = Real AdaBoost " << endl;
00049 cout << "\t\t 3 = Epsilon AdaBoost " << endl;
00050 cout << "\t-E epsilon for Epsilon and Real AdaBoosts (def=0.01)" << endl;
00051 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00052 cout << "\t-A save output data in ascii instead of Root " << endl;
00053 cout << "\t-n number of AdaBoost training cycles " << endl;
00054 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00055 cout << "\t-Q apply variable transformation saved in file " << endl;
00056 cout << "\t-g per-event loss for (cross-)validation " << endl;
00057 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl;
00058 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl;
00059 cout << "\t-b bootstrap training sample " << endl;
00060 cout << "\t-m replace data values below this cutoff with medians" << endl;
00061 cout << "\t-e skip initial event reweighting when resuming " << endl;
00062 cout << "\t-u store data with modified weights to file " << endl;
00063 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00064 cout << "\t-f store trained AdaBoost to file " << endl;
00065 cout << "\t-r resume training for AdaBoost stored in file " << endl;
00066 cout << "\t-K keep this fraction in training set and " << endl;
00067 cout << "\t\t put the rest into validation set " << endl;
00068 cout << "\t-D randomize training set split-up " << endl;
00069 cout << "\t-t read validation/test data from a file " << endl;
00070 cout << "\t\t (must be in same format as input data!!! " << endl;
00071 cout << "\t-d frequency of print-outs for validation data " << endl;
00072 cout << "\t-w scale all signal weights by this factor " << endl;
00073 cout << "\t-V include only these input variables " << endl;
00074 cout << "\t-z exclude input variables from the list " << endl;
00075 cout << "\t\t Variables must be listed in quotes and separated by commas."
00076 << endl;
00077 }
00078
00079
00080 void prepareExit(vector<SprAbsTwoClassCriterion*>& criteria,
00081 vector<SprAbsClassifier*>& classifiers,
00082 vector<SprIntegerBootstrap*>& bstraps)
00083 {
00084 for( int i=0;i<criteria.size();i++ ) delete criteria[i];
00085 for( int i=0;i<classifiers.size();i++ ) delete classifiers[i];
00086 for( int i=0;i<bstraps.size();i++ ) delete bstraps[i];
00087 }
00088
00089
00090 int main(int argc, char ** argv)
00091 {
00092
00093 if( argc < 3 ) {
00094 help(argv[0]);
00095 return 1;
00096 }
00097
00098
00099 int readMode = 0;
00100 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00101 unsigned cycles = 0;
00102 int verbose = 0;
00103 string outFile;
00104 string resumeFile;
00105 string valFile;
00106 unsigned valPrint = 0;
00107 bool scaleWeights = false;
00108 double sW = 1.;
00109 bool useStandardAB = false;
00110 int iAdaBoostMode = 1;
00111 double epsilon = 0.01;
00112 bool skipInitialEventReweighting = false;
00113 string weightedDataOut;
00114 bool setLowCutoff = false;
00115 double lowCutoff = 0;
00116 string includeList, excludeList;
00117 int iLoss = 0;
00118 bool bagInput = false;
00119 string inputClassesString;
00120 bool split = false;
00121 double splitFactor = 0;
00122 bool splitRandomize = false;
00123 string transformerFile;
00124
00125
00126 int c;
00127 extern char* optarg;
00128
00129 while((c = getopt(argc,argv,"hM:E:a:An:y:Q:g:bm:eu:v:f:r:K:Dt:d:w:V:z:")) != EOF ) {
00130 switch( c )
00131 {
00132 case 'h' :
00133 help(argv[0]);
00134 return 1;
00135 case 'M' :
00136 iAdaBoostMode = (optarg==0 ? 1 : atoi(optarg));
00137 break;
00138 case 'E' :
00139 epsilon = (optarg==0 ? 0.01 : atof(optarg));
00140 break;
00141 case 'a' :
00142 readMode = (optarg==0 ? 0 : atoi(optarg));
00143 break;
00144 case 'A' :
00145 writeMode = SprRWFactory::Ascii;
00146 break;
00147 case 'n' :
00148 cycles = (optarg==0 ? 1 : atoi(optarg));
00149 break;
00150 case 'y' :
00151 inputClassesString = optarg;
00152 break;
00153 case 'g' :
00154 iLoss = (optarg==0 ? 0 : atoi(optarg));
00155 break;
00156 case 'b' :
00157 bagInput = true;
00158 break;
00159 case 'm' :
00160 if( optarg != 0 ) {
00161 setLowCutoff = true;
00162 lowCutoff = atof(optarg);
00163 }
00164 break;
00165 case 'e' :
00166 skipInitialEventReweighting = true;
00167 break;
00168 case 'u' :
00169 weightedDataOut = optarg;
00170 break;
00171 case 'v' :
00172 verbose = (optarg==0 ? 0 : atoi(optarg));
00173 break;
00174 case 'f' :
00175 outFile = optarg;
00176 break;
00177 case 'r' :
00178 resumeFile = optarg;
00179 break;
00180 case 'K' :
00181 split = true;
00182 splitFactor = (optarg==0 ? 0 : atof(optarg));
00183 break;
00184 case 'D' :
00185 splitRandomize = true;
00186 break;
00187 case 't' :
00188 valFile = optarg;
00189 break;
00190 case 'd' :
00191 valPrint = (optarg==0 ? 0 : atoi(optarg));
00192 break;
00193 case 'w' :
00194 if( optarg != 0 ) {
00195 scaleWeights = true;
00196 sW = atof(optarg);
00197 }
00198 break;
00199 case 'V' :
00200 includeList = optarg;
00201 break;
00202 case 'z' :
00203 excludeList = optarg;
00204 break;
00205 }
00206 }
00207
00208
00209 string trFile = argv[argc-2];
00210 string configFile = argv[argc-1];
00211 if( trFile.empty() ) {
00212 cerr << "No training file is specified." << endl;
00213 return 1;
00214 }
00215 if( configFile.empty() ) {
00216 cerr << "No classifier configuration file specified." << endl;
00217 return 1;
00218 }
00219
00220
00221 SprRWFactory::DataType inputType
00222 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00223 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00224
00225
00226 set<string> includeSet;
00227 if( !includeList.empty() ) {
00228 vector<vector<string> > includeVars;
00229 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00230 assert( !includeVars.empty() );
00231 for( int i=0;i<includeVars[0].size();i++ )
00232 includeSet.insert(includeVars[0][i]);
00233 if( !reader->chooseVars(includeSet) ) {
00234 cerr << "Unable to include variables in training set." << endl;
00235 return 2;
00236 }
00237 else {
00238 cout << "Following variables have been included in optimization: ";
00239 for( set<string>::const_iterator
00240 i=includeSet.begin();i!=includeSet.end();i++ )
00241 cout << "\"" << *i << "\"" << " ";
00242 cout << endl;
00243 }
00244 }
00245
00246
00247 set<string> excludeSet;
00248 if( !excludeList.empty() ) {
00249 vector<vector<string> > excludeVars;
00250 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00251 assert( !excludeVars.empty() );
00252 for( int i=0;i<excludeVars[0].size();i++ )
00253 excludeSet.insert(excludeVars[0][i]);
00254 if( !reader->chooseAllBut(excludeSet) ) {
00255 cerr << "Unable to exclude variables from training set." << endl;
00256 return 2;
00257 }
00258 else {
00259 cout << "Following variables have been excluded from optimization: ";
00260 for( set<string>::const_iterator
00261 i=excludeSet.begin();i!=excludeSet.end();i++ )
00262 cout << "\"" << *i << "\"" << " ";
00263 cout << endl;
00264 }
00265 }
00266
00267
00268 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00269 if( filter.get() == 0 ) {
00270 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00271 return 2;
00272 }
00273 vector<string> vars;
00274 filter->vars(vars);
00275 cout << "Read data from file " << trFile.c_str() << " for variables";
00276 for( int i=0;i<vars.size();i++ )
00277 cout << " \"" << vars[i].c_str() << "\"";
00278 cout << endl;
00279 cout << "Total number of points read: " << filter->size() << endl;
00280
00281
00282 vector<SprClass> inputClasses;
00283 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00284 cerr << "Cannot choose input classes for string "
00285 << inputClassesString << endl;
00286 return 2;
00287 }
00288 filter->classes(inputClasses);
00289 assert( inputClasses.size() > 1 );
00290 cout << "Training data filtered by class." << endl;
00291 for( int i=0;i<inputClasses.size();i++ ) {
00292 cout << "Points in class " << inputClasses[i] << ": "
00293 << filter->ptsInClass(inputClasses[i]) << endl;
00294 }
00295
00296
00297 if( scaleWeights ) {
00298 cout << "Signal weights are multiplied by " << sW << endl;
00299 filter->scaleWeights(inputClasses[1],sW);
00300 }
00301
00302
00303 if( setLowCutoff ) {
00304 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00305 cerr << "Unable to replace missing values in training data." << endl;
00306 return 2;
00307 }
00308 else
00309 cout << "Values below " << lowCutoff << " in training data"
00310 << " have been replaced with medians." << endl;
00311 }
00312
00313
00314 auto_ptr<SprAbsFilter> valFilter;
00315 if( split && !valFile.empty() ) {
00316 cerr << "Unable to split training data and use validation data "
00317 << "from a separate file." << endl;
00318 return 2;
00319 }
00320 if( split && valPrint!=0 ) {
00321 cout << "Splitting training data with factor " << splitFactor << endl;
00322 if( splitRandomize )
00323 cout << "Will use randomized splitting." << endl;
00324 vector<double> weights;
00325 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00326 if( splitted == 0 ) {
00327 cerr << "Unable to split training data." << endl;
00328 return 2;
00329 }
00330 bool ownData = true;
00331 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00332 cout << "Training data re-filtered:" << endl;
00333 for( int i=0;i<inputClasses.size();i++ ) {
00334 cout << "Points in class " << inputClasses[i] << ": "
00335 << filter->ptsInClass(inputClasses[i]) << endl;
00336 }
00337 }
00338 if( !valFile.empty() && valPrint!=0 ) {
00339 auto_ptr<SprAbsReader>
00340 valReader(SprRWFactory::makeReader(inputType,readMode));
00341 if( !includeSet.empty() ) {
00342 if( !valReader->chooseVars(includeSet) ) {
00343 cerr << "Unable to include variables in validation set." << endl;
00344 return 2;
00345 }
00346 }
00347 if( !excludeSet.empty() ) {
00348 if( !valReader->chooseAllBut(excludeSet) ) {
00349 cerr << "Unable to exclude variables from validation set." << endl;
00350 return 2;
00351 }
00352 }
00353 valFilter.reset(valReader->read(valFile.c_str()));
00354 if( valFilter.get() == 0 ) {
00355 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00356 return 2;
00357 }
00358 vector<string> valVars;
00359 valFilter->vars(valVars);
00360 cout << "Read validation data from file " << valFile.c_str()
00361 << " for variables";
00362 for( int i=0;i<valVars.size();i++ )
00363 cout << " \"" << valVars[i].c_str() << "\"";
00364 cout << endl;
00365 cout << "Total number of points read: " << valFilter->size() << endl;
00366 cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0])
00367 << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl;
00368 }
00369
00370
00371 if( valFilter.get() != 0 ) {
00372 if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00373 cerr << "Cannot choose input classes for string "
00374 << inputClassesString << endl;
00375 return 2;
00376 }
00377 valFilter->classes(inputClasses);
00378 cout << "Validation data filtered by class." << endl;
00379 for( int i=0;i<inputClasses.size();i++ ) {
00380 cout << "Points in class " << inputClasses[i] << ": "
00381 << valFilter->ptsInClass(inputClasses[i]) << endl;
00382 }
00383 }
00384
00385
00386 if( scaleWeights && valFilter.get()!=0 )
00387 valFilter->scaleWeights(inputClasses[1],sW);
00388
00389
00390 if( setLowCutoff && valFilter.get()!=0 ) {
00391 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00392 cerr << "Unable to replace missing values in validation data." << endl;
00393 return 2;
00394 }
00395 else
00396 cout << "Values below " << lowCutoff << " in validation data"
00397 << " have been replaced with medians." << endl;
00398 }
00399
00400
00401 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00402 if( !transformerFile.empty() ) {
00403 SprVarTransformerReader transReader;
00404 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00405 if( t == 0 ) {
00406 cerr << "Unable to read VarTransformer from file "
00407 << transformerFile.c_str() << endl;
00408 return 2;
00409 }
00410 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00411 SprTransformerFilter* t_valid = 0;
00412 if( valFilter.get() != 0 )
00413 t_valid = new SprTransformerFilter(valFilter.get());
00414 bool replaceOriginalData = true;
00415 if( !t_train->transform(t,replaceOriginalData) ) {
00416 cerr << "Unable to apply VarTransformer to training data." << endl;
00417 return 2;
00418 }
00419 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00420 cerr << "Unable to apply VarTransformer to validation data." << endl;
00421 return 2;
00422 }
00423 cout << "Variable transformation from file "
00424 << transformerFile.c_str() << " has been applied to "
00425 << "training and validation data." << endl;
00426 garbage_train.reset(filter.release());
00427 garbage_valid.reset(valFilter.release());
00428 filter.reset(t_train);
00429 valFilter.reset(t_valid);
00430 }
00431
00432
00433 auto_ptr<SprAverageLoss> loss;
00434 switch( iLoss )
00435 {
00436 case 1 :
00437 loss.reset(new SprAverageLoss(&SprLoss::quadratic,
00438 &SprTransformation::logit));
00439 cout << "Per-event loss set to "
00440 << "Quadratic loss (y-f(x))^2 " << endl;
00441 break;
00442 case 2 :
00443 loss.reset(new SprAverageLoss(&SprLoss::exponential));
00444 cout << "Per-event loss set to "
00445 << "Exponential loss exp(-y*f(x)) " << endl;
00446 break;
00447 default :
00448 cout << "No per-event loss is chosen. Will use the default." << endl;
00449 break;
00450 }
00451
00452
00453 SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete;
00454 switch( iAdaBoostMode )
00455 {
00456 case 1 :
00457 abMode = SprTrainedAdaBoost::Discrete;
00458 cout << "Will train Discrete AdaBoost." << endl;
00459 break;
00460 case 2 :
00461 abMode = SprTrainedAdaBoost::Real;
00462 cout << "Will train Real AdaBoost." << endl;
00463 break;
00464 case 3 :
00465 abMode = SprTrainedAdaBoost::Epsilon;
00466 cout << "Will train Epsilon AdaBoost." << endl;
00467 break;
00468 default :
00469 cout << "Will train Discrete AdaBoost." << endl;
00470 break;
00471 }
00472
00473
00474 ifstream file(configFile.c_str());
00475 if( !file ) {
00476 cerr << "Unable to open file " << configFile.c_str() << endl;
00477 return 3;
00478 }
00479
00480
00481 vector<SprAbsTwoClassCriterion*> criteria;
00482 vector<SprAbsClassifier*> destroyC;
00483 vector<SprIntegerBootstrap*> bstraps;
00484 vector<SprCCPair> useC;
00485
00486
00487 unsigned nLine = 0;
00488 bool discreteTree = (abMode!=SprTrainedAdaBoost::Real);
00489 bool mixedNodesTree = (abMode==SprTrainedAdaBoost::Real);
00490 bool fastSort = true;
00491 bool readOneEntry = false;
00492 if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(),
00493 discreteTree,mixedNodesTree,
00494 fastSort,criteria,
00495 bstraps,destroyC,useC,
00496 readOneEntry) ) {
00497 cerr << "Unable to read weak classifier configurations from file "
00498 << configFile.c_str() << endl;
00499 prepareExit(criteria,destroyC,bstraps);
00500 return 4;
00501 }
00502 cout << "Finished reading " << useC.size() << " classifiers from file "
00503 << configFile.c_str() << endl;
00504
00505
00506 SprAdaBoost ab(filter.get(),cycles,useStandardAB,abMode,bagInput);
00507 cout << "Setting epsilon to " << epsilon << endl;
00508 ab.setEpsilon(epsilon);
00509
00510
00511 if( valFilter.get()!=0 && !valFilter->empty() )
00512 ab.setValidation(valFilter.get(),valPrint,loss.get());
00513
00514
00515 if( !resumeFile.empty() ) {
00516 if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00517 &ab,verbose) ) {
00518 cerr << "Failed to read saved AdaBoost from file "
00519 << resumeFile.c_str() << endl;
00520 prepareExit(criteria,destroyC,bstraps);
00521 return 5;
00522 }
00523 cout << "Read saved AdaBoost from file " << resumeFile.c_str()
00524 << " with " << ab.nTrained() << " trained classifiers." << endl;
00525 }
00526 if( skipInitialEventReweighting ) ab.skipInitialEventReweighting(true);
00527
00528
00529 for( int i=0;i<useC.size();i++ ) {
00530 if( !ab.addTrainable(useC[i].first,useC[i].second) ) {
00531 cerr << "Unable to add classifier " << i << " of type "
00532 << useC[i].first->name() << " to AdaBoost." << endl;
00533 prepareExit(criteria,destroyC,bstraps);
00534 return 6;
00535 }
00536 }
00537
00538
00539 if( !ab.train(verbose) )
00540 cerr << "AdaBoost terminated with error." << endl;
00541 if( ab.nTrained() == 0 ) {
00542 cerr << "Unable to train AdaBoost." << endl;
00543 prepareExit(criteria,destroyC,bstraps);
00544 return 7;
00545 }
00546 else {
00547 cout << "AdaBoost finished training with " << ab.nTrained()
00548 << " classifiers." << endl;
00549 }
00550
00551
00552 if( !outFile.empty() ) {
00553 if( !ab.store(outFile.c_str()) ) {
00554 cerr << "Cannot store AdaBoost in file " << outFile.c_str() << endl;
00555 prepareExit(criteria,destroyC,bstraps);
00556 return 8;
00557 }
00558 }
00559
00560
00561 if( !weightedDataOut.empty() ) {
00562 if( !ab.storeData(weightedDataOut.c_str()) ) {
00563 cerr << "Cannot store weighted AdaBoost data to file "
00564 << weightedDataOut.c_str() << endl;
00565 prepareExit(criteria,destroyC,bstraps);
00566 return 9;
00567 }
00568 }
00569
00570
00571 prepareExit(criteria,destroyC,bstraps);
00572 return 0;
00573 }