CMS 3D CMS Logo

SprAdaBoostDecisionTreeApp.cc File Reference

#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <iomanip>

Go to the source code of this file.

Functions

void help (const char *prog)
int main (int argc, char **argv)


Function Documentation

void help ( const char *  prog  ) 

Definition at line 44 of file SprAdaBoostDecisionTreeApp.cc.

References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().

00045 {
00046   cout << "Usage:  " << prog 
00047        << " training_data_file" << endl;
00048   cout << "\t Options: " << endl;
00049   cout << "\t-h --- help                                        " << endl;
00050   cout << "\t-j use regular tree instead of faster topdown tree " << endl;
00051   cout << "\t-M AdaBoost mode                                   " << endl;
00052   cout << "\t\t 1 = Discrete AdaBoost (default)                 " << endl;
00053   cout << "\t\t 2 = Real AdaBoost                               " << endl;
00054   cout << "\t\t 3 = Epsilon AdaBoost                            " << endl;
00055   cout << "\t-E epsilon for Epsilon and Real AdaBoosts (def=0.01)" << endl;
00056   cout << "\t-o output Tuple file                               " << endl;
00057   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00058   cout << "\t-A save output data in ascii instead of Root       " << endl;
00059   cout << "\t-n number of AdaBoost training cycles              " << endl;
00060   cout << "\t-l minimal number of entries per tree leaf (def=1) " << endl;
00061   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00062   cout << "\t-Q apply variable transformation saved in file     " << endl;
00063   cout << "\t-c criterion for optimization                      " << endl;
00064   cout << "\t\t 1 = correctly classified fraction               " << endl;
00065   cout << "\t\t 5 = Gini index (default)                        " << endl;
00066   cout << "\t\t 6 = cross-entropy                               " << endl;
00067   cout << "\t-g per-event loss for (cross-)validation           " << endl;
00068   cout << "\t\t 1 - quadratic loss (y-f(x))^2                   " << endl;
00069   cout << "\t\t 2 - exponential loss exp(-y*f(x))               " << endl;
00070   cout << "\t\t 3 - misid fraction                              " << endl;
00071   cout << "\t-b max number of sampled features (def=0 no sampling)" << endl;
00072   cout << "\t-m replace data values below this cutoff with medians" << endl;
00073   cout << "\t-i count splits on input variables                 " << endl;
00074   cout << "\t-s use standard AdaBoost (see SprTrainedAdaBoost.hh)"<< endl;
00075   cout << "\t-e skip initial event reweighting when resuming    " << endl;
00076   cout << "\t-u store data with modified weights to file        " << endl;
00077   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00078   cout << "\t-f store trained AdaBoost to file                  " << endl;
00079   cout << "\t-F generate code for AdaBoost and store to file    " << endl;
00080   cout << "\t-r resume training for AdaBoost stored in file     " << endl;
00081   cout << "\t-K keep this fraction in training set and          " << endl;
00082   cout << "\t\t put the rest into validation set                " << endl;
00083   cout << "\t-D randomize training set split-up                 " << endl;
00084   cout << "\t-t read validation/test data from a file           " << endl;
00085   cout << "\t\t (must be in same format as input data!!!        " << endl;
00086   cout << "\t-d frequency of print-outs for validation data     " << endl;
00087   cout << "\t-w scale all signal weights by this factor         " << endl;
00088   cout << "\t-V include only these input variables              " << endl;
00089   cout << "\t-z exclude input variables from the list           " << endl;
00090   cout << "\t-Z exclude input variables from the list, "
00091        << "but put them in the output file " << endl;
00092   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00093        << endl;
00094   cout << "\t-x cross-validate by splitting data into a given "
00095        << "number of pieces" << endl;
00096   cout << "\t-q a set of minimal node sizes for cross-validation" << endl;
00097   cout << "\t\t Node sizes must be listed in quotes and separated by commas." 
00098        << endl;
00099 }

int main ( int  argc,
char **  argv 
)

Definition at line 102 of file SprAdaBoostDecisionTreeApp.cc.

References begin, c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, end, lat::endl(), HLT_VtxMuL3::Epsilon, geometryDiff::epsilon, filter, find(), help(), i, j, size, split, t, tree, vars, and weights.

