00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprBumpHunter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBgrndSmoother.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00026 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00027
00028 #include <stdlib.h>
00029 #include <unistd.h>
00030 #include <iostream>
00031 #include <vector>
00032 #include <set>
00033 #include <string>
00034 #include <memory>
00035
00036 using namespace std;
00037
00038
00039 void help(const char* prog)
00040 {
00041 cout << "Usage: " << prog
00042 << " training_data_file" << endl;
00043 cout << "\t Options: " << endl;
00044 cout << "\t-h --- help " << endl;
00045 cout << "\t-o output Tuple file " << endl;
00046 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00047 cout << "\t-A save output data in ascii instead of Root " << endl;
00048 cout << "\t-n minimal number of events per bump (def=1) " << endl;
00049 cout << "\t-b requested number of bumps (def=1) " << endl;
00050 cout << "\t-x max fraction of events peeled off in one try " << endl;
00051 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00052 cout << "\t-Q apply variable transformation saved in file " << endl;
00053 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00054 cout << "\t-f store trained bump hunter to file " << endl;
00055 cout << "\t-c criterion for optimization " << endl;
00056 cout << "\t\t 1 = correctly classified fraction " << endl;
00057 cout << "\t\t 2 = signal significance s/sqrt(s+b) " << endl;
00058 cout << "\t\t 3 = purity s/(s+b) (default) " << endl;
00059 cout << "\t\t 4 = tagger efficiency Q " << endl;
00060 cout << "\t\t 5 = Gini index " << endl;
00061 cout << "\t\t 6 = cross-entropy " << endl;
00062 cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl;
00063 cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b)) " << endl;
00064 cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b)) " << endl;
00065 cout << "\t\t 10= background-smoothed Punzi's sensitivity " << endl;
00066 cout << "\t\t -P background normalization factor for Punzi FOM" << endl;
00067 cout << "\t\t -L lambda for the background-smoothed FOM " << endl;
00068 cout << "\t\t -O omega for the background-smoothed FOM " << endl;
00069 cout << "\t-K keep this fraction in training set and " << endl;
00070 cout << "\t\t put the rest into validation set " << endl;
00071 cout << "\t-D randomize training set split-up " << endl;
00072 cout << "\t-t read validation/test data from a file " << endl;
00073 cout << "\t\t (must be in same format as input data!!! " << endl;
00074 cout << "\t-p output file to store validation/test data " << endl;
00075 cout << "\t-w scale all signal weights by this factor " << endl;
00076 cout << "\t-V include only these input variables " << endl;
00077 cout << "\t-z exclude input variables from the list " << endl;
00078 cout << "\t\t Variables must be listed in quotes and separated by commas."
00079 << endl;
00080 }
00081
00082
00083 int main(int argc, char ** argv)
00084 {
00085
00086 if( argc < 2 ) {
00087 help(argv[0]);
00088 return 1;
00089 }
00090
00091
00092 string tupleFile;
00093 int readMode = 0;
00094 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00095 unsigned nmin = 1;
00096 int verbose = 0;
00097 string outFile;
00098 string resumeFile;
00099 int iCrit = 3;
00100 string valFile;
00101 string valHbkFile;
00102 int nbump = 1;
00103 double apeel = 1.;
00104 bool scaleWeights = false;
00105 double sW = 1.;
00106 string includeList, excludeList;
00107 string inputClassesString;
00108 double bW = 1.;
00109 double lambda = 2.;
00110 double omega = 5.;
00111 bool split = false;
00112 double splitFactor = 0;
00113 bool splitRandomize = false;
00114 string transformerFile;
00115
00116
00117 int c;
00118 extern char* optarg;
00119 while( (c = getopt(argc,argv,"ho:a:An:v:f:c:P:L:O:K:Dt:p:b:x:y:Q:w:V:z:")) != EOF ) {
00120 switch( c )
00121 {
00122 case 'h' :
00123 help(argv[0]);
00124 return 1;
00125 case 'o' :
00126 tupleFile = optarg;
00127 break;
00128 case 'a' :
00129 readMode = (optarg==0 ? 0 : atoi(optarg));
00130 break;
00131 case 'A' :
00132 writeMode = SprRWFactory::Ascii;
00133 break;
00134 case 'n' :
00135 nmin = (optarg==0 ? 1 : atoi(optarg));
00136 break;
00137 case 'v' :
00138 verbose = (optarg==0 ? 0 : atoi(optarg));
00139 break;
00140 case 'f' :
00141 outFile = optarg;
00142 break;
00143 case 'c' :
00144 iCrit = (optarg==0 ? 3 : atoi(optarg));
00145 break;
00146 case 'P' :
00147 bW = (optarg==0 ? 1. : atof(optarg));
00148 break;
00149 case 'L' :
00150 lambda = (optarg==0 ? 2. : atof(optarg));
00151 break;
00152 case 'O' :
00153 omega = (optarg==0 ? 5. : atof(optarg));
00154 break;
00155 case 'K' :
00156 split = true;
00157 splitFactor = (optarg==0 ? 0 : atof(optarg));
00158 break;
00159 case 'D' :
00160 splitRandomize = true;
00161 break;
00162 case 't' :
00163 valFile = optarg;
00164 break;
00165 case 'p' :
00166 valHbkFile = optarg;
00167 break;
00168 case 'b' :
00169 nbump = (optarg==0 ? 1 : atoi(optarg));
00170 break;
00171 case 'x' :
00172 apeel = (optarg==0 ? 1. : atof(optarg));
00173 break;
00174 case 'y' :
00175 inputClassesString = optarg;
00176 break;
00177 case 'Q' :
00178 transformerFile = optarg;
00179 break;
00180 case 'w' :
00181 if( optarg != 0 ) {
00182 scaleWeights = true;
00183 sW = atof(optarg);
00184 }
00185 break;
00186 case 'V' :
00187 includeList = optarg;
00188 break;
00189 case 'z' :
00190 excludeList = optarg;
00191 break;
00192 }
00193 }
00194
00195
00196 string trFile = argv[argc-1];
00197 if( trFile.empty() ) {
00198 cerr << "No training file is specified." << endl;
00199 return 1;
00200 }
00201
00202
00203 SprRWFactory::DataType inputType
00204 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00205 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00206
00207
00208 set<string> includeSet;
00209 if( !includeList.empty() ) {
00210 vector<vector<string> > includeVars;
00211 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00212 assert( !includeVars.empty() );
00213 for( int i=0;i<includeVars[0].size();i++ )
00214 includeSet.insert(includeVars[0][i]);
00215 if( !reader->chooseVars(includeSet) ) {
00216 cerr << "Unable to include variables in training set." << endl;
00217 return 2;
00218 }
00219 else {
00220 cout << "Following variables have been included in optimization: ";
00221 for( set<string>::const_iterator
00222 i=includeSet.begin();i!=includeSet.end();i++ )
00223 cout << "\"" << *i << "\"" << " ";
00224 cout << endl;
00225 }
00226 }
00227
00228
00229 set<string> excludeSet;
00230 if( !excludeList.empty() ) {
00231 vector<vector<string> > excludeVars;
00232 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00233 assert( !excludeVars.empty() );
00234 for( int i=0;i<excludeVars[0].size();i++ )
00235 excludeSet.insert(excludeVars[0][i]);
00236 if( !reader->chooseAllBut(excludeSet) ) {
00237 cerr << "Unable to exclude variables from training set." << endl;
00238 return 2;
00239 }
00240 else {
00241 cout << "Following variables have been excluded from optimization: ";
00242 for( set<string>::const_iterator
00243 i=excludeSet.begin();i!=excludeSet.end();i++ )
00244 cout << "\"" << *i << "\"" << " ";
00245 cout << endl;
00246 }
00247 }
00248
00249
00250 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00251 if( filter.get() == 0 ) {
00252 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00253 return 2;
00254 }
00255 vector<string> vars;
00256 filter->vars(vars);
00257 cout << "Read data from file " << trFile.c_str()
00258 << " for variables";
00259 for( int i=0;i<vars.size();i++ )
00260 cout << " \"" << vars[i].c_str() << "\"";
00261 cout << endl;
00262 cout << "Total number of points read: " << filter->size() << endl;
00263
00264
00265 vector<SprClass> inputClasses;
00266 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00267 cerr << "Cannot choose input classes for string "
00268 << inputClassesString << endl;
00269 return 2;
00270 }
00271 filter->classes(inputClasses);
00272 assert( inputClasses.size() > 1 );
00273 cout << "Training data filtered by class." << endl;
00274 for( int i=0;i<inputClasses.size();i++ ) {
00275 cout << "Points in class " << inputClasses[i] << ": "
00276 << filter->ptsInClass(inputClasses[i]) << endl;
00277 }
00278
00279
00280 if( scaleWeights ) {
00281 cout << "Signal weights are multiplied by " << sW << endl;
00282 filter->scaleWeights(inputClasses[1],sW);
00283 }
00284
00285
00286 auto_ptr<SprAbsFilter> valFilter;
00287 if( split && !valFile.empty() ) {
00288 cerr << "Unable to split training data and use validation data "
00289 << "from a separate file." << endl;
00290 return 2;
00291 }
00292 if( split ) {
00293 cout << "Splitting training data with factor " << splitFactor << endl;
00294 if( splitRandomize )
00295 cout << "Will use randomized splitting." << endl;
00296 vector<double> weights;
00297 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00298 if( splitted == 0 ) {
00299 cerr << "Unable to split training data." << endl;
00300 return 2;
00301 }
00302 bool ownData = true;
00303 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00304 cout << "Training data re-filtered:" << endl;
00305 for( int i=0;i<inputClasses.size();i++ ) {
00306 cout << "Points in class " << inputClasses[i] << ": "
00307 << filter->ptsInClass(inputClasses[i]) << endl;
00308 }
00309 }
00310 if( !valFile.empty() ) {
00311 auto_ptr<SprAbsReader>
00312 valReader(SprRWFactory::makeReader(inputType,readMode));
00313 if( !includeSet.empty() ) {
00314 if( !valReader->chooseVars(includeSet) ) {
00315 cerr << "Unable to include variables in validation set." << endl;
00316 return 2;
00317 }
00318 }
00319 if( !excludeSet.empty() ) {
00320 if( !valReader->chooseAllBut(excludeSet) ) {
00321 cerr << "Unable to exclude variables from validation set." << endl;
00322 return 2;
00323 }
00324 }
00325 valFilter.reset(valReader->read(valFile.c_str()));
00326 if( valFilter.get() == 0 ) {
00327 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00328 return 2;
00329 }
00330 vector<string> valVars;
00331 valFilter->vars(valVars);
00332 cout << "Read validation data from file " << valFile.c_str()
00333 << " for variables";
00334 for( int i=0;i<valVars.size();i++ )
00335 cout << " \"" << valVars[i].c_str() << "\"";
00336 cout << endl;
00337 cout << "Total number of points read: " << valFilter->size() << endl;
00338 }
00339
00340
00341 if( valFilter.get() != 0 ) {
00342 if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00343 cerr << "Cannot choose input classes for string "
00344 << inputClassesString << endl;
00345 return 2;
00346 }
00347 valFilter->classes(inputClasses);
00348 cout << "Validation data filtered by class." << endl;
00349 for( int i=0;i<inputClasses.size();i++ ) {
00350 cout << "Points in class " << inputClasses[i] << ": "
00351 << valFilter->ptsInClass(inputClasses[i]) << endl;
00352 }
00353 }
00354
00355
00356 if( scaleWeights && valFilter.get()!=0 )
00357 valFilter->scaleWeights(inputClasses[1],sW);
00358
00359
00360 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00361 if( !transformerFile.empty() ) {
00362 SprVarTransformerReader transReader;
00363 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00364 if( t == 0 ) {
00365 cerr << "Unable to read VarTransformer from file "
00366 << transformerFile.c_str() << endl;
00367 return 2;
00368 }
00369 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00370 SprTransformerFilter* t_valid = 0;
00371 if( valFilter.get() != 0 )
00372 t_valid = new SprTransformerFilter(valFilter.get());
00373 bool replaceOriginalData = true;
00374 if( !t_train->transform(t,replaceOriginalData) ) {
00375 cerr << "Unable to apply VarTransformer to training data." << endl;
00376 return 2;
00377 }
00378 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00379 cerr << "Unable to apply VarTransformer to validation data." << endl;
00380 return 2;
00381 }
00382 cout << "Variable transformation from file "
00383 << transformerFile.c_str() << " has been applied to "
00384 << "training and validation data." << endl;
00385 garbage_train.reset(filter.release());
00386 garbage_valid.reset(valFilter.release());
00387 filter.reset(t_train);
00388 valFilter.reset(t_valid);
00389 }
00390
00391
00392 auto_ptr<SprAbsTwoClassCriterion> crit;
00393 switch( iCrit )
00394 {
00395 case 1 :
00396 crit.reset(new SprTwoClassIDFraction);
00397 cout << "Optimization criterion set to "
00398 << "Fraction of correctly classified events " << endl;
00399 break;
00400 case 2 :
00401 crit.reset(new SprTwoClassSignalSignif);
00402 cout << "Optimization criterion set to "
00403 << "Signal significance S/sqrt(S+B) " << endl;
00404 break;
00405 case 3 :
00406 crit.reset(new SprTwoClassPurity);
00407 cout << "Optimization criterion set to "
00408 << "Purity S/(S+B) " << endl;
00409 break;
00410 case 4 :
00411 crit.reset(new SprTwoClassTaggerEff);
00412 cout << "Optimization criterion set to "
00413 << "Tagging efficiency Q = e*(1-2w)^2 " << endl;
00414 break;
00415 case 5 :
00416 crit.reset(new SprTwoClassGiniIndex);
00417 cout << "Optimization criterion set to "
00418 << "Gini index -1+p^2+q^2 " << endl;
00419 break;
00420 case 6 :
00421 crit.reset(new SprTwoClassCrossEntropy);
00422 cout << "Optimization criterion set to "
00423 << "Cross-entropy p*log(p)+q*log(q) " << endl;
00424 break;
00425 case 7 :
00426 crit.reset(new SprTwoClassUniformPriorUL90);
00427 cout << "Optimization criterion set to "
00428 << "Inverse of 90% Bayesian upper limit with uniform prior" << endl;
00429 break;
00430 case 8 :
00431 crit.reset(new SprTwoClassBKDiscovery);
00432 cout << "Optimization criterion set to "
00433 << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl;
00434 break;
00435 case 9 :
00436 crit.reset(new SprTwoClassPunzi(bW));
00437 cout << "Optimization criterion set to "
00438 << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl;
00439 break;
00440 case 10 :
00441 crit.reset(new SprTwoClassBgrndSmoother(bW,lambda,omega));
00442 cout << "Optimization criterion set to "
00443 << "background-smoothed Punzi's sensitivity" << endl;
00444 break;
00445 default :
00446 cerr << "Unable to make initialization criterion." << endl;
00447 return 3;
00448 }
00449
00450
00451 SprBumpHunter bump(filter.get(),crit.get(),nbump,nmin,apeel);
00452
00453
00454 if( !bump.train(verbose) ) {
00455 cerr << "Unable to find bumps." << endl;
00456 return 4;
00457 }
00458
00459
00460 if( !outFile.empty() ) {
00461 if( !bump.store(outFile.c_str()) ) {
00462 cerr << "Cannot store bump hunter in file " << outFile.c_str() << endl;
00463 return 5;
00464 }
00465 }
00466
00467
00468 auto_ptr<SprTrainedDecisionTree> trainedTree(bump.makeTrained());
00469
00470
00471 if( valFilter.get() != 0 ) {
00472 double wcor0(0), wmis0(0), wcor1(0), wmis1(0);
00473 int ncor0(0), nmis0(0), ncor1(0), nmis1(0);
00474 for( int i=0;i<valFilter->size();i++ ) {
00475 const SprPoint* p = (*valFilter.get())[i];
00476 double w = valFilter->w(i);
00477 if( trainedTree->accept(p) ) {
00478 if( p->class_ == inputClasses[0] ) {
00479 wmis0 += w;
00480 nmis0++;
00481 }
00482 else if( p->class_ == inputClasses[1] ) {
00483 wcor1 += w;
00484 ncor1++;
00485 }
00486 }
00487 else {
00488 if( p->class_ == inputClasses[0] ) {
00489 wcor0 += w;
00490 ncor0++;
00491 }
00492 else if( p->class_ == inputClasses[1] ) {
00493 wmis1 += w;
00494 nmis1++;
00495 }
00496 }
00497 }
00498 double vFom = crit->fom(wcor0,wmis0,wcor1,wmis1);
00499 cout << "=====================================================" << endl;
00500 cout << "Validation FOM=" << vFom << endl;
00501 cout << "Content of the signal region:"
00502 << " W0=" << wmis0 << " W1=" << wcor1
00503 << " N0=" << nmis0 << " N1=" << ncor1
00504 << endl;
00505 cout << "=====================================================" << endl;
00506 }
00507
00508
00509 if( tupleFile.empty() && valHbkFile.empty() ) return 0;
00510
00511
00512 class BoxNumberWrapper : public SprTrainedDecisionTree {
00513 public:
00514 virtual ~BoxNumberWrapper() {}
00515 BoxNumberWrapper(const SprTrainedDecisionTree& tree)
00516 : SprTrainedDecisionTree(tree) {}
00517 double response(const std::vector<double>& v) const {
00518 return this->nBox(v);
00519 }
00520 };
00521
00522
00523 if( !tupleFile.empty() ) {
00524
00525 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00526 if( !tuple->init(tupleFile.c_str()) ) {
00527 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00528 return 6;
00529 }
00530
00531 BoxNumberWrapper boxNumber(*(trainedTree.get()));
00532
00533 SprDataFeeder feeder(filter.get(),tuple.get());
00534 feeder.addClassifier(trainedTree.get(),"bump");
00535 feeder.addClassifier(&boxNumber,"box");
00536 if( !feeder.feed(1000) ) {
00537 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00538 return 6;
00539 }
00540 }
00541
00542 if( !valHbkFile.empty() ) {
00543
00544 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"test"));
00545 if( !tuple->init(valHbkFile.c_str()) ) {
00546 cerr << "Unable to open output file " << valHbkFile.c_str() << endl;
00547 return 7;
00548 }
00549
00550 BoxNumberWrapper boxNumber(*(trainedTree.get()));
00551
00552 SprDataFeeder feeder(valFilter.get(),tuple.get());
00553 feeder.addClassifier(trainedTree.get(),"bump");
00554 feeder.addClassifier(&boxNumber,"box");
00555 if( !feeder.feed(1000) ) {
00556 cerr << "Cannot feed data into file " << valHbkFile.c_str() << endl;
00557 return 7;
00558 }
00559 }
00560
00561
00562 return 0;
00563 }