00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprBumpHunter.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00026 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00027 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
00028 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
00029 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
00030 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
00031 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
00032 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
00033 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
00034 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
00035 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
00036 #include "PhysicsTools/StatPatternRecognition/interface/SprDataMoments.hh"
00037 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00038 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00039 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00040 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00041 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00042 #include "PhysicsTools/StatPatternRecognition/src/SprSymMatrix.hh"
00043 #include "PhysicsTools/StatPatternRecognition/src/SprVector.hh"
00044
00045 #include <stdlib.h>
00046 #include <unistd.h>
00047 #include <iostream>
00048 #include <vector>
00049 #include <set>
00050 #include <algorithm>
00051 #include <functional>
00052 #include <utility>
00053 #include <iomanip>
00054 #include <cmath>
00055 #include <memory>
00056
00057 using namespace std;
00058
00059
00060 struct SEACmpPairFirst
00061 : public binary_function<pair<double,int>,pair<double,int>,bool> {
00062 bool operator()(const pair<double,int>& l, const pair<double,int>& r)
00063 const {
00064 return (l.first < r.first);
00065 }
00066 };
00067
00068
00069 void cleanup(vector<const SprTrainedDecisionTree*>& trained)
00070 {
00071 for( int i=0;i<trained.size();i++ )
00072 delete trained[i];
00073 }
00074
00075
00076 void help(const char* prog)
00077 {
00078 cout << "Usage: " << prog
00079 << " training_data_file" << endl;
00080 cout << "\t Options: " << endl;
00081 cout << "\t-h --- help " << endl;
00082 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00083 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00084 cout << "\t-Q apply variable transformation saved in file " << endl;
00085 cout << "\t-c criterion for optimization " << endl;
00086 cout << "\t\t 1 = correctly classified fraction (default) " << endl;
00087 cout << "\t\t 2 = signal significance s/sqrt(s+b) " << endl;
00088 cout << "\t\t 3 = purity s/(s+b) " << endl;
00089 cout << "\t\t 4 = tagger efficiency Q " << endl;
00090 cout << "\t\t 5 = Gini index " << endl;
00091 cout << "\t\t 6 = cross-entropy " << endl;
00092 cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl;
00093 cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b)) " << endl;
00094 cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b)) " << endl;
00095 cout << "\t\t -P background normalization factor for Punzi FOM" << endl;
00096 cout << "\t-r compute correlations among intervals " << endl;
00097 cout << "\t-w scale all signal weights by this factor " << endl;
00098 cout << "\t-V include only these input variables " << endl;
00099 cout << "\t-z exclude input variables from the list " << endl;
00100 cout << "\t\t Variables must be listed in quotes and separated by commas."
00101 << endl;
00102 }
00103
00104
00105 int main(int argc, char ** argv)
00106 {
00107
00108 if( argc < 2 ) {
00109 help(argv[0]);
00110 return 1;
00111 }
00112
00113
00114 string tupleFile;
00115 int readMode = 0;
00116 int iCrit = 1;
00117 bool computeCorr = false;
00118 bool scaleWeights = false;
00119 double sW = 1.;
00120 string includeList, excludeList;
00121 string inputClassesString;
00122 double bW = 1.;
00123 string transformerFile;
00124
00125
00126 int c;
00127 extern char* optarg;
00128
00129 while( (c = getopt(argc,argv,"ha:y:Q:c:P:rw:V:z:")) != EOF ) {
00130 switch( c )
00131 {
00132 case 'h' :
00133 help(argv[0]);
00134 return 1;
00135 case 'a' :
00136 readMode = (optarg==0 ? 0 : atoi(optarg));
00137 break;
00138 case 'y' :
00139 inputClassesString = optarg;
00140 break;
00141 case 'Q' :
00142 transformerFile = optarg;
00143 break;
00144 case 'c' :
00145 iCrit = (optarg==0 ? 1 : atoi(optarg));
00146 break;
00147 case 'P' :
00148 bW = (optarg==0 ? 1 : atof(optarg));
00149 break;
00150 case 'r' :
00151 computeCorr = true;
00152 break;
00153 case 'w' :
00154 if( optarg != 0 ) {
00155 scaleWeights = true;
00156 sW = atof(optarg);
00157 }
00158 break;
00159 case 'V' :
00160 includeList = optarg;
00161 break;
00162 case 'z' :
00163 excludeList = optarg;
00164 break;
00165 }
00166 }
00167
00168
00169 string trFile = argv[argc-1];
00170 if( trFile.empty() ) {
00171 cerr << "No training file is specified." << endl;
00172 return 1;
00173 }
00174
00175
00176 SprRWFactory::DataType inputType
00177 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00178 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00179
00180
00181 set<string> includeSet;
00182 if( !includeList.empty() ) {
00183 vector<vector<string> > includeVars;
00184 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00185 assert( !includeVars.empty() );
00186 for( int i=0;i<includeVars[0].size();i++ )
00187 includeSet.insert(includeVars[0][i]);
00188 if( !reader->chooseVars(includeSet) ) {
00189 cerr << "Unable to include variables in training set." << endl;
00190 return 2;
00191 }
00192 else {
00193 cout << "Following variables have been included in optimization: ";
00194 for( set<string>::const_iterator
00195 i=includeSet.begin();i!=includeSet.end();i++ )
00196 cout << "\"" << *i << "\"" << " ";
00197 cout << endl;
00198 }
00199 }
00200
00201
00202 set<string> excludeSet;
00203 if( !excludeList.empty() ) {
00204 vector<vector<string> > excludeVars;
00205 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00206 assert( !excludeVars.empty() );
00207 for( int i=0;i<excludeVars[0].size();i++ )
00208 excludeSet.insert(excludeVars[0][i]);
00209 if( !reader->chooseAllBut(excludeSet) ) {
00210 cerr << "Unable to exclude variables from training set." << endl;
00211 return 2;
00212 }
00213 else {
00214 cout << "Following variables have been excluded from optimization: ";
00215 for( set<string>::const_iterator
00216 i=excludeSet.begin();i!=excludeSet.end();i++ )
00217 cout << "\"" << *i << "\"" << " ";
00218 cout << endl;
00219 }
00220 }
00221
00222
00223 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00224 if( filter.get() == 0 ) {
00225 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00226 return 2;
00227 }
00228 vector<string> vars;
00229 filter->vars(vars);
00230 cout << "Read data from file " << trFile.c_str()
00231 << " for variables";
00232 for( int i=0;i<vars.size();i++ )
00233 cout << " \"" << vars[i].c_str() << "\"";
00234 cout << endl;
00235 cout << "Total number of points read: " << filter->size() << endl;
00236
00237
00238 vector<SprClass> inputClasses;
00239 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00240 cerr << "Cannot choose input classes for string "
00241 << inputClassesString << endl;
00242 return 2;
00243 }
00244 filter->classes(inputClasses);
00245 assert( inputClasses.size() > 1 );
00246 cout << "Training data filtered by class." << endl;
00247 for( int i=0;i<inputClasses.size();i++ ) {
00248 cout << "Points in class " << inputClasses[i] << ": "
00249 << filter->ptsInClass(inputClasses[i]) << endl;
00250 }
00251
00252
00253 assert( vars.size() == filter->dim() );
00254 cout << "=================================" << endl;
00255 cout << "Input variables:" << endl;
00256 for( int i=0;i<vars.size();i++ )
00257 cout << i << " " << vars[i] << endl;
00258 cout << "=================================" << endl;
00259
00260
00261 if( scaleWeights )
00262 filter->scaleWeights(inputClasses[1],sW);
00263
00264
00265 auto_ptr<SprAbsFilter> garbage_train;
00266 if( !transformerFile.empty() ) {
00267 SprVarTransformerReader transReader;
00268 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00269 if( t == 0 ) {
00270 cerr << "Unable to read VarTransformer from file "
00271 << transformerFile.c_str() << endl;
00272 return 2;
00273 }
00274 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00275 bool replaceOriginalData = true;
00276 if( !t_train->transform(t,replaceOriginalData) ) {
00277 cerr << "Unable to apply VarTransformer to training data." << endl;
00278 return 2;
00279 }
00280 cout << "Variable transformation from file "
00281 << transformerFile.c_str() << " has been applied to data." << endl;
00282 garbage_train.reset(filter.release());
00283 filter.reset(t_train);
00284 }
00285
00286
00287 auto_ptr<SprAbsTwoClassCriterion> crit;
00288 switch( iCrit )
00289 {
00290 case 1 :
00291 crit.reset(new SprTwoClassIDFraction);
00292 cout << "Optimization criterion set to "
00293 << "Fraction of correctly classified events " << endl;
00294 break;
00295 case 2 :
00296 crit.reset(new SprTwoClassSignalSignif);
00297 cout << "Optimization criterion set to "
00298 << "Signal significance S/sqrt(S+B) " << endl;
00299 break;
00300 case 3 :
00301 crit.reset(new SprTwoClassPurity);
00302 cout << "Optimization criterion set to "
00303 << "Purity S/(S+B) " << endl;
00304 break;
00305 case 4 :
00306 crit.reset(new SprTwoClassTaggerEff);
00307 cout << "Optimization criterion set to "
00308 << "Tagging efficiency Q = e*(1-2w)^2 " << endl;
00309 break;
00310 case 5 :
00311 crit.reset(new SprTwoClassGiniIndex);
00312 cout << "Optimization criterion set to "
00313 << "Gini index -1+p^2+q^2 " << endl;
00314 break;
00315 case 6 :
00316 crit.reset(new SprTwoClassCrossEntropy);
00317 cout << "Optimization criterion set to "
00318 << "Cross-entropy p*log(p)+q*log(q) " << endl;
00319 break;
00320 case 7 :
00321 crit.reset(new SprTwoClassUniformPriorUL90);
00322 cout << "Optimization criterion set to "
00323 << "Inverse of 90% Bayesian upper limit with uniform prior" << endl;
00324 break;
00325 case 8 :
00326 crit.reset(new SprTwoClassBKDiscovery);
00327 cout << "Optimization criterion set to "
00328 << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl;
00329 break;
00330 case 9 :
00331 crit.reset(new SprTwoClassPunzi(bW));
00332 cout << "Optimization criterion set to "
00333 << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl;
00334 break;
00335 default :
00336 cerr << "Unable to make initialization criterion." << endl;
00337 return 3;
00338 }
00339
00340
00341 SprSymMatrix cov;
00342 SprVector mean;
00343 SprDataMoments moms(filter.get());
00344
00345
00346 if( !moms.covariance(cov,mean) ) {
00347 cerr << "Unable to compute covariance matrix for entire data." << endl;
00348 return 4;
00349 }
00350 cout << "Variables with zero variance: ";
00351 for( int i=0;i<vars.size();i++ ) {
00352 if( cov[i][i] < SprUtils::eps() )
00353 cout << vars[i].c_str() << ",";
00354 }
00355 cout << endl;
00356
00357
00358 for( int c=0;c<2;c++ ) {
00359 vector<SprClass> classes(1);
00360 classes[0] = inputClasses[c];
00361 filter->chooseClasses(classes);
00362 if( !filter->filter() ) {
00363 cerr << "Unable to filter class " << c << endl;
00364 return 4;
00365 }
00366 if( !moms.covariance(cov,mean) ) {
00367 cerr << "Unable to compute covariance matrix for input variables."
00368 << endl;
00369 return 4;
00370 }
00371
00372 cout << "===============================================" << endl;
00373 cout << "Covariance matrix computed with "
00374 << filter->ptsInClass(classes[0]) << " events." << endl;
00375 cout << "Input variable correlations in class " << c << ":" << endl;
00376 cout << "Column ";
00377 for( int i=0;i<filter->dim();i++ )
00378 cout << setw(10) << i << " ";
00379 cout << endl;
00380 cout << "--------";
00381 for( int i=0;i<filter->dim();i++ )
00382 cout << setw(10) << "----------" << "-";
00383 cout << endl;
00384 for( int i=0;i<filter->dim();i++ ) {
00385 cout << "Row " << i << " | ";
00386 for( int j=0;j<filter->dim();j++ )
00387 cout << setw(10) << cov[i][j]/sqrt(cov[i][i])/sqrt(cov[j][j]) << " ";
00388 cout << endl;
00389 }
00390 cout << "===============================================" << endl;
00391 }
00392 filter->clear();
00393
00394
00395 vector<double> corrLabel(filter->dim());
00396 vector<pair<double,int> > absCorrLabel(filter->dim());
00397 double meani(0), vari(0);
00398 for( int i=0;i<filter->dim();i++ ) {
00399 corrLabel[i] = moms.correlClassLabel(i,meani,vari);
00400 absCorrLabel[i] = pair<double,int>(fabs(corrLabel[i]),i);
00401 }
00402 stable_sort(absCorrLabel.begin(),absCorrLabel.end(),not2(SEACmpPairFirst()));
00403 cout << "===============================================" << endl;
00404 cout << "Correlations with class label:" << endl;
00405 for( int i=0;i<filter->dim();i++ ) {
00406 int k = absCorrLabel[i].second;
00407 cout << setw(40) << vars[k] << " " << setw(10) << corrLabel[k] << endl;
00408 }
00409 cout << "===============================================" << endl;
00410
00411
00412 vector<double> corrLabel2(filter->dim());
00413 vector<pair<double,int> > absCorrLabel2(filter->dim());
00414 double meani2(0), vari2(0);
00415 for( int i=0;i<filter->dim();i++ ) {
00416 corrLabel2[i] = moms.absCorrelClassLabel(i,meani2,vari2);
00417 absCorrLabel2[i] = pair<double,int>(fabs(corrLabel2[i]),i);
00418 }
00419 stable_sort(absCorrLabel2.begin(),absCorrLabel2.end(),
00420 not2(SEACmpPairFirst()));
00421 cout << "===============================================" << endl;
00422 cout << "Correlations of absolute values with class label:" << endl;
00423 for( int i=0;i<filter->dim();i++ ) {
00424 int k = absCorrLabel2[i].second;
00425 cout << setw(40) << vars[k] << " " << setw(10) << corrLabel2[k] << endl;
00426 }
00427 cout << "===============================================" << endl;
00428
00429
00430 vector<pair<double,int> > fom(filter->dim(),
00431 pair<double,int>(SprUtils::min(),-1));
00432 vector<const SprTrainedDecisionTree*> trained(filter->dim(),0);
00433 vector<double> w1vec(filter->dim()), w0vec(filter->dim());
00434
00435 SprData tempData("myDummy1Ddata",vector<string>(1,"dummy"));
00436 vector<double> x(1);
00437 for( int j=0;j<filter->size();j++ ) {
00438 const SprPoint* p = (*filter.get())[j];
00439 x[0] = p->x_[0];
00440 tempData.insert(p->index_,p->class_,x);
00441 }
00442
00443 vector<double> weights;
00444 filter->weights(weights);
00445
00446 SprEmptyFilter tempFilter(&tempData,weights);
00447 tempFilter.chooseClasses(inputClasses);
00448
00449 for( int d=0;d<filter->dim();d++ ) {
00450 if( d != 0 ) {
00451 for( int j=0;j<filter->size();j++ )
00452 tempFilter[j]->x_[0] = (*filter.get())[j]->x_[d];
00453 }
00454
00455 cout << "Optimizing interval in dimension " << d << endl;
00456 SprBumpHunter hunter(&tempFilter,crit.get(),1,int(0.01*filter->size()),1.);
00457 if( !hunter.train() ) {
00458 cerr << "Unable to train interval for dimension " << d << endl;
00459 continue;
00460 }
00461 const SprTrainedDecisionTree* t = hunter.makeTrained();
00462 trained[d] = t;
00463
00464 double wmis0(0), wcor0(0), wmis1(0), wcor1(0);
00465 for( int j=0;j<filter->size();j++ ) {
00466 const SprPoint* p = tempFilter[j];
00467 double w = tempFilter.w(j);
00468 if( p->class_ == inputClasses[0] ) {
00469 if( t->accept(p) )
00470 wmis0 += w;
00471 else
00472 wcor0 += w;
00473 }
00474 else if( p->class_ == inputClasses[1] ) {
00475 if( t->accept(p) )
00476 wcor1 += w;
00477 else
00478 wmis1 += w;
00479 }
00480 }
00481 fom[d] = pair<double,int>(crit->fom(wcor0,wmis0,wcor1,wmis1),d);
00482 w1vec[d] = wcor1;
00483 w0vec[d] = wmis0;
00484 }
00485
00486
00487 stable_sort(fom.begin(),fom.end(),not2(SEACmpPairFirst()));
00488
00489
00490 double w0 = filter->weightInClass(inputClasses[0]);
00491 double w1 = filter->weightInClass(inputClasses[1]);
00492 double fmin = crit->fom(0,w0,w1,0);
00493 double fmax = crit->fom(0,0,w1,0);
00494 cout << "Possible FOM range: " << fmin << " " << fmax << endl;
00495 for( int i=0;i<filter->dim();i++ ) {
00496 SprBox limits;
00497 int k = fom[i].second;
00498 if( k>=0 && trained[k]!=0 ) trained[k]->box(0,limits);
00499 SprBox::const_iterator iter = limits.find(0);
00500 if( iter != limits.end() ) {
00501 cout << i << " FOM= " << setw(8) << fom[i].first
00502 << " for variable \"" << setw(15) << vars[k] << "\""
00503 << " with acceptance interval "
00504 << setw(10) << iter->second.first << " "
00505 << setw(10) << iter->second.second
00506 << " W0=" << w0vec[k] << " W1=" << w1vec[k] << endl;
00507 }
00508 }
00509
00510
00511 if( computeCorr ) {
00512 SprSymMatrix corr(filter->dim());
00513 for( int i=0;i<filter->dim();i++ ) {
00514 int c1 = fom[i].second;
00515 if( c1<0 || trained[c1]==0 ) {
00516 cerr << "Unable to compute correlations: "
00517 << "There are uncomputed intervals." << endl;
00518 cleanup(trained);
00519 return 5;
00520 }
00521 for( int j=i+1;j<filter->dim();j++ ) {
00522 int c2 = fom[j].second;
00523 if( c2<0 || trained[c2]==0 ) {
00524 cerr << "Unable to compute correlations: "
00525 << "There are uncomputed intervals." << endl;
00526 cleanup(trained);
00527 return 5;
00528 }
00529 double a(0), b(0), c(0), d(0);
00530 for( int k=0;k<filter->size();k++ ) {
00531 const SprPoint* p = (*filter.get())[k];
00532 double w = filter->w(k);
00533 vector<double> x1(1), x2(1);
00534 x1[0] = p->x_[c1];
00535 x2[0] = p->x_[c2];
00536 if( p->class_ == inputClasses[0] ) {
00537 if( trained[c1]->accept(x1) ) {
00538 if( trained[c2]->accept(x2) )
00539 d += w;
00540 else
00541 c += w;
00542 }
00543 else {
00544 if( trained[c2]->accept(x2) )
00545 b += w;
00546 else
00547 a += w;
00548 }
00549 }
00550 else if( p->class_ == inputClasses[1] ) {
00551 if( trained[c1]->accept(x1) ) {
00552 if( trained[c2]->accept(x2) )
00553 a += w;
00554 else
00555 b += w;
00556 }
00557 else {
00558 if( trained[c2]->accept(x2) )
00559 c += w;
00560 else
00561 d += w;
00562 }
00563 }
00564 }
00565 if( (a+b)<SprUtils::eps() || (c+d)<SprUtils::eps()
00566 || (a+c)<SprUtils::eps() || (b+d)<SprUtils::eps() ) {
00567 cerr << "Unable to compute correlations: One of the sums is zero."
00568 << endl;
00569 cleanup(trained);
00570 return 5;
00571 }
00572 corr[i][j] = (a*d-b*c) / sqrt((a+b)*(c+d)*(a+c)*(b+d));
00573 }
00574 corr[i][i] = 1;
00575 }
00576
00577 cout << "Interval correlations: " << endl;
00578 cout << "Column ";
00579 for( int i=0;i<filter->dim();i++ )
00580 cout << setw(10) << i << " ";
00581 cout << endl;
00582 cout << "--------";
00583 for( int i=0;i<filter->dim();i++ )
00584 cout << setw(10) << "----------" << "-";
00585 cout << endl;
00586 for( int i=0;i<filter->dim();i++ ) {
00587 cout << "Row " << i << " | ";
00588 for( int j=0;j<filter->dim();j++ )
00589 cout << setw(10) << corr[i][j] << " ";
00590 cout << endl;
00591 }
00592 }
00593
00594
00595 cleanup(trained);
00596
00597
00598 return 0;
00599 }