CMS 3D CMS Logo

SprDecisionTreeApp.cc File Reference

#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <iomanip>
#include <fstream>

Go to the source code of this file.

Functions

void help (const char *prog)
int main (int argc, char **argv)


Function Documentation

void help ( const char *  prog  ) 

Definition at line 46 of file SprDecisionTreeApp.cc.

References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().

00047 {
00048   cout << "Usage:  " << prog 
00049        << " training_data_file" << endl;
00050   cout << "\t Options: " << endl;
00051   cout << "\t-h --- help                                        " << endl;
00052   cout << "\t-o output Tuple file                               " << endl;
00053   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00054   cout << "\t-A save output data in ascii instead of Root       " << endl;
00055   cout << "\t-n minimal number of events per tree node (def=1)  " << endl;
00056   cout << "\t-m --- merge nodes after training (def = no merge) " << endl;
00057   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00058   cout << "\t-Q apply variable transformation saved in file     " << endl;
00059   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00060   cout << "\t-T use Topdown tree with continuous output         " << endl;
00061   cout << "\t-f store decision tree to file in human-readable format" << endl;
00062   cout << "\t-F store decision tree to file in machine-readable format"<< endl;
00063   cout << "\t-c criterion for optimization                      " << endl;
00064   cout << "\t\t 1 = correctly classified fraction               " << endl;
00065   cout << "\t\t 2 = signal significance s/sqrt(s+b)             " << endl;
00066   cout << "\t\t 3 = purity s/(s+b)                              " << endl;
00067   cout << "\t\t 4 = tagger efficiency Q                         " << endl;
00068   cout << "\t\t 5 = Gini index (default)                        " << endl;
00069   cout << "\t\t 6 = cross-entropy                               " << endl;
00070   cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl;
00071   cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b))   " << endl;
00072   cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b))  " << endl;
00073   cout << "\t\t -P background normalization factor for Punzi FOM" << endl;
00074   cout << "\t-g per-event loss for (cross-)validation           " << endl;
00075   cout << "\t\t 1 - quadratic loss (y-f(x))^2                   " << endl;
00076   cout << "\t\t 2 - exponential loss exp(-y*f(x))               " << endl;
00077   cout << "\t\t 3 - misid fraction                              " << endl;
00078   cout << "\t-i count splits on input variables                 " << endl;
00079   cout << "\t-K keep this fraction in training set and          " << endl;
00080   cout << "\t\t put the rest into validation set                " << endl;
00081   cout << "\t-D randomize training set split-up                 " << endl;
00082   cout << "\t-t read validation/test data from a file           " << endl;
00083   cout << "\t\t (must be in same format as input data!!!        " << endl;
00084   cout << "\t-p output file to store validation/test data       " << endl;
00085   cout << "\t-w scale all signal weights by this factor         " << endl;
00086   cout << "\t-V include only these input variables              " << endl;
00087   cout << "\t-z exclude input variables from the list           " << endl;
00088   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00089        << endl;
00090   cout << "\t-x cross-validate by splitting data into a given "
00091        << "number of pieces" << endl;
00092   cout << "\t-q a set of minimal node sizes for cross-validation" << endl;
00093   cout << "\t\t Node sizes must be listed in quotes and separated by commas." 
00094        << endl;
00095 }

int main ( int  argc,
char **  argv 
)

Definition at line 98 of file SprDecisionTreeApp.cc.

References c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, lat::endl(), filter, help(), i, j, p, size, split, t, tree, v, vars, w, and weights.

