00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00025
00026 #include <stdlib.h>
00027 #include <unistd.h>
00028 #include <iostream>
00029 #include <fstream>
00030 #include <vector>
00031 #include <set>
00032 #include <string>
00033 #include <memory>
00034
00035 using namespace std;
00036
00037
00038
00039 void help(const char* prog)
00040 {
00041 cout << "Usage: " << prog
00042 << " training_data_file"
00043 << " file_of_classifier_parameters(see booster.config for syntax)"
00044 << endl;
00045 cout << "\t Options: " << endl;
00046 cout << "\t-h --- help " << endl;
00047 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00048 cout << "\t-A save output data in ascii instead of Root " << endl;
00049 cout << "\t-n number of Bagger training cycles " << endl;
00050 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00051 cout << "\t-Q apply variable transformation saved in file " << endl;
00052 cout << "\t-b use a version of Breiman's arc-x4 algorithm " << endl;
00053 cout << "\t-g per-event loss for (cross-)validation " << endl;
00054 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl;
00055 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl;
00056 cout << "\t-m replace data values below this cutoff with medians" << endl;
00057 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00058 cout << "\t-f store trained AdaBoost to file " << endl;
00059 cout << "\t-r resume training for Bagger stored in file " << endl;
00060 cout << "\t-K keep this fraction in training set and " << endl;
00061 cout << "\t\t put the rest into validation set " << endl;
00062 cout << "\t-D randomize training set split-up " << endl;
00063 cout << "\t-G generate seed from time of day for bootstrap " << endl;
00064 cout << "\t\t (this option is required for parallelization) " << endl;
00065 cout << "\t-t read validation/test data from a file " << endl;
00066 cout << "\t\t (must be in same format as input data!!! " << endl;
00067 cout << "\t-d frequency of print-outs for validation data " << endl;
00068 cout << "\t-w scale all signal weights by this factor " << endl;
00069 cout << "\t-V include only these input variables " << endl;
00070 cout << "\t-z exclude input variables from the list " << endl;
00071 cout << "\t\t Variables must be listed in quotes and separated by commas."
00072 << endl;
00073 }
00074
00075
00076 void prepareExit(vector<SprAbsTwoClassCriterion*>& criteria,
00077 vector<SprAbsClassifier*>& classifiers,
00078 vector<SprIntegerBootstrap*>& bstraps)
00079 {
00080 for( int i=0;i<criteria.size();i++ ) delete criteria[i];
00081 for( int i=0;i<classifiers.size();i++ ) delete classifiers[i];
00082 for( int i=0;i<bstraps.size();i++ ) delete bstraps[i];
00083 }
00084
00085
00086 int main(int argc, char ** argv)
00087 {
00088
00089 if( argc < 3 ) {
00090 help(argv[0]);
00091 return 1;
00092 }
00093
00094
00095 int readMode = 0;
00096 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00097 unsigned cycles = 0;
00098 int verbose = 0;
00099 string outFile;
00100 string resumeFile;
00101 string valFile;
00102 unsigned valPrint = 0;
00103 bool scaleWeights = false;
00104 double sW = 1.;
00105 bool setLowCutoff = false;
00106 double lowCutoff = 0;
00107 string includeList, excludeList;
00108 int iLoss = 1;
00109 string inputClassesString;
00110 bool useArcE4 = false;
00111 bool split = false;
00112 double splitFactor = 0;
00113 bool splitRandomize = false;
00114 bool initBootstrapFromTimeOfDay = false;
00115 string transformerFile;
00116
00117
00118 int c;
00119 extern char* optarg;
00120
00121 while((c = getopt(argc,argv,"ha:An:y:Q:bg:m:v:f:r:K:DGt:d:w:V:z:")) != EOF ) {
00122 switch( c )
00123 {
00124 case 'h' :
00125 help(argv[0]);
00126 return 1;
00127 case 'a' :
00128 readMode = (optarg==0 ? 0 : atoi(optarg));
00129 break;
00130 case 'A' :
00131 writeMode = SprRWFactory::Ascii;
00132 break;
00133 case 'n' :
00134 cycles = (optarg==0 ? 1 : atoi(optarg));
00135 break;
00136 case 'y' :
00137 inputClassesString = optarg;
00138 break;
00139 case 'Q' :
00140 transformerFile = optarg;
00141 break;
00142 case 'b' :
00143 useArcE4 = true;
00144 break;
00145 case 'g' :
00146 iLoss = (optarg==0 ? 1 : atoi(optarg));
00147 break;
00148 case 'm' :
00149 if( optarg != 0 ) {
00150 setLowCutoff = true;
00151 lowCutoff = atof(optarg);
00152 }
00153 break;
00154 case 'v' :
00155 verbose = (optarg==0 ? 0 : atoi(optarg));
00156 break;
00157 case 'f' :
00158 outFile = optarg;
00159 break;
00160 case 'r' :
00161 resumeFile = optarg;
00162 break;
00163 case 'K' :
00164 split = true;
00165 splitFactor = (optarg==0 ? 0 : atof(optarg));
00166 break;
00167 case 'D' :
00168 splitRandomize = true;
00169 break;
00170 case 'G' :
00171 initBootstrapFromTimeOfDay = true;
00172 break;
00173 case 't' :
00174 valFile = optarg;
00175 break;
00176 case 'd' :
00177 valPrint = (optarg==0 ? 0 : atoi(optarg));
00178 break;
00179 case 'w' :
00180 if( optarg != 0 ) {
00181 scaleWeights = true;
00182 sW = atof(optarg);
00183 }
00184 break;
00185 case 'V' :
00186 includeList = optarg;
00187 break;
00188 case 'z' :
00189 excludeList = optarg;
00190 break;
00191 }
00192 }
00193
00194
00195 string trFile = argv[argc-2];
00196 string configFile = argv[argc-1];
00197 if( trFile.empty() ) {
00198 cerr << "No training file is specified." << endl;
00199 return 1;
00200 }
00201 if( configFile.empty() ) {
00202 cerr << "No classifier configuration file specified." << endl;
00203 return 1;
00204 }
00205
00206
00207 SprRWFactory::DataType inputType
00208 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00209 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00210
00211
00212 set<string> includeSet;
00213 if( !includeList.empty() ) {
00214 vector<vector<string> > includeVars;
00215 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00216 assert( !includeVars.empty() );
00217 for( int i=0;i<includeVars[0].size();i++ )
00218 includeSet.insert(includeVars[0][i]);
00219 if( !reader->chooseVars(includeSet) ) {
00220 cerr << "Unable to include variables in training set." << endl;
00221 return 2;
00222 }
00223 else {
00224 cout << "Following variables have been included in optimization: ";
00225 for( set<string>::const_iterator
00226 i=includeSet.begin();i!=includeSet.end();i++ )
00227 cout << "\"" << *i << "\"" << " ";
00228 cout << endl;
00229 }
00230 }
00231
00232
00233 set<string> excludeSet;
00234 if( !excludeList.empty() ) {
00235 vector<vector<string> > excludeVars;
00236 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00237 assert( !excludeVars.empty() );
00238 for( int i=0;i<excludeVars[0].size();i++ )
00239 excludeSet.insert(excludeVars[0][i]);
00240 if( !reader->chooseAllBut(excludeSet) ) {
00241 cerr << "Unable to exclude variables from training set." << endl;
00242 return 2;
00243 }
00244 else {
00245 cout << "Following variables have been excluded from optimization: ";
00246 for( set<string>::const_iterator
00247 i=excludeSet.begin();i!=excludeSet.end();i++ )
00248 cout << "\"" << *i << "\"" << " ";
00249 cout << endl;
00250 }
00251 }
00252
00253
00254 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00255 if( filter.get() == 0 ) {
00256 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00257 return 2;
00258 }
00259 vector<string> vars;
00260 filter->vars(vars);
00261 cout << "Read data from file " << trFile.c_str() << " for variables";
00262 for( int i=0;i<vars.size();i++ )
00263 cout << " \"" << vars[i].c_str() << "\"";
00264 cout << endl;
00265 cout << "Total number of points read: " << filter->size() << endl;
00266
00267
00268 vector<SprClass> inputClasses;
00269 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00270 cerr << "Cannot choose input classes for string "
00271 << inputClassesString << endl;
00272 return 2;
00273 }
00274 filter->classes(inputClasses);
00275 assert( inputClasses.size() > 1 );
00276 cout << "Training data filtered by class." << endl;
00277 for( int i=0;i<inputClasses.size();i++ ) {
00278 cout << "Points in class " << inputClasses[i] << ": "
00279 << filter->ptsInClass(inputClasses[i]) << endl;
00280 }
00281
00282
00283 if( scaleWeights ) {
00284 cout << "Signal weights are multiplied by " << sW << endl;
00285 filter->scaleWeights(inputClasses[1],sW);
00286 }
00287
00288
00289 if( setLowCutoff ) {
00290 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00291 cerr << "Unable to replace missing values in training data." << endl;
00292 return 2;
00293 }
00294 else
00295 cout << "Values below " << lowCutoff << " in training data"
00296 << " have been replaced with medians." << endl;
00297 }
00298
00299
00300 auto_ptr<SprAbsFilter> valFilter;
00301 if( split && !valFile.empty() ) {
00302 cerr << "Unable to split training data and use validation data "
00303 << "from a separate file." << endl;
00304 return 2;
00305 }
00306 if( split && valPrint!=0 ) {
00307 cout << "Splitting training data with factor " << splitFactor << endl;
00308 if( splitRandomize )
00309 cout << "Will use randomized splitting." << endl;
00310 vector<double> weights;
00311 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00312 if( splitted == 0 ) {
00313 cerr << "Unable to split training data." << endl;
00314 return 2;
00315 }
00316 bool ownData = true;
00317 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00318 cout << "Training data re-filtered:" << endl;
00319 for( int i=0;i<inputClasses.size();i++ ) {
00320 cout << "Points in class " << inputClasses[i] << ": "
00321 << filter->ptsInClass(inputClasses[i]) << endl;
00322 }
00323 }
00324 if( !valFile.empty() && valPrint!=0 ) {
00325 auto_ptr<SprAbsReader>
00326 valReader(SprRWFactory::makeReader(inputType,readMode));
00327 if( !includeSet.empty() ) {
00328 if( !valReader->chooseVars(includeSet) ) {
00329 cerr << "Unable to include variables in validation set." << endl;
00330 return 2;
00331 }
00332 }
00333 if( !excludeSet.empty() ) {
00334 if( !valReader->chooseAllBut(excludeSet) ) {
00335 cerr << "Unable to exclude variables from validation set." << endl;
00336 return 2;
00337 }
00338 }
00339 valFilter.reset(valReader->read(valFile.c_str()));
00340 if( valFilter.get() == 0 ) {
00341 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00342 return 2;
00343 }
00344 vector<string> valVars;
00345 valFilter->vars(valVars);
00346 cout << "Read validation data from file " << valFile.c_str()
00347 << " for variables";
00348 for( int i=0;i<valVars.size();i++ )
00349 cout << " \"" << valVars[i].c_str() << "\"";
00350 cout << endl;
00351 cout << "Total number of points read: " << valFilter->size() << endl;
00352 cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0])
00353 << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl;
00354 }
00355
00356
00357 if( valFilter.get() != 0 ) {
00358 if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00359 cerr << "Cannot choose input classes for string "
00360 << inputClassesString << endl;
00361 return 2;
00362 }
00363 valFilter->classes(inputClasses);
00364 cout << "Validation data filtered by class." << endl;
00365 for( int i=0;i<inputClasses.size();i++ ) {
00366 cout << "Points in class " << inputClasses[i] << ": "
00367 << valFilter->ptsInClass(inputClasses[i]) << endl;
00368 }
00369 }
00370
00371
00372 if( scaleWeights && valFilter.get()!=0 )
00373 valFilter->scaleWeights(inputClasses[1],sW);
00374
00375
00376 if( setLowCutoff && valFilter.get()!=0 ) {
00377 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00378 cerr << "Unable to replace missing values in validation data." << endl;
00379 return 2;
00380 }
00381 else
00382 cout << "Values below " << lowCutoff << " in validation data"
00383 << " have been replaced with medians." << endl;
00384 }
00385
00386
00387 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00388 if( !transformerFile.empty() ) {
00389 SprVarTransformerReader transReader;
00390 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00391 if( t == 0 ) {
00392 cerr << "Unable to read VarTransformer from file "
00393 << transformerFile.c_str() << endl;
00394 return 2;
00395 }
00396 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00397 SprTransformerFilter* t_valid = 0;
00398 if( valFilter.get() != 0 )
00399 t_valid = new SprTransformerFilter(valFilter.get());
00400 bool replaceOriginalData = true;
00401 if( !t_train->transform(t,replaceOriginalData) ) {
00402 cerr << "Unable to apply VarTransformer to training data." << endl;
00403 return 2;
00404 }
00405 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00406 cerr << "Unable to apply VarTransformer to validation data." << endl;
00407 return 2;
00408 }
00409 cout << "Variable transformation from file "
00410 << transformerFile.c_str() << " has been applied to "
00411 << "training and validation data." << endl;
00412 garbage_train.reset(filter.release());
00413 garbage_valid.reset(valFilter.release());
00414 filter.reset(t_train);
00415 valFilter.reset(t_valid);
00416 }
00417
00418
00419 auto_ptr<SprAverageLoss> loss;
00420 switch( iLoss )
00421 {
00422 case 1 :
00423 loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00424 cout << "Per-event loss set to "
00425 << "Quadratic loss (y-f(x))^2 " << endl;
00426 break;
00427 case 2 :
00428 loss.reset(new SprAverageLoss(&SprLoss::purity_ratio));
00429 cout << "Per-event loss set to "
00430 << "Exponential loss exp(-y*f(x)) " << endl;
00431 break;
00432 default :
00433 cout << "No per-event loss is chosen. Will use the default." << endl;
00434 break;
00435 }
00436
00437
00438 ifstream file(configFile.c_str());
00439 if( !file ) {
00440 cerr << "Unable to open file " << configFile.c_str() << endl;
00441 return 3;
00442 }
00443
00444
00445 vector<SprAbsTwoClassCriterion*> criteria;
00446 vector<SprAbsClassifier*> destroyC;
00447 vector<SprIntegerBootstrap*> bstraps;
00448 vector<SprCCPair> useC;
00449
00450
00451 unsigned nLine = 0;
00452 bool discreteTree = false;
00453 bool mixedNodesTree = false;
00454 bool readOneEntry = false;
00455 bool fastSort = true;
00456 if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(),
00457 discreteTree,mixedNodesTree,
00458 fastSort,criteria,
00459 bstraps,destroyC,useC,
00460 readOneEntry) ) {
00461 cerr << "Unable to read weak classifier configurations from file "
00462 << configFile.c_str() << endl;
00463 prepareExit(criteria,destroyC,bstraps);
00464 return 4;
00465 }
00466 cout << "Finished reading " << useC.size() << " classifiers from file "
00467 << configFile.c_str() << endl;
00468
00469
00470 auto_ptr<SprBagger> bagger;
00471 bool discrete = false;
00472 if( useArcE4 )
00473 bagger.reset(new SprArcE4(filter.get(),cycles,discrete));
00474 else
00475 bagger.reset(new SprBagger(filter.get(),cycles,discrete));
00476
00477
00478 if( initBootstrapFromTimeOfDay && !bagger->initBootstrapFromTimeOfDay() ) {
00479 cerr << "Unable to generate seed from time of day for Bagger." << endl;
00480 return 4;
00481 }
00482
00483
00484 if( valFilter.get()!=0 && !valFilter->empty() )
00485 bagger->setValidation(valFilter.get(),valPrint,0,loss.get());
00486
00487
00488 if( !resumeFile.empty() ) {
00489 if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00490 bagger.get(),verbose) ) {
00491 cerr << "Failed to read saved Bagger from file "
00492 << resumeFile.c_str() << endl;
00493 prepareExit(criteria,destroyC,bstraps);
00494 return 5;
00495 }
00496 cout << "Read saved Bagger from file " << resumeFile.c_str()
00497 << " with " << bagger->nTrained() << " trained classifiers." << endl;
00498 }
00499
00500
00501 for( int i=0;i<useC.size();i++ ) {
00502 if( !bagger->addTrainable(useC[i].first) ) {
00503 cerr << "Unable to add classifier " << i << " of type "
00504 << useC[i].first->name() << " to Bagger." << endl;
00505 prepareExit(criteria,destroyC,bstraps);
00506 return 6;
00507 }
00508 }
00509
00510
00511 if( !bagger->train(verbose) )
00512 cerr << "Bagger terminated with error." << endl;
00513 if( bagger->nTrained() == 0 ) {
00514 cerr << "Unable to train Bagger." << endl;
00515 prepareExit(criteria,destroyC,bstraps);
00516 return 7;
00517 }
00518 else {
00519 cout << "Bagger finished training with " << bagger->nTrained()
00520 << " classifiers." << endl;
00521 }
00522
00523
00524 if( !outFile.empty() ) {
00525 if( !bagger->store(outFile.c_str()) ) {
00526 cerr << "Cannot store Bagger in file " << outFile.c_str() << endl;
00527 prepareExit(criteria,destroyC,bstraps);
00528 return 8;
00529 }
00530 }
00531
00532
00533 prepareExit(criteria,destroyC,bstraps);
00534 return 0;
00535 }