00103 {
00104   // check command line
00105   if( argc < 2 ) {
00106     help(argv[0]);
00107     return 1;
00108   }
00109 
00110   // init
00111   string tupleFile;
00112   int readMode = 0;
00113   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00114   unsigned cycles = 0;
00115   unsigned nmin = 1;
00116   int iCrit = 5;
00117   int verbose = 0;
00118   string outFile;
00119   string codeFile;
00120   string resumeFile;
00121   string valFile;
00122   unsigned valPrint = 0;
00123   bool scaleWeights = false;
00124   double sW = 1.;
00125   bool countTreeSplits = false;
00126   bool useStandardAB = false;
00127   int iAdaBoostMode = 1;
00128   double epsilon = 0.01;
00129   bool skipInitialEventReweighting = false;
00130   string weightedDataOut;
00131   bool setLowCutoff = false;
00132   double lowCutoff = 0;
00133   string includeList, excludeList;
00134   unsigned nCross = 0;
00135   string nodeValidationString;
00136   int nFeaturesToSample = 0;
00137   bool bagInput = false;
00138   bool useTopdown = true;
00139   int iLoss = 0;
00140   string inputClassesString;
00141   string stringVarsDoNotFeed;
00142   bool split = false;
00143   double splitFactor = 0;
00144   bool splitRandomize = false;
00145   string transformerFile;
00146 
00147   // decode command line
00148   int c;
00149   extern char* optarg;
00150   //  extern int optind;
00151   while((c = getopt(argc,argv,"hjM:E:o:a:An:l:y:Q:c:g:b:m:iseu:v:f:F:r:K:Dt:d:w:V:z:Z:x:q:")) != EOF ) {
00152     switch( c )
00153       {
00154       case 'h' :
00155         help(argv[0]);
00156         return 1;
00157       case 'j' :
00158         useTopdown = false;
00159         break;
00160       case 'M' :
00161         iAdaBoostMode = (optarg==0 ? 1 : atoi(optarg));
00162         break;
00163       case 'E' :
00164         epsilon = (optarg==0 ? 0.01 : atof(optarg));
00165         break;
00166       case 'o' :
00167         tupleFile = optarg;
00168         break;
00169       case 'a' :
00170         readMode = (optarg==0 ? 0 : atoi(optarg));
00171         break;
00172       case 'A' :
00173         writeMode = SprRWFactory::Ascii;
00174         break;
00175       case 'n' :
00176         cycles = (optarg==0 ? 1 : atoi(optarg));
00177         break;
00178       case 'l' :
00179         nmin = (optarg==0 ? 1 : atoi(optarg));
00180         break;
00181       case 'y' :
00182         inputClassesString = optarg;
00183         break;
00184       case 'Q' :
00185         transformerFile = optarg;
00186         break;
00187       case 'c' :
00188         iCrit = (optarg==0 ? 5 : atoi(optarg));
00189         break;
00190       case 'g' :
00191         iLoss = (optarg==0 ? 0 : atoi(optarg));
00192         break;
00193       case 'b' :
00194         bagInput = true;
00195         nFeaturesToSample = (optarg==0 ? 0 : atoi(optarg));
00196         break;
00197       case 'm' :
00198         if( optarg != 0 ) {
00199           setLowCutoff = true;
00200           lowCutoff = atof(optarg);
00201         }
00202         break;
00203       case 'i' :
00204         countTreeSplits = true;
00205         break;
00206       case 's' :
00207         useStandardAB = true;
00208         break;
00209       case 'e' :
00210         skipInitialEventReweighting = true;
00211         break;
00212       case 'u' :
00213         weightedDataOut = optarg;
00214         break;
00215       case 'v' :
00216         verbose = (optarg==0 ? 0 : atoi(optarg));
00217         break;
00218       case 'f' :
00219         outFile = optarg;
00220         break;
00221       case 'F' :
00222         codeFile = optarg;
00223         break;
00224       case 'r' :
00225         resumeFile = optarg;
00226         break;
00227       case 'K' :
00228         split = true;
00229         splitFactor = (optarg==0 ? 0 : atof(optarg));
00230         break;
00231       case 'D' :
00232         splitRandomize = true;
00233         break;
00234       case 't' :
00235         valFile = optarg;
00236         break;
00237       case 'd' :
00238         valPrint = (optarg==0 ? 0 : atoi(optarg));
00239         break;
00240       case 'w' :
00241         if( optarg != 0 ) {
00242           scaleWeights = true;
00243           sW = atof(optarg);
00244         }
00245         break;
00246       case 'V' :
00247         includeList = optarg;
00248         break;
00249       case 'z' :
00250         excludeList = optarg;
00251         break;
00252       case 'Z' :
00253         stringVarsDoNotFeed = optarg;
00254         break;
00255       case 'x' :
00256         nCross = (optarg==0 ? 0 : atoi(optarg));
00257         break;
00258       case 'q' :
00259         nodeValidationString = optarg;
00260         break;
00261       }
00262   }
00263 
00264   // There has to be 1 argument after all options.
00265   string trFile = argv[argc-1];
00266   if( trFile.empty() ) {
00267     cerr << "No training file is specified." << endl;
00268     return 1;
00269   }
00270 
00271   // make reader
00272   SprRWFactory::DataType inputType 
00273     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00274   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00275 
00276   // include variables
00277   set<string> includeSet;
00278   if( !includeList.empty() ) {
00279     vector<vector<string> > includeVars;
00280     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00281     assert( !includeVars.empty() );
00282     for( int i=0;i<includeVars[0].size();i++ ) 
00283       includeSet.insert(includeVars[0][i]);
00284     if( !reader->chooseVars(includeSet) ) {
00285       cerr << "Unable to include variables in training set." << endl;
00286       return 2;
00287     }
00288     else {
00289       cout << "Following variables have been included in optimization: ";
00290       for( set<string>::const_iterator 
00291              i=includeSet.begin();i!=includeSet.end();i++ )
00292         cout << "\"" << *i << "\"" << " ";
00293       cout << endl;
00294     }
00295   }
00296 
00297   // exclude variables
00298   set<string> excludeSet;
00299   if( !excludeList.empty() ) {
00300     vector<vector<string> > excludeVars;
00301     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00302     assert( !excludeVars.empty() );
00303     for( int i=0;i<excludeVars[0].size();i++ ) 
00304       excludeSet.insert(excludeVars[0][i]);
00305     if( !reader->chooseAllBut(excludeSet) ) {
00306       cerr << "Unable to exclude variables from training set." << endl;
00307       return 2;
00308     }
00309     else {
00310       cout << "Following variables have been excluded from optimization: ";
00311       for( set<string>::const_iterator 
00312              i=excludeSet.begin();i!=excludeSet.end();i++ )
00313         cout << "\"" << *i << "\"" << " ";
00314       cout << endl;
00315     }
00316   }
00317 
00318   // read training data from file
00319   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00320   if( filter.get() == 0 ) {
00321     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00322     return 2;
00323   }
00324   vector<string> vars;
00325   filter->vars(vars);
00326   cout << "Read data from file " << trFile.c_str() 
00327        << " for variables";
00328   for( int i=0;i<vars.size();i++ ) 
00329     cout << " \"" << vars[i].c_str() << "\"";
00330   cout << endl;
00331   cout << "Total number of points read: " << filter->size() << endl;
00332 
00333   // filter training data by class
00334   vector<SprClass> inputClasses;
00335   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00336     cerr << "Cannot choose input classes for string " 
00337          << inputClassesString << endl;
00338     return 2;
00339   }
00340   filter->classes(inputClasses);
00341   assert( inputClasses.size() > 1 );
00342   cout << "Training data filtered by class." << endl;
00343   for( int i=0;i<inputClasses.size();i++ ) {
00344     cout << "Points in class " << inputClasses[i] << ":   " 
00345          << filter->ptsInClass(inputClasses[i]) << endl;
00346   }
00347 
00348   // scale weights
00349   if( scaleWeights ) {
00350     cout << "Signal weights are multiplied by " << sW << endl;
00351     filter->scaleWeights(inputClasses[1],sW);
00352   }
00353 
00354   // apply low cutoff
00355   if( setLowCutoff ) {
00356     if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00357       cerr << "Unable to replace missing values in training data." << endl;
00358       return 2;
00359     }
00360     else
00361       cout << "Values below " << lowCutoff << " in training data"
00362            << " have been replaced with medians." << endl;
00363   }
00364 
00365   // read validation data from file
00366   auto_ptr<SprAbsFilter> valFilter;
00367   if( split && !valFile.empty() ) {
00368     cerr << "Unable to split training data and use validation data " 
00369          << "from a separate file." << endl;
00370     return 2;
00371   }
00372   if( split && valPrint!=0 ) {
00373     cout << "Splitting training data with factor " << splitFactor << endl;
00374     if( splitRandomize )
00375       cout << "Will use randomized splitting." << endl;
00376     vector<double> weights;
00377     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00378     if( splitted == 0 ) {
00379       cerr << "Unable to split training data." << endl;
00380       return 2;
00381     }
00382     bool ownData = true;
00383     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00384     cout << "Training data re-filtered:" << endl;
00385     for( int i=0;i<inputClasses.size();i++ ) {
00386       cout << "Points in class " << inputClasses[i] << ":   " 
00387            << filter->ptsInClass(inputClasses[i]) << endl;
00388     }
00389   }
00390   if( !valFile.empty() && valPrint!=0 ) {
00391     auto_ptr<SprAbsReader> 
00392       valReader(SprRWFactory::makeReader(inputType,readMode));
00393     if( !includeSet.empty() ) {
00394       if( !valReader->chooseVars(includeSet) ) {
00395         cerr << "Unable to include variables in validation set." << endl;
00396         return 2;
00397       }
00398     }
00399     if( !excludeSet.empty() ) {
00400       if( !valReader->chooseAllBut(excludeSet) ) {
00401         cerr << "Unable to exclude variables from validation set." << endl;
00402         return 2;
00403       }
00404     }
00405     valFilter.reset(valReader->read(valFile.c_str()));
00406     if( valFilter.get() == 0 ) {
00407       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00408       return 2;
00409     }
00410     vector<string> valVars;
00411     valFilter->vars(valVars);
00412     cout << "Read validation data from file " << valFile.c_str() 
00413          << " for variables";
00414     for( int i=0;i<valVars.size();i++ ) 
00415       cout << " \"" << valVars[i].c_str() << "\"";
00416     cout << endl;
00417     cout << "Total number of points read: " << valFilter->size() << endl;
00418     cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0])
00419          << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl;
00420   }
00421 
00422   // filter validation data by class
00423   if( valFilter.get() != 0 ) {
00424     if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00425       cerr << "Cannot choose input classes for string " 
00426            << inputClassesString << endl;
00427       return 2;
00428     }
00429     valFilter->classes(inputClasses);
00430     cout << "Validation data filtered by class." << endl;
00431     for( int i=0;i<inputClasses.size();i++ ) {
00432       cout << "Points in class " << inputClasses[i] << ":   " 
00433            << valFilter->ptsInClass(inputClasses[i]) << endl;
00434     }
00435   }
00436 
00437   // scale weights
00438   if( scaleWeights && valFilter.get()!=0 )
00439     valFilter->scaleWeights(inputClasses[1],sW);
00440 
00441   // apply low cutoff
00442   if( setLowCutoff && valFilter.get()!=0 ) {
00443     if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00444       cerr << "Unable to replace missing values in validation data." << endl;
00445       return 2;
00446     }
00447     else
00448       cout << "Values below " << lowCutoff << " in validation data"
00449            << " have been replaced with medians." << endl;
00450   }
00451 
00452   // apply transformation of variables to training and test data
00453   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00454   if( !transformerFile.empty() ) {
00455     SprVarTransformerReader transReader;
00456     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00457     if( t == 0 ) {
00458       cerr << "Unable to read VarTransformer from file "
00459            << transformerFile.c_str() << endl;
00460       return 2;
00461     }
00462     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00463     SprTransformerFilter* t_valid = 0;
00464     if( valFilter.get() != 0 )
00465       t_valid = new SprTransformerFilter(valFilter.get());
00466     bool replaceOriginalData = true;
00467     if( !t_train->transform(t,replaceOriginalData) ) {
00468       cerr << "Unable to apply VarTransformer to training data." << endl;
00469       return 2;
00470     }
00471     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00472       cerr << "Unable to apply VarTransformer to validation data." << endl;
00473       return 2;
00474     }
00475     cout << "Variable transformation from file "
00476          << transformerFile.c_str() << " has been applied to "
00477          << "training and validation data." << endl;
00478     garbage_train.reset(filter.release());
00479     garbage_valid.reset(valFilter.release());
00480     filter.reset(t_train);
00481     valFilter.reset(t_valid);
00482   }
00483 
00484   // make optimization criterion
00485   auto_ptr<SprAbsTwoClassCriterion> crit;
00486   switch( iCrit )
00487     {
00488     case 1 :
00489       crit.reset(new SprTwoClassIDFraction);
00490       cout << "Optimization criterion set to "
00491            << "Fraction of correctly classified events " << endl;
00492       break;
00493     case 5 :
00494       crit.reset(new SprTwoClassGiniIndex);
00495       cout << "Optimization criterion set to "
00496            << "Gini index  -1+p^2+q^2 " << endl;
00497       break;
00498     case 6 :
00499       crit.reset(new SprTwoClassCrossEntropy);
00500       cout << "Optimization criterion set to "
00501            << "Cross-entropy p*log(p)+q*log(q) " << endl;
00502       break;
00503     default :
00504       cerr << "Unable to make initialization criterion." << endl;
00505       return 3;
00506     }
00507 
00508   // make per-event loss
00509   auto_ptr<SprAverageLoss> loss;
00510   switch( iLoss )
00511     {
00512     case 1 :
00513       loss.reset(new SprAverageLoss(&SprLoss::quadratic,
00514                                     &SprTransformation::logit));
00515       cout << "Per-event loss set to "
00516            << "Quadratic loss (y-f(x))^2 " << endl;
00517       break;
00518     case 2 :
00519       loss.reset(new SprAverageLoss(&SprLoss::exponential));
00520       cout << "Per-event loss set to "
00521            << "Exponential loss exp(-y*f(x)) " << endl;
00522       break;
00523     case 3 :
00524       loss.reset(new SprAverageLoss(&SprLoss::correct_id,
00525                                &SprTransformation::inftyRangeToDiscrete01));
00526       cout << "Per-event loss set to "
00527            << "Misid rate int(y==f(x)) " << endl;
00528       break;
00529     default :
00530       cout << "No per-event loss is chosen. Will use the default." << endl;
00531       break;
00532     }
00533 
00534   // make AdaBoost mode
00535   SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete;
00536   switch( iAdaBoostMode )
00537     {
00538     case 1 :
00539       abMode = SprTrainedAdaBoost::Discrete;
00540       cout << "Will train Discrete AdaBoost." << endl;
00541       break;
00542     case 2 :
00543       abMode = SprTrainedAdaBoost::Real;
00544       cout << "Will train Real AdaBoost." << endl;
00545       break;
00546     case 3 :
00547       abMode = SprTrainedAdaBoost::Epsilon;
00548       cout << "Will train Epsilon AdaBoost." << endl;
00549       break;
00550    default :
00551       cout << "Will train Discrete AdaBoost." << endl;
00552       break;
00553     }
00554 
00555   // make bootstrap for resampling input features
00556   auto_ptr<SprIntegerBootstrap> bootstrap;
00557   if( nFeaturesToSample > filter->dim() ) 
00558     nFeaturesToSample = filter->dim();
00559   if( nFeaturesToSample > 0 ) {
00560     bootstrap.reset(new SprIntegerBootstrap(filter->dim(),nFeaturesToSample));
00561     if( !resumeFile.empty() ) bootstrap->init(-1);
00562   }
00563 
00564   // make decision tree
00565   bool discrete = true;
00566   if( abMode==SprTrainedAdaBoost::Real ) discrete = false;
00567   bool doMerge = false;
00568   auto_ptr<SprDecisionTree> tree;
00569   if( useTopdown )
00570     tree.reset(new SprTopdownTree(filter.get(),crit.get(),
00571                                   nmin,discrete,bootstrap.get()));
00572   else
00573     tree.reset(new SprDecisionTree(filter.get(),crit.get(),
00574                                    nmin,doMerge,discrete,bootstrap.get()));
00575   if( countTreeSplits ) tree->startSplitCounter();
00576   if( abMode == SprTrainedAdaBoost::Real ) tree->forceMixedNodes();
00577   tree->useFastSort();
00578 
00579   // if cross-validation requested, cross-validate and exit
00580   if( nCross > 0 ) {
00581     // message
00582     cout << "Will cross-validate by dividing training data into " 
00583          << nCross << " subsamples." << endl;
00584     vector<vector<int> > nodeMinSize;
00585 
00586     // decode validation string
00587     if( !nodeValidationString.empty() )
00588       SprStringParser::parseToInts(nodeValidationString.c_str(),nodeMinSize);
00589     else {
00590       nodeMinSize.resize(1);
00591       nodeMinSize[0].push_back(nmin);
00592     }
00593     if( nodeMinSize.empty() || nodeMinSize[0].empty() ) {
00594       cerr << "Unable to determine node size for cross-validation." << endl;
00595       return 4;
00596     }
00597     else {
00598       cout << "Will cross-validate for trees with minimal node sizes: ";
00599       for( int i=0;i<nodeMinSize[0].size();i++ )
00600         cout << nodeMinSize[0][i] << " ";
00601       cout << endl;
00602     }
00603 
00604     // loop over nodes to prepare classifiers
00605     vector<SprDecisionTree*> trees(nodeMinSize[0].size());
00606     vector<SprAbsClassifier*> classifiers(nodeMinSize[0].size());
00607     for( int i=0;i<nodeMinSize[0].size();i++ ) {
00608       SprDecisionTree* tree1 = 0;
00609       if( useTopdown )
00610         tree1 = 
00611           new SprTopdownTree(filter.get(),crit.get(),
00612                              nodeMinSize[0][i],
00613                              discrete,bootstrap.get());
00614       else
00615         tree1 =
00616           new SprDecisionTree(filter.get(),crit.get(),
00617                               nodeMinSize[0][i],doMerge,
00618                               discrete,bootstrap.get());
00619       if( abMode == SprTrainedAdaBoost::Real ) tree1->forceMixedNodes();
00620       tree1->useFastSort();
00621       SprAdaBoost* ab1 = new SprAdaBoost(filter.get(),cycles,
00622                                          useStandardAB,abMode,bagInput);
00623       cout << "Setting epsilon to " << epsilon << endl;
00624       ab1->setEpsilon(epsilon);
00625       if( !ab1->addTrainable(tree1,SprUtils::lowerBound(0.5)) ) {
00626         cerr << "Unable to add decision tree to AdaBoost for CV." << endl;
00627         for( int j=0;j<trees.size();j++ ) {
00628           delete trees[j];
00629           delete classifiers[j];
00630         }
00631         return 4;
00632       }
00633       trees[i] = tree1;
00634       classifiers[i] = ab1;
00635     }
00636 
00637     // cross-validate
00638     vector<double> cvFom;
00639     SprCrossValidator cv(filter.get(),nCross);
00640     if( !cv.validate(crit.get(),loss.get(),classifiers,
00641                      inputClasses[0],inputClasses[1],
00642                      SprUtils::lowerBound(0.5),cvFom,verbose) ) {
00643       cerr << "Unable to cross-validate." << endl;
00644       for( int j=0;j<trees.size();j++ ) {
00645         delete trees[j];
00646         delete classifiers[j];
00647       }
00648       return 4;
00649     }
00650     else {
00651       cout << "Cross-validated FOMs:" << endl;
00652       for( int i=0;i<cvFom.size();i++ ) {
00653         cout << "Node size=" << setw(8) << nodeMinSize[0][i] 
00654              << "      FOM=" << setw(10) << cvFom[i] << endl;
00655       }
00656     }
00657 
00658     // cleanup
00659     for( int j=0;j<trees.size();j++ ) {
00660       delete trees[j];
00661       delete classifiers[j];
00662     }
00663 
00664     // normal exit
00665     return 0;
00666   }// end cross-validation
00667 
00668   // make AdaBoost
00669   SprAdaBoost ab(filter.get(),cycles,useStandardAB,abMode,bagInput);
00670   cout << "Setting epsilon to " << epsilon << endl;
00671   ab.setEpsilon(epsilon);
00672 
00673   // set validation
00674   if( valFilter.get()!=0 && !valFilter->empty() )
00675     ab.setValidation(valFilter.get(),valPrint,loss.get());
00676 
00677   // read saved Boost from file
00678   if( !resumeFile.empty() ) {
00679     if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00680                                             &ab,verbose) ) {
00681       cerr << "Failed to read saved AdaBoost from file " 
00682            << resumeFile.c_str() << endl;
00683       return 5;
00684     }
00685     cout << "Read saved AdaBoost from file " << resumeFile.c_str()
00686          << " with " << ab.nTrained() << " trained classifiers." << endl;
00687   }
00688   if( skipInitialEventReweighting ) ab.skipInitialEventReweighting(true);
00689 
00690   // add trainable tree
00691   if( !ab.addTrainable(tree.get(),SprUtils::lowerBound(0.5)) ) {
00692     cerr << "Unable to add decision tree to AdaBoost." << endl;
00693     return 6;
00694   }
00695 
00696   // train
00697   if( !ab.train(verbose) )
00698     cerr << "AdaBoost terminated with error." << endl;
00699   if( ab.nTrained() == 0 ) {
00700     cerr << "Unable to train AdaBoost." << endl;
00701     return 7;
00702   }
00703   else {
00704     cout << "AdaBoost finished training with " << ab.nTrained() 
00705          << " classifiers." << endl;
00706   }
00707 
00708   // save trained AdaBoost
00709   if( !outFile.empty() ) {
00710     if( !ab.store(outFile.c_str()) ) {
00711       cerr << "Cannot store AdaBoost in file " << outFile.c_str() << endl;
00712       return 8;
00713     }
00714   }
00715 
00716   // print out counted splits
00717   if( countTreeSplits ) tree->printSplitCounter(cout);
00718 
00719   // save reweighted data
00720   if( !weightedDataOut.empty() ) {
00721     if( !ab.storeData(weightedDataOut.c_str()) ) {
00722       cerr << "Cannot store weighted AdaBoost data to file " 
00723            << weightedDataOut.c_str() << endl;
00724       return 9;
00725     }
00726   }
00727 
00728   // make a trained AdaBoost
00729   auto_ptr<SprTrainedAdaBoost> trainedAda(ab.makeTrained());
00730   if( trainedAda.get() == 0 ) {
00731     cerr << "Unable to get trained AdaBoost." << endl;
00732     return 7;
00733   }
00734 
00735   // store code into file
00736   if( !codeFile.empty() ) {
00737     if( !trainedAda->storeCode(codeFile.c_str()) ) {
00738       cerr << "Unable to store code for trained AdaBoost." << endl;
00739       return 8;
00740     }
00741   }
00742 
00743   // make histogram if requested
00744   if( tupleFile.empty() ) 
00745     return 0;
00746 
00747   // make a writer
00748   auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00749   if( !tuple->init(tupleFile.c_str()) ) {
00750     cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00751     return 10;
00752   }
00753 
00754   // determine if certain variables are to be excluded from usage,
00755   // but included in the output storage file (-Z option)
00756   string printVarsDoNotFeed;
00757   vector<vector<string> > varsDoNotFeed;
00758   SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00759   vector<unsigned> mapper;
00760   for( int d=0;d<vars.size();d++ ) {
00761     if( varsDoNotFeed.empty() ||
00762         (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00763          ==varsDoNotFeed[0].end()) ) {
00764       mapper.push_back(d);
00765     }
00766     else {
00767       printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00768       printVarsDoNotFeed += vars[d];
00769     }
00770   }
00771   if( !printVarsDoNotFeed.empty() ) {
00772     cout << "The following variables are not used in the algorithm, " 
00773          << "but will be included in the output file: " 
00774          << printVarsDoNotFeed.c_str() << endl;
00775   }
00776 
00777   // feed
00778   SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00779   feeder.addClassifier(trainedAda.get(),"ada");
00780   if( !feeder.feed(1000) ) {
00781     cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00782     return 11;
00783   }
00784 
00785   // exit
00786   return 0;
00787 }


Generated on Tue Jun 9 17:55:00 2009 for CMSSW by  doxygen 1.5.4