00099 {
00100   // check command line
00101   if( argc < 2 ) {
00102     help(argv[0]);
00103     return 1;
00104   }
00105 
00106   // init
00107   string tupleFile;
00108   int readMode = 0;
00109   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00110   unsigned nmin = 1;
00111   int verbose = 0;
00112   bool useTopdown = false;
00113   string outHuman, outMachine;
00114   string resumeFile;
00115   int iCrit = 5;
00116   string valFile;
00117   string valHbkFile;
00118   bool doMerge = false;
00119   int iLoss = 0;
00120   bool scaleWeights = false;
00121   double sW = 1.;
00122   bool countTreeSplits = false;
00123   string includeList, excludeList;
00124   unsigned nCross = 0;
00125   string nodeValidationString;
00126   string inputClassesString;
00127   double bW = 1.;
00128   bool split = false;
00129   double splitFactor = 0;
00130   bool splitRandomize = false;
00131   string transformerFile;
00132 
00133   // decode command line
00134   int c;
00135   extern char* optarg;
00136   while( (c = getopt(argc,argv,"ho:a:An:v:f:TF:c:P:g:iK:Dt:p:my:Q:w:V:z:x:q:")) != EOF ) {
00137     switch( c )
00138       {
00139       case 'h' :
00140         help(argv[0]);
00141         return 1;
00142       case 'o' :
00143         tupleFile = optarg;
00144         break;
00145       case 'a' :
00146         readMode = (optarg==0 ? 0 : atoi(optarg));
00147         break;
00148       case 'A' :
00149         writeMode = SprRWFactory::Ascii;
00150         break;
00151       case 'n' :
00152         nmin = (optarg==0 ? 1 : atoi(optarg));
00153         break;
00154       case 'v' :
00155         verbose = (optarg==0 ? 0 : atoi(optarg));
00156         break;
00157       case 'T' :
00158         useTopdown = true;
00159         break;
00160       case 'f' :
00161         outHuman = optarg;
00162         break;
00163       case 'F' :
00164         outMachine = optarg;
00165         break;
00166       case 'c' :
00167         iCrit = (optarg==0 ? 5 : atoi(optarg));
00168         break;
00169       case 'P' :
00170         bW = (optarg==0 ? 1 : atof(optarg));
00171         break;
00172       case 'g' :
00173         iLoss = (optarg==0 ? 0 : atoi(optarg));
00174         break;
00175       case 'i' :
00176         countTreeSplits = true;
00177         break;
00178       case 'K' :
00179         split = true;
00180         splitFactor = (optarg==0 ? 0 : atof(optarg));
00181         break;
00182       case 'D' :
00183         splitRandomize = true;
00184         break;
00185       case 't' :
00186         valFile = optarg;
00187         break;
00188       case 'p' :
00189         valHbkFile = optarg;
00190         break;
00191       case 'm' :
00192         doMerge = true;
00193         break;
00194       case 'y' :
00195         inputClassesString = optarg;
00196         break;
00197       case 'Q' :
00198         transformerFile = optarg;
00199         break;
00200       case 'w' :
00201         if( optarg != 0 ) {
00202           scaleWeights = true;
00203           sW = atof(optarg);
00204         }
00205         break;
00206       case 'V' :
00207         includeList = optarg;
00208         break;
00209       case 'z' :
00210         excludeList = optarg;
00211         break;
00212       case 'x' :
00213         nCross = (optarg==0 ? 0 : atoi(optarg));
00214         break;
00215       case 'q' :
00216         nodeValidationString = optarg;
00217         break;
00218       }
00219   }
00220 
00221   // There has to be 1 argument after all options.
00222   string trFile = argv[argc-1];
00223   if( trFile.empty() ) {
00224     cerr << "No training file is specified." << endl;
00225     return 1;
00226   }
00227 
00228   // cannot merge nodes in Topdown trees
00229   if( doMerge ) useTopdown = false;
00230 
00231   // make reader
00232   SprRWFactory::DataType inputType 
00233     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00234   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00235 
00236   // include variables
00237   set<string> includeSet;
00238   if( !includeList.empty() ) {
00239     vector<vector<string> > includeVars;
00240     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00241     assert( !includeVars.empty() );
00242     for( int i=0;i<includeVars[0].size();i++ ) 
00243       includeSet.insert(includeVars[0][i]);
00244     if( !reader->chooseVars(includeSet) ) {
00245       cerr << "Unable to include variables in training set." << endl;
00246       return 2;
00247     }
00248     else {
00249       cout << "Following variables have been included in optimization: ";
00250       for( set<string>::const_iterator 
00251              i=includeSet.begin();i!=includeSet.end();i++ )
00252         cout << "\"" << *i << "\"" << " ";
00253       cout << endl;
00254     }
00255   }
00256 
00257   // exclude variables
00258   set<string> excludeSet;
00259   if( !excludeList.empty() ) {
00260     vector<vector<string> > excludeVars;
00261     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00262     assert( !excludeVars.empty() );
00263     for( int i=0;i<excludeVars[0].size();i++ ) 
00264       excludeSet.insert(excludeVars[0][i]);
00265     if( !reader->chooseAllBut(excludeSet) ) {
00266       cerr << "Unable to exclude variables from training set." << endl;
00267       return 2;
00268     }
00269     else {
00270       cout << "Following variables have been excluded from optimization: ";
00271       for( set<string>::const_iterator 
00272              i=excludeSet.begin();i!=excludeSet.end();i++ )
00273         cout << "\"" << *i << "\"" << " ";
00274       cout << endl;
00275     }
00276   }
00277 
00278   // read training data from file
00279   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00280   if( filter.get() == 0 ) {
00281     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00282     return 2;
00283   }
00284   vector<string> vars;
00285   filter->vars(vars);
00286   cout << "Read data from file " << trFile.c_str() 
00287        << " for variables";
00288   for( int i=0;i<vars.size();i++ ) 
00289     cout << " \"" << vars[i].c_str() << "\"";
00290   cout << endl;
00291   cout << "Total number of points read: " << filter->size() << endl;
00292 
00293   // filter training data by class
00294   vector<SprClass> inputClasses;
00295   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00296     cerr << "Cannot choose input classes for string " 
00297          << inputClassesString << endl;
00298     return 2;
00299   }
00300   filter->classes(inputClasses);
00301   assert( inputClasses.size() > 1 );
00302   cout << "Training data filtered by class." << endl;
00303   for( int i=0;i<inputClasses.size();i++ ) {
00304     cout << "Points in class " << inputClasses[i] << ":   " 
00305          << filter->ptsInClass(inputClasses[i]) << endl;
00306   }
00307 
00308   // scale weights
00309   if( scaleWeights ) {
00310     cout << "Signal weights are multiplied by " << sW << endl;
00311     filter->scaleWeights(inputClasses[1],sW);
00312   }
00313 
00314   // read validation data from file
00315   auto_ptr<SprAbsFilter> valFilter;
00316   if( split && !valFile.empty() ) {
00317     cerr << "Unable to split training data and use validation data " 
00318          << "from a separate file." << endl;
00319     return 2;
00320   }
00321   if( split ) {
00322     cout << "Splitting training data with factor " << splitFactor << endl;
00323     if( splitRandomize )
00324       cout << "Will use randomized splitting." << endl;
00325     vector<double> weights;
00326     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00327     if( splitted == 0 ) {
00328       cerr << "Unable to split training data." << endl;
00329       return 2;
00330     }
00331     bool ownData = true;
00332     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00333     cout << "Training data re-filtered:" << endl;
00334     for( int i=0;i<inputClasses.size();i++ ) {
00335       cout << "Points in class " << inputClasses[i] << ":   " 
00336            << filter->ptsInClass(inputClasses[i]) << endl;
00337     }
00338   }
00339   if( !valFile.empty() ) {
00340     auto_ptr<SprAbsReader> 
00341       valReader(SprRWFactory::makeReader(inputType,readMode));
00342     if( !includeSet.empty() ) {
00343       if( !valReader->chooseVars(includeSet) ) {
00344         cerr << "Unable to include variables in validation set." << endl;
00345         return 2;
00346       }
00347     }
00348     if( !excludeSet.empty() ) {
00349       if( !valReader->chooseAllBut(excludeSet) ) {
00350         cerr << "Unable to exclude variables from validation set." << endl;
00351         return 2;
00352       }
00353     }
00354     valFilter.reset(valReader->read(valFile.c_str()));
00355     if( valFilter.get() == 0 ) {
00356       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00357       return 2;
00358     }
00359     vector<string> valVars;
00360     valFilter->vars(valVars);
00361     cout << "Read validation data from file " << valFile.c_str()
00362          << " for variables";
00363     for( int i=0;i<valVars.size();i++ )
00364       cout << " \"" << valVars[i].c_str() << "\"";
00365     cout << endl;
00366     cout << "Total number of points read: " << valFilter->size() << endl;
00367   }
00368   
00369   // filter validation data by class
00370   if( valFilter.get() != 0 ) {
00371     if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00372       cerr << "Cannot choose input classes for string " 
00373            << inputClassesString << endl;
00374       return 2;
00375     }
00376     valFilter->classes(inputClasses);
00377     cout << "Validation data filtered by class." << endl;
00378     for( int i=0;i<inputClasses.size();i++ ) {
00379       cout << "Points in class " << inputClasses[i] << ":   " 
00380            << valFilter->ptsInClass(inputClasses[i]) << endl;
00381     }
00382   }
00383 
00384   // scale weights
00385   if( scaleWeights && valFilter.get()!=0 )
00386     valFilter->scaleWeights(inputClasses[1],sW);
00387 
00388   // apply transformation of variables to training and test data
00389   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00390   if( !transformerFile.empty() ) {
00391     SprVarTransformerReader transReader;
00392     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00393     if( t == 0 ) {
00394       cerr << "Unable to read VarTransformer from file "
00395            << transformerFile.c_str() << endl;
00396       return 2;
00397     }
00398     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00399     SprTransformerFilter* t_valid = 0;
00400     if( valFilter.get() != 0 )
00401       t_valid = new SprTransformerFilter(valFilter.get());
00402     bool replaceOriginalData = true;
00403     if( !t_train->transform(t,replaceOriginalData) ) {
00404       cerr << "Unable to apply VarTransformer to training data." << endl;
00405       return 2;
00406     }
00407     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00408       cerr << "Unable to apply VarTransformer to validation data." << endl;
00409       return 2;
00410     }
00411     cout << "Variable transformation from file "
00412          << transformerFile.c_str() << " has been applied to "
00413          << "training and validation data." << endl;
00414     garbage_train.reset(filter.release());
00415     garbage_valid.reset(valFilter.release());
00416     filter.reset(t_train);
00417     valFilter.reset(t_valid);
00418   }
00419 
00420   // make optimization criterion
00421   auto_ptr<SprAbsTwoClassCriterion> crit;
00422   switch( iCrit )
00423     {
00424     case 1 :
00425       crit.reset(new SprTwoClassIDFraction);
00426       cout << "Optimization criterion set to "
00427            << "Fraction of correctly classified events " << endl;
00428       break;
00429     case 2 :
00430       crit.reset(new SprTwoClassSignalSignif);
00431       cout << "Optimization criterion set to "
00432            << "Signal significance S/sqrt(S+B) " << endl;
00433       break;
00434     case 3 :
00435       crit.reset(new SprTwoClassPurity);
00436       cout << "Optimization criterion set to "
00437            << "Purity S/(S+B) " << endl;
00438       break;
00439     case 4 :
00440       crit.reset(new SprTwoClassTaggerEff);
00441       cout << "Optimization criterion set to "
00442            << "Tagging efficiency Q = e*(1-2w)^2 " << endl;
00443       break;
00444     case 5 :
00445       crit.reset(new SprTwoClassGiniIndex);
00446       cout << "Optimization criterion set to "
00447            << "Gini index  -1+p^2+q^2 " << endl;
00448       break;
00449     case 6 :
00450       crit.reset(new SprTwoClassCrossEntropy);
00451       cout << "Optimization criterion set to "
00452            << "Cross-entropy p*log(p)+q*log(q) " << endl;
00453       break;
00454     case 7 :
00455       crit.reset(new SprTwoClassUniformPriorUL90);
00456       cout << "Optimization criterion set to "
00457            << "Inverse of 90% Bayesian upper limit with uniform prior" << endl;
00458       break;
00459     case 8 :
00460       crit.reset(new SprTwoClassBKDiscovery);
00461       cout << "Optimization criterion set to "
00462            << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl;
00463       break;
00464     case 9 :
00465       crit.reset(new SprTwoClassPunzi(bW));
00466       cout << "Optimization criterion set to "
00467            << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl;
00468       break;
00469     default :
00470       cerr << "Unable to make initialization criterion." << endl;
00471       return 3;
00472     }
00473 
00474   // make per-event loss
00475   auto_ptr<SprAverageLoss> loss;
00476   switch( iLoss )
00477     {
00478     case 1 :
00479       loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00480       cout << "Per-event loss set to "
00481            << "Quadratic loss (y-f(x))^2 " << endl;
00482       break;
00483     case 2 :
00484       loss.reset(new SprAverageLoss(&SprLoss::purity_ratio));
00485       cout << "Per-event loss set to "
00486            << "Exponential loss exp(-y*f(x)) " << endl;
00487       break;
00488     case 3 :
00489       loss.reset(new SprAverageLoss(&SprLoss::correct_id,
00490                                 &SprTransformation::continuous01ToDiscrete01));
00491       cout << "Per-event loss set to "
00492            << "Misid rate int(y==f(x)) " << endl;
00493       break;
00494     default :
00495       cout << "No per-event loss is chosen. Will use the default." << endl;
00496       break;
00497     }
00498 
00499   // if cross-validation requested, cross-validate and exit
00500   if( nCross > 0 ) {
00501     // message
00502     cout << "Will cross-validate by dividing training data into " 
00503          << nCross << " subsamples." << endl;
00504     vector<vector<int> > nodeMinSize;
00505 
00506     // decode validation string
00507     if( !nodeValidationString.empty() )
00508       SprStringParser::parseToInts(nodeValidationString.c_str(),nodeMinSize);
00509     else {
00510       nodeMinSize.resize(1);
00511       nodeMinSize[0].push_back(nmin);
00512     }
00513     if( nodeMinSize.empty() || nodeMinSize[0].empty() ) {
00514       cerr << "Unable to determine node size for cross-validation." << endl;
00515       return 4;
00516     }
00517     else {
00518       cout << "Will cross-validate for trees with minimal node sizes: ";
00519       for( int i=0;i<nodeMinSize[0].size();i++ )
00520         cout << nodeMinSize[0][i] << " ";
00521       cout << endl;
00522     }
00523 
00524     // loop over nodes to prepare classifiers
00525     vector<SprAbsClassifier*> classifiers(nodeMinSize[0].size());
00526     for( int i=0;i<nodeMinSize[0].size();i++ ) {
00527       SprDecisionTree* tree1 = 0;
00528       if( useTopdown ) {
00529         bool discrete = false;
00530         tree1 = new SprTopdownTree(filter.get(),crit.get(),
00531                                    nodeMinSize[0][i],discrete);
00532       }
00533       else {
00534         bool discrete = true;
00535         tree1 = new SprDecisionTree(filter.get(),crit.get(),
00536                                     nodeMinSize[0][i],doMerge,discrete);
00537       }
00538       classifiers[i] = tree1;
00539     }
00540 
00541     // cross-validate
00542     vector<double> cvFom;
00543     SprCrossValidator cv(filter.get(),nCross);
00544     if( !cv.validate(crit.get(),loss.get(),classifiers,0,1,
00545                      SprUtils::lowerBound(0.5),cvFom,verbose) ) {
00546       cerr << "Unable to cross-validate." << endl;
00547       for( int j=0;j<classifiers.size();j++ ) {
00548         delete classifiers[j];
00549       }
00550       return 4;
00551     }
00552     else {
00553       cout << "Cross-validated FOMs:" << endl;
00554       for( int i=0;i<cvFom.size();i++ ) {
00555         cout << "Node size=" << setw(8) << nodeMinSize[0][i] 
00556              << "      FOM=" << setw(10) << cvFom[i] << endl;
00557       }
00558     }
00559 
00560     // cleanup
00561     for( int j=0;j<classifiers.size();j++ ) {
00562       delete classifiers[j];
00563     }
00564 
00565     // normal exit
00566     return 0;
00567   }// end cross-validation
00568 
00569   // make decision tree
00570   auto_ptr<SprDecisionTree> tree;
00571   if( useTopdown ) {
00572     bool discrete = false;
00573     tree.reset(new SprTopdownTree(filter.get(),crit.get(),nmin,discrete));
00574   }
00575   else {
00576     tree.reset( new SprDecisionTree(filter.get(),crit.get(),
00577                                     nmin,doMerge,true));
00578     if( countTreeSplits ) tree->startSplitCounter();
00579     tree->setShowBackgroundNodes(true);
00580   }
00581 
00582   // train
00583   if( !tree->train(verbose) ) {
00584     cerr << "Unable to train decision tree." << endl;
00585     return 4;
00586   }
00587   cout << "Finished training decision tree." << endl;
00588 
00589   // save trained decision tree in human-readable format
00590   if( !outHuman.empty() ) {
00591     if( !tree->store(outHuman.c_str()) ) {
00592       cerr << "Cannot store decision tree in file " 
00593            << outHuman.c_str() << endl;
00594       return 5;
00595     }
00596   }
00597 
00598   // print out counted splits
00599   if( countTreeSplits ) tree->printSplitCounter(cout);
00600 
00601   // make trained decision tree
00602   auto_ptr<SprTrainedDecisionTree> trainedTree(tree->makeTrained());
00603 
00604   // save trained tree in machine-readable format
00605   if( !outMachine.empty() ) {
00606     if( !trainedTree->store(outMachine.c_str()) ) {
00607       cerr << "Unable to save trained tree into " 
00608            << outMachine.c_str() << endl;
00609       return 5;
00610     }
00611   }
00612 
00613   // compute FOM for the validation data
00614   if( valFilter.get() != 0 ) {
00615     double wcor0(0), wmis0(0), wcor1(0), wmis1(0);
00616     int ncor0(0), nmis0(0), ncor1(0), nmis1(0);
00617     if( loss.get() != 0 ) loss->reset();
00618     for( int i=0;i<valFilter->size();i++ ) {
00619       const SprPoint* p = (*valFilter.get())[i];
00620       double w = valFilter->w(i);
00621       double resp = trainedTree->response(p->x_);
00622       if( trainedTree->accept(p) ) {
00623         if(      p->class_ == inputClasses[0] ) {
00624           wmis0 += w;
00625           nmis0++;
00626           if( loss.get() != 0 ) loss->update(0,resp,w);
00627         }
00628         else if( p->class_ == inputClasses[1] ) {
00629           wcor1 += w;
00630           ncor1++;
00631           if( loss.get() != 0 ) loss->update(1,resp,w);
00632         }
00633       }
00634       else {
00635         if(      p->class_ == inputClasses[0] ) {
00636           wcor0 += w;
00637           ncor0++;
00638           if( loss.get() != 0 ) loss->update(0,resp,w);
00639         }
00640         else if( p->class_ == inputClasses[1] ) {
00641           wmis1 += w;
00642           nmis1++;
00643           if( loss.get() != 0 ) loss->update(1,resp,w);
00644         }
00645       }
00646     }
00647     double vFom = crit->fom(wcor0,wmis0,wcor1,wmis1);
00648     double vLoss = 0;
00649     if( loss.get() != 0 ) vLoss = loss->value();
00650     cout << "=====================================================" << endl;
00651     cout << "Validation FOM=" << vFom << "  Loss=" << vLoss << endl;
00652     cout << "Content of the signal region:"
00653          << "   W0=" << wmis0 << "  W1=" << wcor1 
00654          << "   N0=" << nmis0 << "  N1=" << ncor1 
00655          << endl;
00656     cout << "=====================================================" << endl;
00657   }
00658 
00659   // make histogram if requested
00660   if( tupleFile.empty() && valHbkFile.empty() ) return 0;
00661 
00662   // make a wrapper to store box numbers
00663   class BoxNumberWrapper : public SprTrainedDecisionTree {
00664   public:
00665     virtual ~BoxNumberWrapper() {}
00666     BoxNumberWrapper(const SprTrainedDecisionTree& tree)
00667       : SprTrainedDecisionTree(tree) {}
00668     double response(const std::vector<double>& v) const {
00669       return this->nBox(v);
00670     }
00671   };
00672 
00673   // feed training data
00674   if( !tupleFile.empty() ) {
00675     // make a writer
00676     auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00677     if( !tuple->init(tupleFile.c_str()) ) {
00678       cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00679       return 6;
00680     }
00681     // wrap
00682     BoxNumberWrapper boxNumber(*(trainedTree.get()));
00683     // feed 
00684     SprDataFeeder feeder(filter.get(),tuple.get());
00685     feeder.addClassifier(trainedTree.get(),"tree");
00686     feeder.addClassifier(&boxNumber,"box");
00687     if( !feeder.feed(1000) ) {
00688       cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00689       return 6;
00690     }
00691   }
00692 
00693   if( !valHbkFile.empty() ) {
00694     // make a writer
00695     auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"test"));
00696     if( !tuple->init(valHbkFile.c_str()) ) {
00697       cerr << "Unable to open output file " << valHbkFile.c_str() << endl;
00698       return 7;
00699     }
00700     // wrap
00701     BoxNumberWrapper boxNumber(*(trainedTree.get()));
00702     // feed 
00703     SprDataFeeder feeder(valFilter.get(),tuple.get());
00704     feeder.addClassifier(trainedTree.get(),"tree");
00705     feeder.addClassifier(&boxNumber,"box");
00706     if( !feeder.feed(1000) ) {
00707       cerr << "Cannot feed data into file " << valHbkFile.c_str() << endl;
00708       return 7;
00709     }
00710   }
00711 
00712   // exit
00713   return 0;
00714 }


Generated on Tue Jun 9 17:55:00 2009 for CMSSW by  doxygen 1.5.4