CMS 3D CMS Logo

SprBumpHunterApp.cc File Reference

#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprBumpHunter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBgrndSmoother.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>

Go to the source code of this file.

Functions

void help (const char *prog)
int main (int argc, char **argv)


Function Documentation

void help ( const char *  prog  ) 

Definition at line 39 of file SprBumpHunterApp.cc.

References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().

00040 {
00041   cout << "Usage:  " << prog 
00042        << " training_data_file" << endl;
00043   cout << "\t Options: " << endl;
00044   cout << "\t-h --- help                                        " << endl;
00045   cout << "\t-o output Tuple file                                 " << endl;
00046   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00047   cout << "\t-A save output data in ascii instead of Root       " << endl;
00048   cout << "\t-n minimal number of events per bump (def=1)       " << endl;
00049   cout << "\t-b requested number of bumps (def=1)               " << endl;
00050   cout << "\t-x max fraction of events peeled off in one try    " << endl;
00051   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00052   cout << "\t-Q apply variable transformation saved in file     " << endl;
00053   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00054   cout << "\t-f store trained bump hunter to file               " << endl;
00055   cout << "\t-c criterion for optimization                      " << endl;
00056   cout << "\t\t 1 = correctly classified fraction               " << endl;
00057   cout << "\t\t 2 = signal significance s/sqrt(s+b)             " << endl;
00058   cout << "\t\t 3 = purity s/(s+b) (default)                    " << endl;
00059   cout << "\t\t 4 = tagger efficiency Q                         " << endl;
00060   cout << "\t\t 5 = Gini index                                  " << endl;
00061   cout << "\t\t 6 = cross-entropy                               " << endl;
00062   cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl;
00063   cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b))   " << endl;
00064   cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b))  " << endl;
00065   cout << "\t\t 10= background-smoothed Punzi's sensitivity     " << endl;
00066   cout << "\t\t -P background normalization factor for Punzi FOM" << endl;
00067   cout << "\t\t -L lambda for the background-smoothed FOM       " << endl;
00068   cout << "\t\t -O omega for the background-smoothed FOM        " << endl;
00069   cout << "\t-K keep this fraction in training set and          " << endl;
00070   cout << "\t\t put the rest into validation set                " << endl;
00071   cout << "\t-D randomize training set split-up                 " << endl;
00072   cout << "\t-t read validation/test data from a file           " << endl;
00073   cout << "\t\t (must be in same format as input data!!!        " << endl;
00074   cout << "\t-p output file to store validation/test data       " << endl;
00075   cout << "\t-w scale all signal weights by this factor         " << endl;
00076   cout << "\t-V include only these input variables              " << endl;
00077   cout << "\t-z exclude input variables from the list           " << endl;
00078   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00079        << endl;
00080 }

int main ( int  argc,
char **  argv 
)

Definition at line 83 of file SprBumpHunterApp.cc.

References c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, lat::endl(), filter, help(), i, p, split, t, tree, v, vars, w, and weights.

