#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <iomanip>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void help | ( | const char * | prog | ) |
Definition at line 51 of file SprBaggerDecisionTreeApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00052 { 00053 cout << "Usage: " << prog 00054 << " training_data_file" << endl; 00055 cout << "\t Options: " << endl; 00056 cout << "\t-h --- help " << endl; 00057 cout << "\t-j use regular tree instead of faster topdown tree " << endl; 00058 cout << "\t-k discrete decision tree output (default=continuous)"<< endl; 00059 cout << "\t-o output Tuple file " << endl; 00060 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00061 cout << "\t-A save output data in ascii instead of Root " << endl; 00062 cout << "\t-n number of Bagger training cycles " << endl; 00063 cout << "\t-l minimal number of entries per tree leaf (def=1) " << endl; 00064 cout << "\t-s max number of sampled features (def=0 no sampling)"<< endl; 00065 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl; 00066 cout << "\t-Q apply variable transformation saved in file " << endl; 00067 cout << "\t-b use a version of Breiman's arc-x4 algorithm " << endl; 00068 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00069 cout << "\t-f store trained Bagger to file " << endl; 00070 cout << "\t-F generate code for AdaBoost and store to file " << endl; 00071 cout << "\t-c criterion for optimization " << endl; 00072 cout << "\t\t 1 = correctly classified fraction " << endl; 00073 cout << "\t\t 2 = signal significance s/sqrt(s+b) " << endl; 00074 cout << "\t\t 3 = purity s/(s+b) " << endl; 00075 cout << "\t\t 4 = tagger efficiency Q " << endl; 00076 cout << "\t\t 5 = Gini index (default) " << endl; 00077 cout << "\t\t 6 = cross-entropy " << endl; 00078 cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl; 00079 cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b)) " << endl; 00080 cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b)) " << endl; 00081 cout << "\t\t -P background normalization factor for Punzi FOM" << endl; 00082 cout << "\t-g per-event loss for (cross-)validation " << endl; 00083 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl; 00084 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl; 00085 cout << "\t\t 3 - misid fraction " << endl; 00086 cout << "\t-m replace data values below this cutoff with medians" << endl; 00087 cout << "\t-i count splits on input variables " << endl; 00088 cout << "\t-r resume training for Bagger stored in file " << endl; 00089 cout << "\t-K keep this fraction in training set and " << endl; 00090 cout << "\t\t put the rest into validation set " << endl; 00091 cout << "\t-D randomize training set split-up " << endl; 00092 cout << "\t-G generate seed from time of day for bootstrap " << endl; 00093 cout << "\t-t read validation/test data from a file " << endl; 00094 cout << "\t\t (must be in same format as input data!!! " << endl; 00095 cout << "\t-d frequency of print-outs for validation data " << endl; 00096 cout << "\t-w scale all signal weights by this factor " << endl; 00097 cout << "\t-V include only these input variables " << endl; 00098 cout << "\t-z exclude input variables from the list " << endl; 00099 cout << "\t-Z exclude input variables from the list, " 00100 << "but put them in the output file " << endl; 00101 cout << "\t\t Variables must be listed in quotes and separated by commas." 00102 << endl; 00103 cout << "\t-x cross-validate by splitting data into a given " 00104 << "number of pieces" << endl; 00105 cout << "\t-q a set of minimal node sizes for cross-validation" << endl; 00106 cout << "\t\t Node sizes must be listed in quotes and separated by commas." 00107 << endl; 00108 }
Definition at line 111 of file SprBaggerDecisionTreeApp.cc.
References begin, c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, end, lat::endl(), filter, find(), help(), i, j, size, split, t, tree, vars, and weights.
00112 { 00113 // check command line 00114 if( argc < 2 ) { 00115 help(argv[0]); 00116 return 1; 00117 } 00118 00119 // init 00120 string tupleFile; 00121 int readMode = 0; 00122 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00123 unsigned cycles = 0; 00124 unsigned nmin = 1; 00125 int verbose = 0; 00126 string outFile; 00127 string codeFile; 00128 string resumeFile; 00129 int iCrit = 5; 00130 string valFile; 00131 unsigned valPrint = 0; 00132 bool scaleWeights = false; 00133 double sW = 1.; 00134 int nFeaturesToSample = 0; 00135 bool countTreeSplits = false; 00136 bool setLowCutoff = false; 00137 double lowCutoff = 0; 00138 string includeList, excludeList; 00139 unsigned nCross = 0; 00140 string nodeValidationString; 00141 bool useTopdown = true; 00142 bool discrete = false; 00143 int iLoss = 0; 00144 string inputClassesString; 00145 bool useArcE4 = false; 00146 double bW = 1.; 00147 string stringVarsDoNotFeed; 00148 bool split = false; 00149 double splitFactor = 0; 00150 bool splitRandomize = false; 00151 bool initBootstrapFromTimeOfDay = false; 00152 string transformerFile; 00153 00154 // decode command line 00155 int c; 00156 extern char* optarg; 00157 // extern int optind; 00158 while( (c = getopt(argc,argv,"hjko:a:An:l:s:y:Q:bv:f:F:c:P:g:m:ir:K:DGt:d:w:V:z:Z:x:q:")) 00159 != EOF ) { 00160 switch( c ) 00161 { 00162 case 'h' : 00163 help(argv[0]); 00164 return 1; 00165 case 'j' : 00166 useTopdown = false; 00167 break; 00168 case 'k' : 00169 discrete = true; 00170 break; 00171 case 'o' : 00172 tupleFile = optarg; 00173 break; 00174 case 'a' : 00175 readMode = (optarg==0 ? 0 : atoi(optarg)); 00176 break; 00177 case 'A' : 00178 writeMode = SprRWFactory::Ascii; 00179 break; 00180 case 'n' : 00181 cycles = (optarg==0 ? 1 : atoi(optarg)); 00182 break; 00183 case 'l' : 00184 nmin = (optarg==0 ? 1 : atoi(optarg)); 00185 break; 00186 case 's' : 00187 nFeaturesToSample = (optarg==0 ? 0 : atoi(optarg)); 00188 break; 00189 case 'y' : 00190 inputClassesString = optarg; 00191 break; 00192 case 'Q' : 00193 transformerFile = optarg; 00194 break; 00195 case 'b' : 00196 useArcE4 = true; 00197 break; 00198 case 'v' : 00199 verbose = (optarg==0 ? 0 : atoi(optarg)); 00200 break; 00201 case 'f' : 00202 outFile = optarg; 00203 break; 00204 case 'F' : 00205 codeFile = optarg; 00206 break; 00207 case 'c' : 00208 iCrit = (optarg==0 ? 5 : atoi(optarg)); 00209 break; 00210 case 'P' : 00211 bW = (optarg==0 ? 1 : atof(optarg)); 00212 break; 00213 case 'g' : 00214 iLoss = (optarg==0 ? 0 : atoi(optarg)); 00215 break; 00216 case 'm' : 00217 if( optarg != 0 ) { 00218 setLowCutoff = true; 00219 lowCutoff = atof(optarg); 00220 } 00221 break; 00222 case 'i' : 00223 countTreeSplits = true; 00224 break; 00225 case 'r' : 00226 resumeFile = optarg; 00227 break; 00228 case 'K' : 00229 split = true; 00230 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00231 break; 00232 case 'D' : 00233 splitRandomize = true; 00234 break; 00235 case 'G' : 00236 initBootstrapFromTimeOfDay = true; 00237 break; 00238 case 't' : 00239 valFile = optarg; 00240 break; 00241 case 'd' : 00242 valPrint = (optarg==0 ? 0 : atoi(optarg)); 00243 break; 00244 case 'w' : 00245 if( optarg != 0 ) { 00246 scaleWeights = true; 00247 sW = atof(optarg); 00248 } 00249 break; 00250 case 'V' : 00251 includeList = optarg; 00252 break; 00253 case 'z' : 00254 excludeList = optarg; 00255 break; 00256 case 'Z' : 00257 stringVarsDoNotFeed = optarg; 00258 break; 00259 case 'x' : 00260 nCross = (optarg==0 ? 0 : atoi(optarg)); 00261 break; 00262 case 'q' : 00263 nodeValidationString = optarg; 00264 break; 00265 } 00266 } 00267 00268 // There has to be 1 argument after all options. 00269 string trFile = argv[argc-1]; 00270 if( trFile.empty() ) { 00271 cerr << "No training file is specified." << endl; 00272 return 1; 00273 } 00274 00275 // make reader 00276 SprRWFactory::DataType inputType 00277 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00278 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00279 00280 // include variables 00281 set<string> includeSet; 00282 if( !includeList.empty() ) { 00283 vector<vector<string> > includeVars; 00284 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00285 assert( !includeVars.empty() ); 00286 for( int i=0;i<includeVars[0].size();i++ ) 00287 includeSet.insert(includeVars[0][i]); 00288 if( !reader->chooseVars(includeSet) ) { 00289 cerr << "Unable to include variables in training set." << endl; 00290 return 2; 00291 } 00292 else { 00293 cout << "Following variables have been included in optimization: "; 00294 for( set<string>::const_iterator 00295 i=includeSet.begin();i!=includeSet.end();i++ ) 00296 cout << "\"" << *i << "\"" << " "; 00297 cout << endl; 00298 } 00299 } 00300 00301 // exclude variables 00302 set<string> excludeSet; 00303 if( !excludeList.empty() ) { 00304 vector<vector<string> > excludeVars; 00305 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00306 assert( !excludeVars.empty() ); 00307 for( int i=0;i<excludeVars[0].size();i++ ) 00308 excludeSet.insert(excludeVars[0][i]); 00309 if( !reader->chooseAllBut(excludeSet) ) { 00310 cerr << "Unable to exclude variables from training set." << endl; 00311 return 2; 00312 } 00313 else { 00314 cout << "Following variables have been excluded from optimization: "; 00315 for( set<string>::const_iterator 00316 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00317 cout << "\"" << *i << "\"" << " "; 00318 cout << endl; 00319 } 00320 } 00321 00322 // read training data from file 00323 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00324 if( filter.get() == 0 ) { 00325 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00326 return 2; 00327 } 00328 vector<string> vars; 00329 filter->vars(vars); 00330 cout << "Read data from file " << trFile.c_str() 00331 << " for variables"; 00332 for( int i=0;i<vars.size();i++ ) 00333 cout << " \"" << vars[i].c_str() << "\""; 00334 cout << endl; 00335 cout << "Total number of points read: " << filter->size() << endl; 00336 00337 // filter training data by class 00338 vector<SprClass> inputClasses; 00339 if( !filter->filterByClass(inputClassesString.c_str()) ) { 00340 cerr << "Cannot choose input classes for string " 00341 << inputClassesString << endl; 00342 return 2; 00343 } 00344 filter->classes(inputClasses); 00345 assert( inputClasses.size() > 1 ); 00346 cout << "Training data filtered by class." << endl; 00347 for( int i=0;i<inputClasses.size();i++ ) { 00348 cout << "Points in class " << inputClasses[i] << ": " 00349 << filter->ptsInClass(inputClasses[i]) << endl; 00350 } 00351 00352 // scale weights 00353 if( scaleWeights ) { 00354 cout << "Signal weights are multiplied by " << sW << endl; 00355 filter->scaleWeights(inputClasses[1],sW); 00356 } 00357 00358 // apply low cutoff 00359 if( setLowCutoff ) { 00360 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00361 cerr << "Unable to replace missing values in training data." << endl; 00362 return 2; 00363 } 00364 else 00365 cout << "Values below " << lowCutoff << " in training data" 00366 << " have been replaced with medians." << endl; 00367 } 00368 00369 // read validation data from file 00370 auto_ptr<SprAbsFilter> valFilter; 00371 if( split && !valFile.empty() ) { 00372 cerr << "Unable to split training data and use validation data " 00373 << "from a separate file." << endl; 00374 return 2; 00375 } 00376 if( split && valPrint!=0 ) { 00377 cout << "Splitting training data with factor " << splitFactor << endl; 00378 if( splitRandomize ) 00379 cout << "Will use randomized splitting." << endl; 00380 vector<double> weights; 00381 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00382 if( splitted == 0 ) { 00383 cerr << "Unable to split training data." << endl; 00384 return 2; 00385 } 00386 bool ownData = true; 00387 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00388 cout << "Training data re-filtered:" << endl; 00389 for( int i=0;i<inputClasses.size();i++ ) { 00390 cout << "Points in class " << inputClasses[i] << ": " 00391 << filter->ptsInClass(inputClasses[i]) << endl; 00392 } 00393 } 00394 if( !valFile.empty() && valPrint!=0 ) { 00395 auto_ptr<SprAbsReader> 00396 valReader(SprRWFactory::makeReader(inputType,readMode)); 00397 if( !includeSet.empty() ) { 00398 if( !valReader->chooseVars(includeSet) ) { 00399 cerr << "Unable to include variables in validation set." << endl; 00400 return 2; 00401 } 00402 } 00403 if( !excludeSet.empty() ) { 00404 if( !valReader->chooseAllBut(excludeSet) ) { 00405 cerr << "Unable to exclude variables from validation set." << endl; 00406 return 2; 00407 } 00408 } 00409 valFilter.reset(valReader->read(valFile.c_str())); 00410 if( valFilter.get() == 0 ) { 00411 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00412 return 2; 00413 } 00414 vector<string> valVars; 00415 valFilter->vars(valVars); 00416 cout << "Read validation data from file " << valFile.c_str() 00417 << " for variables"; 00418 for( int i=0;i<valVars.size();i++ ) 00419 cout << " \"" << valVars[i].c_str() << "\""; 00420 cout << endl; 00421 cout << "Total number of points read: " << valFilter->size() << endl; 00422 } 00423 00424 // filter validation data by class 00425 if( valFilter.get() != 0 ) { 00426 if( !valFilter->filterByClass(inputClassesString.c_str()) ) { 00427 cerr << "Cannot choose input classes for string " 00428 << inputClassesString << endl; 00429 return 2; 00430 } 00431 valFilter->classes(inputClasses); 00432 cout << "Validation data filtered by class." << endl; 00433 for( int i=0;i<inputClasses.size();i++ ) { 00434 cout << "Points in class " << inputClasses[i] << ": " 00435 << valFilter->ptsInClass(inputClasses[i]) << endl; 00436 } 00437 } 00438 00439 // scale weights 00440 if( scaleWeights && valFilter.get()!=0 ) 00441 valFilter->scaleWeights(inputClasses[1],sW); 00442 00443 // apply low cutoff 00444 if( setLowCutoff && valFilter.get()!=0 ) { 00445 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00446 cerr << "Unable to replace missing values in validation data." << endl; 00447 return 2; 00448 } 00449 else 00450 cout << "Values below " << lowCutoff << " in validation data" 00451 << " have been replaced with medians." << endl; 00452 } 00453 00454 // apply transformation of variables to training and test data 00455 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00456 if( !transformerFile.empty() ) { 00457 SprVarTransformerReader transReader; 00458 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00459 if( t == 0 ) { 00460 cerr << "Unable to read VarTransformer from file " 00461 << transformerFile.c_str() << endl; 00462 return 2; 00463 } 00464 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00465 SprTransformerFilter* t_valid = 0; 00466 if( valFilter.get() != 0 ) 00467 t_valid = new SprTransformerFilter(valFilter.get()); 00468 bool replaceOriginalData = true; 00469 if( !t_train->transform(t,replaceOriginalData) ) { 00470 cerr << "Unable to apply VarTransformer to training data." << endl; 00471 return 2; 00472 } 00473 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00474 cerr << "Unable to apply VarTransformer to validation data." << endl; 00475 return 2; 00476 } 00477 cout << "Variable transformation from file " 00478 << transformerFile.c_str() << " has been applied to " 00479 << "training and validation data." << endl; 00480 garbage_train.reset(filter.release()); 00481 garbage_valid.reset(valFilter.release()); 00482 filter.reset(t_train); 00483 valFilter.reset(t_valid); 00484 } 00485 00486 // make optimization criterion 00487 auto_ptr<SprAbsTwoClassCriterion> crit; 00488 switch( iCrit ) 00489 { 00490 case 1 : 00491 crit.reset(new SprTwoClassIDFraction); 00492 cout << "Optimization criterion set to " 00493 << "Fraction of correctly classified events " << endl; 00494 break; 00495 case 2 : 00496 crit.reset(new SprTwoClassSignalSignif); 00497 cout << "Optimization criterion set to " 00498 << "Signal significance S/sqrt(S+B) " << endl; 00499 break; 00500 case 3 : 00501 crit.reset(new SprTwoClassPurity); 00502 cout << "Optimization criterion set to " 00503 << "Purity S/(S+B) " << endl; 00504 break; 00505 case 4 : 00506 crit.reset(new SprTwoClassTaggerEff); 00507 cout << "Optimization criterion set to " 00508 << "Tagging efficiency Q = e*(1-2w)^2 " << endl; 00509 break; 00510 case 5 : 00511 crit.reset(new SprTwoClassGiniIndex); 00512 cout << "Optimization criterion set to " 00513 << "Gini index -1+p^2+q^2 " << endl; 00514 break; 00515 case 6 : 00516 crit.reset(new SprTwoClassCrossEntropy); 00517 cout << "Optimization criterion set to " 00518 << "Cross-entropy p*log(p)+q*log(q) " << endl; 00519 break; 00520 case 7 : 00521 crit.reset(new SprTwoClassUniformPriorUL90); 00522 cout << "Optimization criterion set to " 00523 << "Inverse of 90% Bayesian upper limit with uniform prior" << endl; 00524 break; 00525 case 8 : 00526 crit.reset(new SprTwoClassBKDiscovery); 00527 cout << "Optimization criterion set to " 00528 << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl; 00529 break; 00530 case 9 : 00531 crit.reset(new SprTwoClassPunzi(bW)); 00532 cout << "Optimization criterion set to " 00533 << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl; 00534 break; 00535 default : 00536 cerr << "Unable to make initialization criterion." << endl; 00537 return 3; 00538 } 00539 00540 // check criterion vs classifier 00541 if( useArcE4 && !crit->symmetric() ) { 00542 cerr << "Unable to use arc-e4 with an asymmetric criterion." << endl; 00543 return 3; 00544 } 00545 00546 // make per-event loss 00547 auto_ptr<SprAverageLoss> loss; 00548 switch( iLoss ) 00549 { 00550 case 1 : 00551 loss.reset(new SprAverageLoss(&SprLoss::quadratic)); 00552 cout << "Per-event loss set to " 00553 << "Quadratic loss (y-f(x))^2 " << endl; 00554 break; 00555 case 2 : 00556 loss.reset(new SprAverageLoss(&SprLoss::purity_ratio)); 00557 cout << "Per-event loss set to " 00558 << "Exponential loss exp(-y*f(x)) " << endl; 00559 break; 00560 case 3 : 00561 loss.reset(new SprAverageLoss(&SprLoss::correct_id, 00562 &SprTransformation::continuous01ToDiscrete01)); 00563 cout << "Per-event loss set to " 00564 << "Misid rate int(y==f(x)) " << endl; 00565 break; 00566 default : 00567 cout << "No per-event loss is chosen. Will use the default." << endl; 00568 break; 00569 } 00570 00571 // make bootstrap for resampling input features 00572 auto_ptr<SprIntegerBootstrap> bootstrap; 00573 if( nFeaturesToSample > filter->dim() ) 00574 nFeaturesToSample = filter->dim(); 00575 if( nFeaturesToSample > 0 ) { 00576 bootstrap.reset(new SprIntegerBootstrap(filter->dim(),nFeaturesToSample)); 00577 if( !resumeFile.empty() || initBootstrapFromTimeOfDay ) 00578 bootstrap->init(-1); 00579 } 00580 00581 // make decision tree 00582 bool doMerge = !crit->symmetric(); 00583 if( doMerge ) useTopdown = false; 00584 auto_ptr<SprDecisionTree> tree; 00585 if( useTopdown ) { 00586 tree.reset(new SprTopdownTree(filter.get(),crit.get(),nmin, 00587 discrete,bootstrap.get())); 00588 } 00589 else { 00590 tree.reset(new SprDecisionTree(filter.get(),crit.get(),nmin,doMerge, 00591 discrete,bootstrap.get())); 00592 } 00593 if( countTreeSplits ) tree->startSplitCounter(); 00594 tree->useFastSort(); 00595 00596 // if cross-validation requested, cross-validate and exit 00597 if( nCross > 0 ) { 00598 // message 00599 cout << "Will cross-validate by dividing training data into " 00600 << nCross << " subsamples." << endl; 00601 vector<vector<int> > nodeMinSize; 00602 00603 // decode validation string 00604 if( !nodeValidationString.empty() ) 00605 SprStringParser::parseToInts(nodeValidationString.c_str(),nodeMinSize); 00606 else { 00607 nodeMinSize.resize(1); 00608 nodeMinSize[0].push_back(nmin); 00609 } 00610 if( nodeMinSize.empty() || nodeMinSize[0].empty() ) { 00611 cerr << "Unable to determine node size for cross-validation." << endl; 00612 return 4; 00613 } 00614 else { 00615 cout << "Will cross-validate for trees with minimal node sizes: "; 00616 for( int i=0;i<nodeMinSize[0].size();i++ ) 00617 cout << nodeMinSize[0][i] << " "; 00618 cout << endl; 00619 } 00620 00621 // loop over nodes to prepare classifiers 00622 vector<SprDecisionTree*> trees(nodeMinSize[0].size()); 00623 vector<SprAbsClassifier*> classifiers(nodeMinSize[0].size()); 00624 for( int i=0;i<nodeMinSize[0].size();i++ ) { 00625 SprDecisionTree* tree1 = 0; 00626 if( useTopdown ) { 00627 tree1 = new SprTopdownTree(filter.get(),crit.get(),nodeMinSize[0][i], 00628 discrete,bootstrap.get()); 00629 } 00630 else { 00631 tree1 = new SprDecisionTree(filter.get(),crit.get(),nodeMinSize[0][i], 00632 doMerge,discrete,bootstrap.get()); 00633 } 00634 tree1->useFastSort(); 00635 SprBagger* bagger1 = 0; 00636 if( useArcE4 ) 00637 bagger1 = new SprArcE4(filter.get(),cycles,discrete); 00638 else 00639 bagger1 = new SprBagger(filter.get(),cycles,discrete); 00640 if( initBootstrapFromTimeOfDay 00641 && !bagger1->initBootstrapFromTimeOfDay() ) { 00642 cerr << "Unable to generate seed from time of day for Bagger." << endl; 00643 return 4; 00644 } 00645 if( !bagger1->addTrainable(tree1) ) { 00646 cerr << "Unable to add decision tree to Bagger for CV." << endl; 00647 for( int j=0;j<trees.size();j++ ) { 00648 delete trees[j]; 00649 delete classifiers[j]; 00650 } 00651 return 4; 00652 } 00653 trees[i] = tree1; 00654 classifiers[i] = bagger1; 00655 } 00656 00657 // cross-validate 00658 vector<double> cvFom; 00659 SprCrossValidator cv(filter.get(),nCross); 00660 if( !cv.validate(crit.get(),loss.get(),classifiers, 00661 inputClasses[0],inputClasses[1], 00662 SprUtils::lowerBound(0.5),cvFom,verbose) ) { 00663 cerr << "Unable to cross-validate." << endl; 00664 for( int j=0;j<trees.size();j++ ) { 00665 delete trees[j]; 00666 delete classifiers[j]; 00667 } 00668 return 4; 00669 } 00670 else { 00671 cout << "Cross-validated FOMs:" << endl; 00672 for( int i=0;i<cvFom.size();i++ ) { 00673 cout << "Node size=" << setw(8) << nodeMinSize[0][i] 00674 << " FOM=" << setw(10) << cvFom[i] << endl; 00675 } 00676 } 00677 00678 // cleanup 00679 for( int j=0;j<trees.size();j++ ) { 00680 delete trees[j]; 00681 delete classifiers[j]; 00682 } 00683 00684 // normal exit 00685 return 0; 00686 }// end cross-validation 00687 00688 // make Bagger 00689 auto_ptr<SprBagger> bagger; 00690 if( useArcE4 ) 00691 bagger.reset(new SprArcE4(filter.get(),cycles,discrete)); 00692 else 00693 bagger.reset(new SprBagger(filter.get(),cycles,discrete)); 00694 00695 // set seed for bootstrap if necessary 00696 if( initBootstrapFromTimeOfDay && !bagger->initBootstrapFromTimeOfDay() ) { 00697 cerr << "Unable to generate seed from time of day for Bagger." << endl; 00698 return 4; 00699 } 00700 00701 // set validation 00702 if( valFilter.get()!=0 && !valFilter->empty() ) 00703 bagger->setValidation(valFilter.get(),valPrint,crit.get(),loss.get()); 00704 00705 // read saved Bagger from file 00706 if( !resumeFile.empty() ) { 00707 if( !SprClassifierReader::readTrainable(resumeFile.c_str(), 00708 bagger.get(),verbose) ) { 00709 cerr << "Failed to read saved Bagger from file " 00710 << resumeFile.c_str() << endl; 00711 return 5; 00712 } 00713 cout << "Read saved Bagger from file " << resumeFile.c_str() 00714 << " with " << bagger->nTrained() << " trained classifiers." 00715 << endl; 00716 } 00717 00718 // add trainable tree 00719 if( !bagger->addTrainable(tree.get()) ) { 00720 cerr << "Unable to add decision tree to Bagger." << endl; 00721 return 6; 00722 } 00723 00724 // train 00725 if( !bagger->train(verbose) ) 00726 cerr << "Bagger terminated with error." << endl; 00727 if( bagger->nTrained() == 0 ) { 00728 cerr << "Unable to train Bagger." << endl; 00729 return 7; 00730 } 00731 else { 00732 cout << "Bagger finished training with " << bagger->nTrained() 00733 << " classifiers." << endl; 00734 } 00735 00736 // save trained Bagger 00737 if( !outFile.empty() ) { 00738 if( !bagger->store(outFile.c_str()) ) { 00739 cerr << "Cannot store Bagger in file " << outFile.c_str() << endl; 00740 return 8; 00741 } 00742 } 00743 00744 // print out counted splits 00745 if( countTreeSplits ) tree->printSplitCounter(cout); 00746 00747 // make a trained Bagger 00748 auto_ptr<SprTrainedBagger> trainedBagger(bagger->makeTrained()); 00749 if( trainedBagger.get() == 0 ) { 00750 cerr << "Unable to get trained Bagger." << endl; 00751 return 7; 00752 } 00753 00754 // store code into file 00755 if( !codeFile.empty() ) { 00756 if( !trainedBagger->storeCode(codeFile.c_str()) ) { 00757 cerr << "Unable to store code for trained Bagger." << endl; 00758 return 8; 00759 } 00760 } 00761 00762 // make histogram if requested 00763 if( tupleFile.empty() ) 00764 return 0; 00765 00766 // make a writer 00767 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training")); 00768 if( !tuple->init(tupleFile.c_str()) ) { 00769 cerr << "Unable to open output file " << tupleFile.c_str() << endl; 00770 return 9; 00771 } 00772 00773 // determine if certain variables are to be excluded from usage, 00774 // but included in the output storage file (-Z option) 00775 string printVarsDoNotFeed; 00776 vector<vector<string> > varsDoNotFeed; 00777 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed); 00778 vector<unsigned> mapper; 00779 for( int d=0;d<vars.size();d++ ) { 00780 if( varsDoNotFeed.empty() || 00781 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d]) 00782 ==varsDoNotFeed[0].end()) ) { 00783 mapper.push_back(d); 00784 } 00785 else { 00786 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " ); 00787 printVarsDoNotFeed += vars[d]; 00788 } 00789 } 00790 if( !printVarsDoNotFeed.empty() ) { 00791 cout << "The following variables are not used in the algorithm, " 00792 << "but will be included in the output file: " 00793 << printVarsDoNotFeed.c_str() << endl; 00794 } 00795 00796 // feed 00797 SprDataFeeder feeder(filter.get(),tuple.get(),mapper); 00798 feeder.addClassifier(trainedBagger.get(),"bag"); 00799 if( !feeder.feed(1000) ) { 00800 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl; 00801 return 10; 00802 } 00803 00804 // exit 00805 return 0; 00806 }