CMS 3D CMS Logo

SprBaggerDecisionTreeApp.cc File Reference

#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <iomanip>

Go to the source code of this file.

Functions

void help (const char *prog)
int main (int argc, char **argv)


Function Documentation

void help ( const char *  prog  ) 

Definition at line 51 of file SprBaggerDecisionTreeApp.cc.

References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().

00052 {
00053   cout << "Usage:  " << prog 
00054        << " training_data_file" << endl;
00055   cout << "\t Options: " << endl;
00056   cout << "\t-h --- help                                        " << endl;
00057   cout << "\t-j use regular tree instead of faster topdown tree " << endl;
00058   cout << "\t-k discrete decision tree output (default=continuous)"<< endl;
00059   cout << "\t-o output Tuple file                               " << endl;
00060   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00061   cout << "\t-A save output data in ascii instead of Root       " << endl;
00062   cout << "\t-n number of Bagger training cycles                " << endl;
00063   cout << "\t-l minimal number of entries per tree leaf (def=1) " << endl;
00064   cout << "\t-s max number of sampled features (def=0 no sampling)"<< endl;
00065   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00066   cout << "\t-Q apply variable transformation saved in file     " << endl;
00067   cout << "\t-b use a version of Breiman's arc-x4 algorithm     " << endl;
00068   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00069   cout << "\t-f store trained Bagger to file                    " << endl;
00070   cout << "\t-F generate code for AdaBoost and store to file    " << endl;
00071   cout << "\t-c criterion for optimization                      " << endl;
00072   cout << "\t\t 1 = correctly classified fraction               " << endl;
00073   cout << "\t\t 2 = signal significance s/sqrt(s+b)             " << endl;
00074   cout << "\t\t 3 = purity s/(s+b)                              " << endl;
00075   cout << "\t\t 4 = tagger efficiency Q                         " << endl;
00076   cout << "\t\t 5 = Gini index (default)                        " << endl;
00077   cout << "\t\t 6 = cross-entropy                               " << endl;
00078   cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl;
00079   cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b))   " << endl;
00080   cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b))  " << endl;
00081   cout << "\t\t -P background normalization factor for Punzi FOM" << endl;
00082   cout << "\t-g per-event loss for (cross-)validation           " << endl;
00083   cout << "\t\t 1 - quadratic loss (y-f(x))^2                   " << endl;
00084   cout << "\t\t 2 - exponential loss exp(-y*f(x))               " << endl;
00085   cout << "\t\t 3 - misid fraction                              " << endl;
00086   cout << "\t-m replace data values below this cutoff with medians" << endl;
00087   cout << "\t-i count splits on input variables                 " << endl;
00088   cout << "\t-r resume training for Bagger stored in file       " << endl;
00089   cout << "\t-K keep this fraction in training set and          " << endl;
00090   cout << "\t\t put the rest into validation set                " << endl;
00091   cout << "\t-D randomize training set split-up                 " << endl;
00092   cout << "\t-G generate seed from time of day for bootstrap    " << endl;
00093   cout << "\t-t read validation/test data from a file           " << endl;
00094   cout << "\t\t (must be in same format as input data!!!        " << endl;
00095   cout << "\t-d frequency of print-outs for validation data     " << endl;
00096   cout << "\t-w scale all signal weights by this factor         " << endl;
00097   cout << "\t-V include only these input variables              " << endl;
00098   cout << "\t-z exclude input variables from the list           " << endl;
00099   cout << "\t-Z exclude input variables from the list, "
00100        << "but put them in the output file " << endl;
00101   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00102        << endl;
00103   cout << "\t-x cross-validate by splitting data into a given "
00104        << "number of pieces" << endl;
00105   cout << "\t-q a set of minimal node sizes for cross-validation" << endl;
00106   cout << "\t\t Node sizes must be listed in quotes and separated by commas." 
00107        << endl;
00108 }

int main ( int  argc,
char **  argv 
)

Definition at line 111 of file SprBaggerDecisionTreeApp.cc.

References begin, c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, end, lat::endl(), filter, find(), help(), i, j, size, split, t, tree, vars, and weights.