00084 {
00085   // check command line
00086   if( argc < 2 ) {
00087     help(argv[0]);
00088     return 1;
00089   }
00090 
00091   // init
00092   string tupleFile;
00093   int readMode = 0;
00094   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00095   unsigned nmin = 1;
00096   int verbose = 0;
00097   string outFile;
00098   string resumeFile;
00099   int iCrit = 3;
00100   string valFile;
00101   string valHbkFile;
00102   int nbump = 1;
00103   double apeel = 1.;
00104   bool scaleWeights = false;
00105   double sW = 1.;
00106   string includeList, excludeList;
00107   string inputClassesString;
00108   double bW = 1.;
00109   double lambda = 2.;
00110   double omega = 5.;
00111   bool split = false;
00112   double splitFactor = 0;
00113   bool splitRandomize = false;   
00114   string transformerFile;
00115 
00116   // decode command line
00117   int c;
00118   extern char* optarg;
00119   while( (c = getopt(argc,argv,"ho:a:An:v:f:c:P:L:O:K:Dt:p:b:x:y:Q:w:V:z:")) != EOF ) {
00120     switch( c )
00121       {
00122       case 'h' :
00123         help(argv[0]);
00124         return 1;
00125       case 'o' :
00126         tupleFile = optarg;
00127         break;
00128       case 'a' :
00129         readMode = (optarg==0 ? 0 : atoi(optarg));
00130         break;
00131       case 'A' :
00132         writeMode = SprRWFactory::Ascii;
00133         break;
00134       case 'n' :
00135         nmin = (optarg==0 ? 1 : atoi(optarg));
00136         break;
00137       case 'v' :
00138         verbose = (optarg==0 ? 0 : atoi(optarg));
00139         break;
00140       case 'f' :
00141         outFile = optarg;
00142         break;
00143       case 'c' :
00144         iCrit = (optarg==0 ? 3 : atoi(optarg));
00145         break;
00146       case 'P' :
00147         bW = (optarg==0 ? 1. : atof(optarg));
00148         break;
00149       case 'L' :
00150         lambda = (optarg==0 ? 2. : atof(optarg));
00151         break;
00152       case 'O' :
00153         omega = (optarg==0 ? 5. : atof(optarg));
00154         break;
00155       case 'K' :
00156         split = true;
00157         splitFactor = (optarg==0 ? 0 : atof(optarg));
00158         break;
00159       case 'D' :
00160         splitRandomize = true;
00161         break;
00162       case 't' :
00163         valFile = optarg;
00164         break;
00165       case 'p' :
00166         valHbkFile = optarg;
00167         break;
00168       case 'b' :
00169         nbump = (optarg==0 ? 1 : atoi(optarg));
00170         break;
00171       case 'x' :
00172         apeel = (optarg==0 ? 1. : atof(optarg));
00173         break;
00174       case 'y' :
00175         inputClassesString = optarg;
00176         break;
00177       case 'Q' :
00178         transformerFile = optarg;
00179         break;
00180       case 'w' :
00181         if( optarg != 0 ) {
00182           scaleWeights = true;
00183           sW = atof(optarg);
00184         }
00185         break;
00186       case 'V' :
00187         includeList = optarg;
00188         break;
00189       case 'z' :
00190         excludeList = optarg;
00191         break;
00192       }
00193   }
00194 
00195   // There has to be 1 argument after all options.
00196   string trFile = argv[argc-1];
00197   if( trFile.empty() ) {
00198     cerr << "No training file is specified." << endl;
00199     return 1;
00200   }
00201 
00202   // make reader
00203   SprRWFactory::DataType inputType 
00204     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00205   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00206 
00207   // include variables
00208   set<string> includeSet;
00209   if( !includeList.empty() ) {
00210     vector<vector<string> > includeVars;
00211     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00212     assert( !includeVars.empty() );
00213     for( int i=0;i<includeVars[0].size();i++ ) 
00214       includeSet.insert(includeVars[0][i]);
00215     if( !reader->chooseVars(includeSet) ) {
00216       cerr << "Unable to include variables in training set." << endl;
00217       return 2;
00218     }
00219     else {
00220       cout << "Following variables have been included in optimization: ";
00221       for( set<string>::const_iterator 
00222              i=includeSet.begin();i!=includeSet.end();i++ )
00223         cout << "\"" << *i << "\"" << " ";
00224       cout << endl;
00225     }
00226   }
00227 
00228   // exclude variables
00229   set<string> excludeSet;
00230   if( !excludeList.empty() ) {
00231     vector<vector<string> > excludeVars;
00232     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00233     assert( !excludeVars.empty() );
00234     for( int i=0;i<excludeVars[0].size();i++ ) 
00235       excludeSet.insert(excludeVars[0][i]);
00236     if( !reader->chooseAllBut(excludeSet) ) {
00237       cerr << "Unable to exclude variables from training set." << endl;
00238       return 2;
00239     }
00240     else {
00241       cout << "Following variables have been excluded from optimization: ";
00242       for( set<string>::const_iterator 
00243              i=excludeSet.begin();i!=excludeSet.end();i++ )
00244         cout << "\"" << *i << "\"" << " ";
00245       cout << endl;
00246     }
00247   }
00248 
00249   // read training data from file
00250   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00251   if( filter.get() == 0 ) {
00252     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00253     return 2;
00254   }
00255   vector<string> vars;
00256   filter->vars(vars);
00257   cout << "Read data from file " << trFile.c_str() 
00258        << " for variables";
00259   for( int i=0;i<vars.size();i++ ) 
00260     cout << " \"" << vars[i].c_str() << "\"";
00261   cout << endl;
00262   cout << "Total number of points read: " << filter->size() << endl;
00263 
00264   // filter training data by class
00265   vector<SprClass> inputClasses;
00266   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00267     cerr << "Cannot choose input classes for string " 
00268          << inputClassesString << endl;
00269     return 2;
00270   }
00271   filter->classes(inputClasses);
00272   assert( inputClasses.size() > 1 );
00273   cout << "Training data filtered by class." << endl;
00274   for( int i=0;i<inputClasses.size();i++ ) {
00275     cout << "Points in class " << inputClasses[i] << ":   " 
00276          << filter->ptsInClass(inputClasses[i]) << endl;
00277   }
00278 
00279   // scale weights
00280   if( scaleWeights ) {
00281     cout << "Signal weights are multiplied by " << sW << endl;
00282     filter->scaleWeights(inputClasses[1],sW);
00283   }
00284 
00285   // read validation data from file
00286   auto_ptr<SprAbsFilter> valFilter;
00287   if( split && !valFile.empty() ) {
00288     cerr << "Unable to split training data and use validation data " 
00289          << "from a separate file." << endl;
00290     return 2;
00291   }
00292   if( split ) {
00293     cout << "Splitting training data with factor " << splitFactor << endl;
00294     if( splitRandomize )
00295       cout << "Will use randomized splitting." << endl;
00296     vector<double> weights;
00297     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00298     if( splitted == 0 ) {
00299       cerr << "Unable to split training data." << endl;
00300       return 2;
00301     }
00302     bool ownData = true;
00303     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00304     cout << "Training data re-filtered:" << endl;
00305     for( int i=0;i<inputClasses.size();i++ ) {
00306       cout << "Points in class " << inputClasses[i] << ":   " 
00307            << filter->ptsInClass(inputClasses[i]) << endl;
00308     }
00309   }
00310   if( !valFile.empty() ) {
00311     auto_ptr<SprAbsReader> 
00312       valReader(SprRWFactory::makeReader(inputType,readMode));
00313     if( !includeSet.empty() ) {
00314       if( !valReader->chooseVars(includeSet) ) {
00315         cerr << "Unable to include variables in validation set." << endl;
00316         return 2;
00317       }
00318     }
00319     if( !excludeSet.empty() ) {
00320       if( !valReader->chooseAllBut(excludeSet) ) {
00321         cerr << "Unable to exclude variables from validation set." << endl;
00322         return 2;
00323       }
00324     }
00325     valFilter.reset(valReader->read(valFile.c_str()));
00326     if( valFilter.get() == 0 ) {
00327       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00328       return 2;
00329     }
00330     vector<string> valVars;
00331     valFilter->vars(valVars);
00332     cout << "Read validation data from file " << valFile.c_str()
00333          << " for variables";
00334     for( int i=0;i<valVars.size();i++ )
00335       cout << " \"" << valVars[i].c_str() << "\"";
00336     cout << endl;
00337     cout << "Total number of points read: " << valFilter->size() << endl;
00338   }
00339   
00340   // filter validation data by class
00341   if( valFilter.get() != 0 ) {
00342     if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00343       cerr << "Cannot choose input classes for string " 
00344            << inputClassesString << endl;
00345       return 2;
00346     }
00347     valFilter->classes(inputClasses);
00348     cout << "Validation data filtered by class." << endl;
00349     for( int i=0;i<inputClasses.size();i++ ) {
00350       cout << "Points in class " << inputClasses[i] << ":   " 
00351            << valFilter->ptsInClass(inputClasses[i]) << endl;
00352     }
00353   }
00354 
00355   // scale weights
00356   if( scaleWeights && valFilter.get()!=0 )
00357     valFilter->scaleWeights(inputClasses[1],sW);
00358 
00359   // apply transformation of variables to training and test data
00360   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00361   if( !transformerFile.empty() ) {
00362     SprVarTransformerReader transReader;
00363     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00364     if( t == 0 ) {
00365       cerr << "Unable to read VarTransformer from file "
00366            << transformerFile.c_str() << endl;
00367       return 2;
00368     }
00369     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00370     SprTransformerFilter* t_valid = 0;
00371     if( valFilter.get() != 0 )
00372       t_valid = new SprTransformerFilter(valFilter.get());
00373     bool replaceOriginalData = true;
00374     if( !t_train->transform(t,replaceOriginalData) ) {
00375       cerr << "Unable to apply VarTransformer to training data." << endl;
00376       return 2;
00377     }
00378     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00379       cerr << "Unable to apply VarTransformer to validation data." << endl;
00380       return 2;
00381     }
00382     cout << "Variable transformation from file "
00383          << transformerFile.c_str() << " has been applied to "
00384          << "training and validation data." << endl;
00385     garbage_train.reset(filter.release());
00386     garbage_valid.reset(valFilter.release());
00387     filter.reset(t_train);
00388     valFilter.reset(t_valid);
00389   }
00390 
00391   // make optimization criterion
00392   auto_ptr<SprAbsTwoClassCriterion> crit;
00393   switch( iCrit )
00394     {
00395     case 1 :
00396       crit.reset(new SprTwoClassIDFraction);
00397       cout << "Optimization criterion set to "
00398            << "Fraction of correctly classified events " << endl;
00399       break;
00400     case 2 :
00401       crit.reset(new SprTwoClassSignalSignif);
00402       cout << "Optimization criterion set to "
00403            << "Signal significance S/sqrt(S+B) " << endl;
00404       break;
00405     case 3 :
00406       crit.reset(new SprTwoClassPurity);
00407       cout << "Optimization criterion set to "
00408            << "Purity S/(S+B) " << endl;
00409       break;
00410     case 4 :
00411       crit.reset(new SprTwoClassTaggerEff);
00412       cout << "Optimization criterion set to "
00413            << "Tagging efficiency Q = e*(1-2w)^2 " << endl;
00414       break;
00415     case 5 :
00416       crit.reset(new SprTwoClassGiniIndex);
00417       cout << "Optimization criterion set to "
00418            << "Gini index  -1+p^2+q^2 " << endl;
00419       break;
00420     case 6 :
00421       crit.reset(new SprTwoClassCrossEntropy);
00422       cout << "Optimization criterion set to "
00423            << "Cross-entropy p*log(p)+q*log(q) " << endl;
00424       break;
00425     case 7 :
00426       crit.reset(new SprTwoClassUniformPriorUL90);
00427       cout << "Optimization criterion set to "
00428            << "Inverse of 90% Bayesian upper limit with uniform prior" << endl;
00429       break;
00430     case 8 :
00431       crit.reset(new SprTwoClassBKDiscovery);
00432       cout << "Optimization criterion set to "
00433            << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl;
00434       break;
00435     case 9 :
00436       crit.reset(new SprTwoClassPunzi(bW));
00437       cout << "Optimization criterion set to "
00438            << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl;
00439       break;
00440     case 10 :
00441       crit.reset(new SprTwoClassBgrndSmoother(bW,lambda,omega));
00442       cout << "Optimization criterion set to "
00443            << "background-smoothed Punzi's sensitivity" << endl;
00444       break;
00445     default :
00446       cerr << "Unable to make initialization criterion." << endl;
00447       return 3;
00448     }
00449 
00450   // make decision tree
00451   SprBumpHunter bump(filter.get(),crit.get(),nbump,nmin,apeel);
00452 
00453   // train
00454   if( !bump.train(verbose) ) {
00455     cerr << "Unable to find bumps." << endl;
00456     return 4;
00457   }
00458 
00459   // save trained decision tree
00460   if( !outFile.empty() ) {
00461     if( !bump.store(outFile.c_str()) ) {
00462       cerr << "Cannot store bump hunter in file " << outFile.c_str() << endl;
00463       return 5;
00464     }
00465   }
00466 
00467   // make trained decision tree
00468   auto_ptr<SprTrainedDecisionTree> trainedTree(bump.makeTrained());
00469 
00470   // compute FOM for the validation data
00471   if( valFilter.get() != 0 ) {
00472     double wcor0(0), wmis0(0), wcor1(0), wmis1(0);
00473     int ncor0(0), nmis0(0), ncor1(0), nmis1(0);
00474     for( int i=0;i<valFilter->size();i++ ) {
00475       const SprPoint* p = (*valFilter.get())[i];
00476       double w = valFilter->w(i);
00477       if( trainedTree->accept(p) ) {
00478         if(      p->class_ == inputClasses[0] ) {
00479           wmis0 += w;
00480           nmis0++;
00481         }
00482         else if( p->class_ == inputClasses[1] ) {
00483           wcor1 += w;
00484           ncor1++;
00485         }
00486       }
00487       else {
00488         if(      p->class_ == inputClasses[0] ) {
00489           wcor0 += w;
00490           ncor0++;
00491         }
00492         else if( p->class_ == inputClasses[1] ) {
00493           wmis1 += w;
00494           nmis1++;
00495         }
00496       }
00497     }
00498     double vFom = crit->fom(wcor0,wmis0,wcor1,wmis1);
00499     cout << "=====================================================" << endl;
00500     cout << "Validation FOM=" << vFom << endl;
00501     cout << "Content of the signal region:"
00502          << "   W0=" << wmis0 << "  W1=" << wcor1 
00503          << "   N0=" << nmis0 << "  N1=" << ncor1 
00504          << endl;
00505     cout << "=====================================================" << endl;
00506   }
00507 
00508   // make histogram if requested
00509   if( tupleFile.empty() && valHbkFile.empty() ) return 0;
00510 
00511   // make a wrapper to store box numbers
00512   class BoxNumberWrapper : public SprTrainedDecisionTree {
00513   public:
00514     virtual ~BoxNumberWrapper() {}
00515     BoxNumberWrapper(const SprTrainedDecisionTree& tree)
00516       : SprTrainedDecisionTree(tree) {}
00517     double response(const std::vector<double>& v) const {
00518       return this->nBox(v);
00519     }
00520   };
00521 
00522   // feed training data
00523   if( !tupleFile.empty() ) {
00524     // make a writer
00525     auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00526     if( !tuple->init(tupleFile.c_str()) ) {
00527       cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00528       return 6;
00529     }
00530     // wrap
00531     BoxNumberWrapper boxNumber(*(trainedTree.get()));
00532     // feed 
00533     SprDataFeeder feeder(filter.get(),tuple.get());
00534     feeder.addClassifier(trainedTree.get(),"bump");
00535     feeder.addClassifier(&boxNumber,"box");
00536     if( !feeder.feed(1000) ) {
00537       cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00538       return 6;
00539     }
00540   }
00541 
00542   if( !valHbkFile.empty() ) {
00543     // make a writer
00544     auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"test"));
00545     if( !tuple->init(valHbkFile.c_str()) ) {
00546       cerr << "Unable to open output file " << valHbkFile.c_str() << endl;
00547       return 7;
00548     }
00549     // wrap
00550     BoxNumberWrapper boxNumber(*(trainedTree.get()));
00551     // feed 
00552     SprDataFeeder feeder(valFilter.get(),tuple.get());
00553     feeder.addClassifier(trainedTree.get(),"bump");
00554     feeder.addClassifier(&boxNumber,"box");
00555     if( !feeder.feed(1000) ) {
00556       cerr << "Cannot feed data into file " << valHbkFile.c_str() << endl;
00557       return 7;
00558     }
00559   }
00560 
00561   // exit
00562   return 0;
00563 }


Generated on Tue Jun 9 17:55:00 2009 for CMSSW by  doxygen 1.5.4