00112 {
00113   // check command line
00114   if( argc < 2 ) {
00115     help(argv[0]);
00116     return 1;
00117   }
00118 
00119   // init
00120   string tupleFile;
00121   int readMode = 0;
00122   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00123   unsigned cycles = 0;
00124   unsigned nmin = 1;
00125   int verbose = 0;
00126   string outFile;
00127   string codeFile;
00128   string resumeFile;
00129   int iCrit = 5;
00130   string valFile;
00131   unsigned valPrint = 0;
00132   bool scaleWeights = false;
00133   double sW = 1.;
00134   int nFeaturesToSample = 0;
00135   bool countTreeSplits = false;
00136   bool setLowCutoff = false;
00137   double lowCutoff = 0;
00138   string includeList, excludeList;
00139   unsigned nCross = 0;
00140   string nodeValidationString;
00141   bool useTopdown = true;
00142   bool discrete = false;
00143   int iLoss = 0;
00144   string inputClassesString;
00145   bool useArcE4 = false;
00146   double bW = 1.;
00147   string stringVarsDoNotFeed;
00148   bool split = false;
00149   double splitFactor = 0;
00150   bool splitRandomize = false;
00151   bool initBootstrapFromTimeOfDay = false;
00152   string transformerFile;
00153 
00154   // decode command line
00155   int c;
00156   extern char* optarg;
00157   //  extern int optind;
00158   while( (c = getopt(argc,argv,"hjko:a:An:l:s:y:Q:bv:f:F:c:P:g:m:ir:K:DGt:d:w:V:z:Z:x:q:")) 
00159          != EOF ) {
00160     switch( c )
00161       {
00162       case 'h' :
00163         help(argv[0]);
00164         return 1;
00165       case 'j' :
00166         useTopdown = false;
00167         break;
00168       case 'k' :
00169         discrete = true;
00170         break;
00171       case 'o' :
00172         tupleFile = optarg;
00173         break;
00174       case 'a' :
00175         readMode = (optarg==0 ? 0 : atoi(optarg));
00176         break;
00177       case 'A' :
00178         writeMode = SprRWFactory::Ascii;
00179         break;
00180       case 'n' :
00181         cycles = (optarg==0 ? 1 : atoi(optarg));
00182         break;
00183       case 'l' :
00184         nmin = (optarg==0 ? 1 : atoi(optarg));
00185         break;
00186       case 's' :
00187         nFeaturesToSample = (optarg==0 ? 0 : atoi(optarg));
00188         break;
00189       case 'y' :
00190         inputClassesString = optarg;
00191         break;
00192       case 'Q' :
00193         transformerFile = optarg;
00194         break;
00195       case 'b' :
00196         useArcE4 = true;
00197         break;
00198       case 'v' :
00199         verbose = (optarg==0 ? 0 : atoi(optarg));
00200         break;
00201       case 'f' :
00202         outFile = optarg;
00203         break;
00204       case 'F' :
00205         codeFile = optarg;
00206         break;
00207       case 'c' :
00208         iCrit = (optarg==0 ? 5 : atoi(optarg));
00209         break;
00210       case 'P' :
00211         bW = (optarg==0 ? 1 : atof(optarg));
00212         break;
00213       case 'g' :
00214         iLoss = (optarg==0 ? 0 : atoi(optarg));
00215         break;
00216       case 'm' :
00217         if( optarg != 0 ) {
00218           setLowCutoff = true;
00219           lowCutoff = atof(optarg);
00220         }
00221         break;
00222       case 'i' :
00223         countTreeSplits = true;
00224         break;
00225       case 'r' :
00226         resumeFile = optarg;
00227         break;
00228       case 'K' :
00229         split = true;
00230         splitFactor = (optarg==0 ? 0 : atof(optarg));
00231         break;
00232       case 'D' :
00233         splitRandomize = true;
00234         break;
00235       case 'G' :
00236         initBootstrapFromTimeOfDay = true;
00237         break;
00238       case 't' :
00239         valFile = optarg;
00240         break;
00241       case 'd' :
00242         valPrint = (optarg==0 ? 0 : atoi(optarg));
00243         break;
00244       case 'w' :
00245         if( optarg != 0 ) {
00246           scaleWeights = true;
00247           sW = atof(optarg);
00248         }
00249         break;
00250       case 'V' :
00251         includeList = optarg;
00252         break;
00253       case 'z' :
00254         excludeList = optarg;
00255         break;
00256       case 'Z' :
00257         stringVarsDoNotFeed = optarg;
00258         break;
00259       case 'x' :
00260         nCross = (optarg==0 ? 0 : atoi(optarg));
00261         break;
00262       case 'q' :
00263         nodeValidationString = optarg;
00264         break;
00265       }
00266   }
00267 
00268   // There has to be 1 argument after all options.
00269   string trFile = argv[argc-1];
00270   if( trFile.empty() ) {
00271     cerr << "No training file is specified." << endl;
00272     return 1;
00273   }
00274 
00275   // make reader
00276   SprRWFactory::DataType inputType 
00277     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00278   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00279 
00280   // include variables
00281   set<string> includeSet;
00282   if( !includeList.empty() ) {
00283     vector<vector<string> > includeVars;
00284     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00285     assert( !includeVars.empty() );
00286     for( int i=0;i<includeVars[0].size();i++ ) 
00287       includeSet.insert(includeVars[0][i]);
00288     if( !reader->chooseVars(includeSet) ) {
00289       cerr << "Unable to include variables in training set." << endl;
00290       return 2;
00291     }
00292     else {
00293       cout << "Following variables have been included in optimization: ";
00294       for( set<string>::const_iterator 
00295              i=includeSet.begin();i!=includeSet.end();i++ )
00296         cout << "\"" << *i << "\"" << " ";
00297       cout << endl;
00298     }
00299   }
00300 
00301   // exclude variables
00302   set<string> excludeSet;
00303   if( !excludeList.empty() ) {
00304     vector<vector<string> > excludeVars;
00305     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00306     assert( !excludeVars.empty() );
00307     for( int i=0;i<excludeVars[0].size();i++ ) 
00308       excludeSet.insert(excludeVars[0][i]);
00309     if( !reader->chooseAllBut(excludeSet) ) {
00310       cerr << "Unable to exclude variables from training set." << endl;
00311       return 2;
00312     }
00313     else {
00314       cout << "Following variables have been excluded from optimization: ";
00315       for( set<string>::const_iterator 
00316              i=excludeSet.begin();i!=excludeSet.end();i++ )
00317         cout << "\"" << *i << "\"" << " ";
00318       cout << endl;
00319     }
00320   }
00321 
00322   // read training data from file
00323   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00324   if( filter.get() == 0 ) {
00325     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00326     return 2;
00327   }
00328   vector<string> vars;
00329   filter->vars(vars);
00330   cout << "Read data from file " << trFile.c_str() 
00331        << " for variables";
00332   for( int i=0;i<vars.size();i++ ) 
00333     cout << " \"" << vars[i].c_str() << "\"";
00334   cout << endl;
00335   cout << "Total number of points read: " << filter->size() << endl;
00336 
00337   // filter training data by class
00338   vector<SprClass> inputClasses;
00339   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00340     cerr << "Cannot choose input classes for string " 
00341          << inputClassesString << endl;
00342     return 2;
00343   }
00344   filter->classes(inputClasses);
00345   assert( inputClasses.size() > 1 );
00346   cout << "Training data filtered by class." << endl;
00347   for( int i=0;i<inputClasses.size();i++ ) {
00348     cout << "Points in class " << inputClasses[i] << ":   " 
00349          << filter->ptsInClass(inputClasses[i]) << endl;
00350   }
00351 
00352   // scale weights
00353   if( scaleWeights ) {
00354     cout << "Signal weights are multiplied by " << sW << endl;
00355     filter->scaleWeights(inputClasses[1],sW);
00356   }
00357 
00358   // apply low cutoff
00359   if( setLowCutoff ) {
00360     if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00361       cerr << "Unable to replace missing values in training data." << endl;
00362       return 2;
00363     }
00364     else
00365       cout << "Values below " << lowCutoff << " in training data"
00366            << " have been replaced with medians." << endl;
00367   }
00368 
00369   // read validation data from file
00370   auto_ptr<SprAbsFilter> valFilter;
00371   if( split && !valFile.empty() ) {
00372     cerr << "Unable to split training data and use validation data "
00373          << "from a separate file." << endl;
00374     return 2;
00375   }
00376   if( split && valPrint!=0 ) {
00377     cout << "Splitting training data with factor " << splitFactor << endl;
00378     if( splitRandomize )
00379       cout << "Will use randomized splitting." << endl;
00380     vector<double> weights;
00381     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00382     if( splitted == 0 ) {
00383       cerr << "Unable to split training data." << endl;
00384       return 2;
00385     }
00386     bool ownData = true;
00387     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00388     cout << "Training data re-filtered:" << endl;
00389     for( int i=0;i<inputClasses.size();i++ ) {
00390       cout << "Points in class " << inputClasses[i] << ":   "
00391            << filter->ptsInClass(inputClasses[i]) << endl;
00392     }
00393   }
00394   if( !valFile.empty() && valPrint!=0 ) {
00395     auto_ptr<SprAbsReader> 
00396       valReader(SprRWFactory::makeReader(inputType,readMode));
00397     if( !includeSet.empty() ) {
00398       if( !valReader->chooseVars(includeSet) ) {
00399         cerr << "Unable to include variables in validation set." << endl;
00400         return 2;
00401       }
00402     }
00403     if( !excludeSet.empty() ) {
00404       if( !valReader->chooseAllBut(excludeSet) ) {
00405         cerr << "Unable to exclude variables from validation set." << endl;
00406         return 2;
00407       }
00408     }
00409     valFilter.reset(valReader->read(valFile.c_str()));
00410     if( valFilter.get() == 0 ) {
00411       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00412       return 2;
00413     }
00414     vector<string> valVars;
00415     valFilter->vars(valVars);
00416     cout << "Read validation data from file " << valFile.c_str() 
00417          << " for variables";
00418     for( int i=0;i<valVars.size();i++ ) 
00419       cout << " \"" << valVars[i].c_str() << "\"";
00420     cout << endl;
00421     cout << "Total number of points read: " << valFilter->size() << endl;
00422   }
00423 
00424   // filter validation data by class
00425   if( valFilter.get() != 0 ) {
00426     if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00427       cerr << "Cannot choose input classes for string " 
00428            << inputClassesString << endl;
00429       return 2;
00430     }
00431     valFilter->classes(inputClasses);
00432     cout << "Validation data filtered by class." << endl;
00433     for( int i=0;i<inputClasses.size();i++ ) {
00434       cout << "Points in class " << inputClasses[i] << ":   " 
00435            << valFilter->ptsInClass(inputClasses[i]) << endl;
00436     }
00437   }
00438 
00439   // scale weights
00440   if( scaleWeights && valFilter.get()!=0 )
00441     valFilter->scaleWeights(inputClasses[1],sW);
00442 
00443   // apply low cutoff
00444   if( setLowCutoff && valFilter.get()!=0 ) {
00445     if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00446       cerr << "Unable to replace missing values in validation data." << endl;
00447       return 2;
00448     }
00449     else
00450       cout << "Values below " << lowCutoff << " in validation data"
00451            << " have been replaced with medians." << endl;
00452   }
00453 
00454   // apply transformation of variables to training and test data
00455   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00456   if( !transformerFile.empty() ) {
00457     SprVarTransformerReader transReader;
00458     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00459     if( t == 0 ) {
00460       cerr << "Unable to read VarTransformer from file "
00461            << transformerFile.c_str() << endl;
00462       return 2;
00463     }
00464     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00465     SprTransformerFilter* t_valid = 0;
00466     if( valFilter.get() != 0 )
00467       t_valid = new SprTransformerFilter(valFilter.get());
00468     bool replaceOriginalData = true;
00469     if( !t_train->transform(t,replaceOriginalData) ) {
00470       cerr << "Unable to apply VarTransformer to training data." << endl;
00471       return 2;
00472     }
00473     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00474       cerr << "Unable to apply VarTransformer to validation data." << endl;
00475       return 2;
00476     }
00477     cout << "Variable transformation from file "
00478          << transformerFile.c_str() << " has been applied to "
00479          << "training and validation data." << endl;
00480     garbage_train.reset(filter.release());
00481     garbage_valid.reset(valFilter.release());
00482     filter.reset(t_train);
00483     valFilter.reset(t_valid);
00484   }
00485 
00486   // make optimization criterion
00487   auto_ptr<SprAbsTwoClassCriterion> crit;
00488   switch( iCrit )
00489     {
00490     case 1 :
00491       crit.reset(new SprTwoClassIDFraction);
00492       cout << "Optimization criterion set to "
00493            << "Fraction of correctly classified events " << endl;
00494       break;
00495     case 2 :
00496       crit.reset(new SprTwoClassSignalSignif);
00497       cout << "Optimization criterion set to "
00498            << "Signal significance S/sqrt(S+B) " << endl;
00499       break;
00500     case 3 :
00501       crit.reset(new SprTwoClassPurity);
00502       cout << "Optimization criterion set to "
00503            << "Purity S/(S+B) " << endl;
00504       break;
00505     case 4 :
00506       crit.reset(new SprTwoClassTaggerEff);
00507       cout << "Optimization criterion set to "
00508            << "Tagging efficiency Q = e*(1-2w)^2 " << endl;
00509       break;
00510     case 5 :
00511       crit.reset(new SprTwoClassGiniIndex);
00512       cout << "Optimization criterion set to "
00513            << "Gini index  -1+p^2+q^2 " << endl;
00514       break;
00515     case 6 :
00516       crit.reset(new SprTwoClassCrossEntropy);
00517       cout << "Optimization criterion set to "
00518            << "Cross-entropy p*log(p)+q*log(q) " << endl;
00519       break;
00520     case 7 :
00521       crit.reset(new SprTwoClassUniformPriorUL90);
00522       cout << "Optimization criterion set to "
00523            << "Inverse of 90% Bayesian upper limit with uniform prior" << endl;
00524       break;
00525     case 8 :
00526       crit.reset(new SprTwoClassBKDiscovery);
00527       cout << "Optimization criterion set to "
00528            << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl;
00529       break;
00530     case 9 :
00531       crit.reset(new SprTwoClassPunzi(bW));
00532       cout << "Optimization criterion set to "
00533            << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl;
00534       break;
00535     default :
00536       cerr << "Unable to make initialization criterion." << endl;
00537       return 3;
00538     }
00539 
00540   // check criterion vs classifier
00541   if( useArcE4 && !crit->symmetric() ) {
00542     cerr << "Unable to use arc-e4 with an asymmetric criterion." << endl;
00543     return 3;
00544   }
00545 
00546   // make per-event loss
00547   auto_ptr<SprAverageLoss> loss;
00548   switch( iLoss )
00549     {
00550     case 1 :
00551       loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00552       cout << "Per-event loss set to "
00553            << "Quadratic loss (y-f(x))^2 " << endl;
00554       break;
00555     case 2 :
00556       loss.reset(new SprAverageLoss(&SprLoss::purity_ratio));
00557       cout << "Per-event loss set to "
00558            << "Exponential loss exp(-y*f(x)) " << endl;
00559       break;
00560     case 3 :
00561       loss.reset(new SprAverageLoss(&SprLoss::correct_id,
00562                                 &SprTransformation::continuous01ToDiscrete01));
00563       cout << "Per-event loss set to "
00564            << "Misid rate int(y==f(x)) " << endl;
00565       break;
00566     default :
00567       cout << "No per-event loss is chosen. Will use the default." << endl;
00568       break;
00569     }
00570 
00571   // make bootstrap for resampling input features
00572   auto_ptr<SprIntegerBootstrap> bootstrap;
00573   if( nFeaturesToSample > filter->dim() ) 
00574     nFeaturesToSample = filter->dim();
00575   if( nFeaturesToSample > 0 ) {
00576     bootstrap.reset(new SprIntegerBootstrap(filter->dim(),nFeaturesToSample));
00577     if( !resumeFile.empty() || initBootstrapFromTimeOfDay ) 
00578       bootstrap->init(-1);
00579   }
00580 
00581   // make decision tree
00582   bool doMerge = !crit->symmetric();
00583   if( doMerge ) useTopdown = false;
00584   auto_ptr<SprDecisionTree> tree;
00585   if( useTopdown ) {
00586     tree.reset(new SprTopdownTree(filter.get(),crit.get(),nmin,
00587                                   discrete,bootstrap.get()));
00588   }
00589   else {
00590     tree.reset(new SprDecisionTree(filter.get(),crit.get(),nmin,doMerge,
00591                                    discrete,bootstrap.get()));
00592   }
00593   if( countTreeSplits ) tree->startSplitCounter();
00594   tree->useFastSort();
00595 
00596   // if cross-validation requested, cross-validate and exit
00597   if( nCross > 0 ) {
00598     // message
00599     cout << "Will cross-validate by dividing training data into " 
00600          << nCross << " subsamples." << endl;
00601     vector<vector<int> > nodeMinSize;
00602 
00603     // decode validation string
00604     if( !nodeValidationString.empty() )
00605       SprStringParser::parseToInts(nodeValidationString.c_str(),nodeMinSize);
00606     else {
00607       nodeMinSize.resize(1);
00608       nodeMinSize[0].push_back(nmin);
00609     }
00610     if( nodeMinSize.empty() || nodeMinSize[0].empty() ) {
00611       cerr << "Unable to determine node size for cross-validation." << endl;
00612       return 4;
00613     }
00614     else {
00615       cout << "Will cross-validate for trees with minimal node sizes: ";
00616       for( int i=0;i<nodeMinSize[0].size();i++ )
00617         cout << nodeMinSize[0][i] << " ";
00618       cout << endl;
00619     }
00620 
00621     // loop over nodes to prepare classifiers
00622     vector<SprDecisionTree*> trees(nodeMinSize[0].size());
00623     vector<SprAbsClassifier*> classifiers(nodeMinSize[0].size());
00624     for( int i=0;i<nodeMinSize[0].size();i++ ) {
00625       SprDecisionTree* tree1 = 0;
00626       if( useTopdown ) {
00627         tree1 = new SprTopdownTree(filter.get(),crit.get(),nodeMinSize[0][i],
00628                                    discrete,bootstrap.get());
00629       }
00630       else {
00631         tree1 = new SprDecisionTree(filter.get(),crit.get(),nodeMinSize[0][i],
00632                                     doMerge,discrete,bootstrap.get());
00633       }
00634       tree1->useFastSort();
00635       SprBagger* bagger1 = 0;
00636       if( useArcE4 )
00637         bagger1 = new SprArcE4(filter.get(),cycles,discrete);
00638       else
00639         bagger1 = new SprBagger(filter.get(),cycles,discrete);
00640       if( initBootstrapFromTimeOfDay 
00641           && !bagger1->initBootstrapFromTimeOfDay() ) {
00642         cerr << "Unable to generate seed from time of day for Bagger." << endl;
00643         return 4;
00644       }
00645       if( !bagger1->addTrainable(tree1) ) {
00646         cerr << "Unable to add decision tree to Bagger for CV." << endl;
00647         for( int j=0;j<trees.size();j++ ) {
00648           delete trees[j];
00649           delete classifiers[j];
00650         }
00651         return 4;
00652       }
00653       trees[i] = tree1;
00654       classifiers[i] = bagger1;
00655     }
00656 
00657     // cross-validate
00658     vector<double> cvFom;
00659     SprCrossValidator cv(filter.get(),nCross);
00660     if( !cv.validate(crit.get(),loss.get(),classifiers,
00661                      inputClasses[0],inputClasses[1],
00662                      SprUtils::lowerBound(0.5),cvFom,verbose) ) {
00663       cerr << "Unable to cross-validate." << endl;
00664       for( int j=0;j<trees.size();j++ ) {
00665         delete trees[j];
00666         delete classifiers[j];
00667       }
00668       return 4;
00669     }
00670     else {
00671       cout << "Cross-validated FOMs:" << endl;
00672       for( int i=0;i<cvFom.size();i++ ) {
00673         cout << "Node size=" << setw(8) << nodeMinSize[0][i] 
00674              << "      FOM=" << setw(10) << cvFom[i] << endl;
00675       }
00676     }
00677 
00678     // cleanup
00679     for( int j=0;j<trees.size();j++ ) {
00680       delete trees[j];
00681       delete classifiers[j];
00682     }
00683 
00684     // normal exit
00685     return 0;
00686   }// end cross-validation
00687 
00688   // make Bagger
00689   auto_ptr<SprBagger> bagger;
00690   if( useArcE4 )
00691     bagger.reset(new SprArcE4(filter.get(),cycles,discrete));
00692   else
00693     bagger.reset(new SprBagger(filter.get(),cycles,discrete));
00694 
00695   // set seed for bootstrap if necessary
00696   if( initBootstrapFromTimeOfDay && !bagger->initBootstrapFromTimeOfDay() ) {
00697     cerr << "Unable to generate seed from time of day for Bagger." << endl;
00698     return 4;
00699   }
00700 
00701   // set validation
00702   if( valFilter.get()!=0 && !valFilter->empty() )
00703     bagger->setValidation(valFilter.get(),valPrint,crit.get(),loss.get());
00704 
00705   // read saved Bagger from file
00706   if( !resumeFile.empty() ) {
00707     if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00708                                             bagger.get(),verbose) ) {
00709       cerr << "Failed to read saved Bagger from file "
00710            << resumeFile.c_str() << endl;
00711       return 5;
00712     }
00713     cout << "Read saved Bagger from file " << resumeFile.c_str()
00714          << " with " << bagger->nTrained() << " trained classifiers." 
00715          << endl;
00716   }
00717 
00718   // add trainable tree
00719   if( !bagger->addTrainable(tree.get()) ) {
00720     cerr << "Unable to add decision tree to Bagger." << endl;
00721     return 6;
00722   }
00723 
00724   // train
00725   if( !bagger->train(verbose) )
00726     cerr << "Bagger terminated with error." << endl;
00727   if( bagger->nTrained() == 0 ) {
00728     cerr << "Unable to train Bagger." << endl;
00729     return 7;
00730   }
00731   else {
00732     cout << "Bagger finished training with " << bagger->nTrained() 
00733          << " classifiers." << endl;
00734   }
00735 
00736   // save trained Bagger
00737   if( !outFile.empty() ) {
00738     if( !bagger->store(outFile.c_str()) ) {
00739       cerr << "Cannot store Bagger in file " << outFile.c_str() << endl;
00740       return 8;
00741     }
00742   }
00743 
00744   // print out counted splits
00745   if( countTreeSplits ) tree->printSplitCounter(cout);
00746 
00747   // make a trained Bagger
00748   auto_ptr<SprTrainedBagger> trainedBagger(bagger->makeTrained());
00749   if( trainedBagger.get() == 0 ) {
00750     cerr << "Unable to get trained Bagger." << endl;
00751     return 7;
00752   }
00753 
00754   // store code into file
00755   if( !codeFile.empty() ) {
00756     if( !trainedBagger->storeCode(codeFile.c_str()) ) {
00757       cerr << "Unable to store code for trained Bagger." << endl;
00758       return 8;
00759     }
00760   }
00761 
00762   // make histogram if requested
00763   if( tupleFile.empty() ) 
00764     return 0;
00765 
00766   // make a writer
00767   auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00768   if( !tuple->init(tupleFile.c_str()) ) {
00769     cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00770     return 9;
00771   }
00772 
00773   // determine if certain variables are to be excluded from usage,
00774   // but included in the output storage file (-Z option)
00775   string printVarsDoNotFeed;
00776   vector<vector<string> > varsDoNotFeed;
00777   SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00778   vector<unsigned> mapper;
00779   for( int d=0;d<vars.size();d++ ) {
00780     if( varsDoNotFeed.empty() ||
00781         (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00782          ==varsDoNotFeed[0].end()) ) {
00783       mapper.push_back(d);
00784     }
00785     else {
00786       printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00787       printVarsDoNotFeed += vars[d];
00788     }
00789   }
00790   if( !printVarsDoNotFeed.empty() ) {
00791     cout << "The following variables are not used in the algorithm, " 
00792          << "but will be included in the output file: " 
00793          << printVarsDoNotFeed.c_str() << endl;
00794   }
00795 
00796   // feed
00797   SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00798   feeder.addClassifier(trainedBagger.get(),"bag");
00799   if( !feeder.feed(1000) ) {
00800     cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00801     return 10;
00802   }
00803 
00804   // exit
00805   return 0;
00806 }


Generated on Tue Jun 9 17:55:00 2009 for CMSSW by  doxygen 1.5.4