#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <iomanip>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void help | ( | const char * | prog | ) |
Definition at line 44 of file SprAdaBoostDecisionTreeApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00045 { 00046 cout << "Usage: " << prog 00047 << " training_data_file" << endl; 00048 cout << "\t Options: " << endl; 00049 cout << "\t-h --- help " << endl; 00050 cout << "\t-j use regular tree instead of faster topdown tree " << endl; 00051 cout << "\t-M AdaBoost mode " << endl; 00052 cout << "\t\t 1 = Discrete AdaBoost (default) " << endl; 00053 cout << "\t\t 2 = Real AdaBoost " << endl; 00054 cout << "\t\t 3 = Epsilon AdaBoost " << endl; 00055 cout << "\t-E epsilon for Epsilon and Real AdaBoosts (def=0.01)" << endl; 00056 cout << "\t-o output Tuple file " << endl; 00057 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00058 cout << "\t-A save output data in ascii instead of Root " << endl; 00059 cout << "\t-n number of AdaBoost training cycles " << endl; 00060 cout << "\t-l minimal number of entries per tree leaf (def=1) " << endl; 00061 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl; 00062 cout << "\t-Q apply variable transformation saved in file " << endl; 00063 cout << "\t-c criterion for optimization " << endl; 00064 cout << "\t\t 1 = correctly classified fraction " << endl; 00065 cout << "\t\t 5 = Gini index (default) " << endl; 00066 cout << "\t\t 6 = cross-entropy " << endl; 00067 cout << "\t-g per-event loss for (cross-)validation " << endl; 00068 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl; 00069 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl; 00070 cout << "\t\t 3 - misid fraction " << endl; 00071 cout << "\t-b max number of sampled features (def=0 no sampling)" << endl; 00072 cout << "\t-m replace data values below this cutoff with medians" << endl; 00073 cout << "\t-i count splits on input variables " << endl; 00074 cout << "\t-s use standard AdaBoost (see SprTrainedAdaBoost.hh)"<< endl; 00075 cout << "\t-e skip initial event reweighting when resuming " << endl; 00076 cout << "\t-u store data with modified weights to file " << endl; 00077 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00078 cout << "\t-f store trained AdaBoost to file " << endl; 00079 cout << "\t-F generate code for AdaBoost and store to file " << endl; 00080 cout << "\t-r resume training for AdaBoost stored in file " << endl; 00081 cout << "\t-K keep this fraction in training set and " << endl; 00082 cout << "\t\t put the rest into validation set " << endl; 00083 cout << "\t-D randomize training set split-up " << endl; 00084 cout << "\t-t read validation/test data from a file " << endl; 00085 cout << "\t\t (must be in same format as input data!!! " << endl; 00086 cout << "\t-d frequency of print-outs for validation data " << endl; 00087 cout << "\t-w scale all signal weights by this factor " << endl; 00088 cout << "\t-V include only these input variables " << endl; 00089 cout << "\t-z exclude input variables from the list " << endl; 00090 cout << "\t-Z exclude input variables from the list, " 00091 << "but put them in the output file " << endl; 00092 cout << "\t\t Variables must be listed in quotes and separated by commas." 00093 << endl; 00094 cout << "\t-x cross-validate by splitting data into a given " 00095 << "number of pieces" << endl; 00096 cout << "\t-q a set of minimal node sizes for cross-validation" << endl; 00097 cout << "\t\t Node sizes must be listed in quotes and separated by commas." 00098 << endl; 00099 }
Definition at line 102 of file SprAdaBoostDecisionTreeApp.cc.
References begin, c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, end, lat::endl(), HLT_VtxMuL3::Epsilon, geometryDiff::epsilon, filter, find(), help(), i, j, size, split, t, tree, vars, and weights.
00103 { 00104 // check command line 00105 if( argc < 2 ) { 00106 help(argv[0]); 00107 return 1; 00108 } 00109 00110 // init 00111 string tupleFile; 00112 int readMode = 0; 00113 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00114 unsigned cycles = 0; 00115 unsigned nmin = 1; 00116 int iCrit = 5; 00117 int verbose = 0; 00118 string outFile; 00119 string codeFile; 00120 string resumeFile; 00121 string valFile; 00122 unsigned valPrint = 0; 00123 bool scaleWeights = false; 00124 double sW = 1.; 00125 bool countTreeSplits = false; 00126 bool useStandardAB = false; 00127 int iAdaBoostMode = 1; 00128 double epsilon = 0.01; 00129 bool skipInitialEventReweighting = false; 00130 string weightedDataOut; 00131 bool setLowCutoff = false; 00132 double lowCutoff = 0; 00133 string includeList, excludeList; 00134 unsigned nCross = 0; 00135 string nodeValidationString; 00136 int nFeaturesToSample = 0; 00137 bool bagInput = false; 00138 bool useTopdown = true; 00139 int iLoss = 0; 00140 string inputClassesString; 00141 string stringVarsDoNotFeed; 00142 bool split = false; 00143 double splitFactor = 0; 00144 bool splitRandomize = false; 00145 string transformerFile; 00146 00147 // decode command line 00148 int c; 00149 extern char* optarg; 00150 // extern int optind; 00151 while((c = getopt(argc,argv,"hjM:E:o:a:An:l:y:Q:c:g:b:m:iseu:v:f:F:r:K:Dt:d:w:V:z:Z:x:q:")) != EOF ) { 00152 switch( c ) 00153 { 00154 case 'h' : 00155 help(argv[0]); 00156 return 1; 00157 case 'j' : 00158 useTopdown = false; 00159 break; 00160 case 'M' : 00161 iAdaBoostMode = (optarg==0 ? 1 : atoi(optarg)); 00162 break; 00163 case 'E' : 00164 epsilon = (optarg==0 ? 0.01 : atof(optarg)); 00165 break; 00166 case 'o' : 00167 tupleFile = optarg; 00168 break; 00169 case 'a' : 00170 readMode = (optarg==0 ? 0 : atoi(optarg)); 00171 break; 00172 case 'A' : 00173 writeMode = SprRWFactory::Ascii; 00174 break; 00175 case 'n' : 00176 cycles = (optarg==0 ? 1 : atoi(optarg)); 00177 break; 00178 case 'l' : 00179 nmin = (optarg==0 ? 1 : atoi(optarg)); 00180 break; 00181 case 'y' : 00182 inputClassesString = optarg; 00183 break; 00184 case 'Q' : 00185 transformerFile = optarg; 00186 break; 00187 case 'c' : 00188 iCrit = (optarg==0 ? 5 : atoi(optarg)); 00189 break; 00190 case 'g' : 00191 iLoss = (optarg==0 ? 0 : atoi(optarg)); 00192 break; 00193 case 'b' : 00194 bagInput = true; 00195 nFeaturesToSample = (optarg==0 ? 0 : atoi(optarg)); 00196 break; 00197 case 'm' : 00198 if( optarg != 0 ) { 00199 setLowCutoff = true; 00200 lowCutoff = atof(optarg); 00201 } 00202 break; 00203 case 'i' : 00204 countTreeSplits = true; 00205 break; 00206 case 's' : 00207 useStandardAB = true; 00208 break; 00209 case 'e' : 00210 skipInitialEventReweighting = true; 00211 break; 00212 case 'u' : 00213 weightedDataOut = optarg; 00214 break; 00215 case 'v' : 00216 verbose = (optarg==0 ? 0 : atoi(optarg)); 00217 break; 00218 case 'f' : 00219 outFile = optarg; 00220 break; 00221 case 'F' : 00222 codeFile = optarg; 00223 break; 00224 case 'r' : 00225 resumeFile = optarg; 00226 break; 00227 case 'K' : 00228 split = true; 00229 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00230 break; 00231 case 'D' : 00232 splitRandomize = true; 00233 break; 00234 case 't' : 00235 valFile = optarg; 00236 break; 00237 case 'd' : 00238 valPrint = (optarg==0 ? 0 : atoi(optarg)); 00239 break; 00240 case 'w' : 00241 if( optarg != 0 ) { 00242 scaleWeights = true; 00243 sW = atof(optarg); 00244 } 00245 break; 00246 case 'V' : 00247 includeList = optarg; 00248 break; 00249 case 'z' : 00250 excludeList = optarg; 00251 break; 00252 case 'Z' : 00253 stringVarsDoNotFeed = optarg; 00254 break; 00255 case 'x' : 00256 nCross = (optarg==0 ? 0 : atoi(optarg)); 00257 break; 00258 case 'q' : 00259 nodeValidationString = optarg; 00260 break; 00261 } 00262 } 00263 00264 // There has to be 1 argument after all options. 00265 string trFile = argv[argc-1]; 00266 if( trFile.empty() ) { 00267 cerr << "No training file is specified." << endl; 00268 return 1; 00269 } 00270 00271 // make reader 00272 SprRWFactory::DataType inputType 00273 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00274 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00275 00276 // include variables 00277 set<string> includeSet; 00278 if( !includeList.empty() ) { 00279 vector<vector<string> > includeVars; 00280 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00281 assert( !includeVars.empty() ); 00282 for( int i=0;i<includeVars[0].size();i++ ) 00283 includeSet.insert(includeVars[0][i]); 00284 if( !reader->chooseVars(includeSet) ) { 00285 cerr << "Unable to include variables in training set." << endl; 00286 return 2; 00287 } 00288 else { 00289 cout << "Following variables have been included in optimization: "; 00290 for( set<string>::const_iterator 00291 i=includeSet.begin();i!=includeSet.end();i++ ) 00292 cout << "\"" << *i << "\"" << " "; 00293 cout << endl; 00294 } 00295 } 00296 00297 // exclude variables 00298 set<string> excludeSet; 00299 if( !excludeList.empty() ) { 00300 vector<vector<string> > excludeVars; 00301 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00302 assert( !excludeVars.empty() ); 00303 for( int i=0;i<excludeVars[0].size();i++ ) 00304 excludeSet.insert(excludeVars[0][i]); 00305 if( !reader->chooseAllBut(excludeSet) ) { 00306 cerr << "Unable to exclude variables from training set." << endl; 00307 return 2; 00308 } 00309 else { 00310 cout << "Following variables have been excluded from optimization: "; 00311 for( set<string>::const_iterator 00312 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00313 cout << "\"" << *i << "\"" << " "; 00314 cout << endl; 00315 } 00316 } 00317 00318 // read training data from file 00319 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00320 if( filter.get() == 0 ) { 00321 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00322 return 2; 00323 } 00324 vector<string> vars; 00325 filter->vars(vars); 00326 cout << "Read data from file " << trFile.c_str() 00327 << " for variables"; 00328 for( int i=0;i<vars.size();i++ ) 00329 cout << " \"" << vars[i].c_str() << "\""; 00330 cout << endl; 00331 cout << "Total number of points read: " << filter->size() << endl; 00332 00333 // filter training data by class 00334 vector<SprClass> inputClasses; 00335 if( !filter->filterByClass(inputClassesString.c_str()) ) { 00336 cerr << "Cannot choose input classes for string " 00337 << inputClassesString << endl; 00338 return 2; 00339 } 00340 filter->classes(inputClasses); 00341 assert( inputClasses.size() > 1 ); 00342 cout << "Training data filtered by class." << endl; 00343 for( int i=0;i<inputClasses.size();i++ ) { 00344 cout << "Points in class " << inputClasses[i] << ": " 00345 << filter->ptsInClass(inputClasses[i]) << endl; 00346 } 00347 00348 // scale weights 00349 if( scaleWeights ) { 00350 cout << "Signal weights are multiplied by " << sW << endl; 00351 filter->scaleWeights(inputClasses[1],sW); 00352 } 00353 00354 // apply low cutoff 00355 if( setLowCutoff ) { 00356 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00357 cerr << "Unable to replace missing values in training data." << endl; 00358 return 2; 00359 } 00360 else 00361 cout << "Values below " << lowCutoff << " in training data" 00362 << " have been replaced with medians." << endl; 00363 } 00364 00365 // read validation data from file 00366 auto_ptr<SprAbsFilter> valFilter; 00367 if( split && !valFile.empty() ) { 00368 cerr << "Unable to split training data and use validation data " 00369 << "from a separate file." << endl; 00370 return 2; 00371 } 00372 if( split && valPrint!=0 ) { 00373 cout << "Splitting training data with factor " << splitFactor << endl; 00374 if( splitRandomize ) 00375 cout << "Will use randomized splitting." << endl; 00376 vector<double> weights; 00377 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00378 if( splitted == 0 ) { 00379 cerr << "Unable to split training data." << endl; 00380 return 2; 00381 } 00382 bool ownData = true; 00383 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00384 cout << "Training data re-filtered:" << endl; 00385 for( int i=0;i<inputClasses.size();i++ ) { 00386 cout << "Points in class " << inputClasses[i] << ": " 00387 << filter->ptsInClass(inputClasses[i]) << endl; 00388 } 00389 } 00390 if( !valFile.empty() && valPrint!=0 ) { 00391 auto_ptr<SprAbsReader> 00392 valReader(SprRWFactory::makeReader(inputType,readMode)); 00393 if( !includeSet.empty() ) { 00394 if( !valReader->chooseVars(includeSet) ) { 00395 cerr << "Unable to include variables in validation set." << endl; 00396 return 2; 00397 } 00398 } 00399 if( !excludeSet.empty() ) { 00400 if( !valReader->chooseAllBut(excludeSet) ) { 00401 cerr << "Unable to exclude variables from validation set." << endl; 00402 return 2; 00403 } 00404 } 00405 valFilter.reset(valReader->read(valFile.c_str())); 00406 if( valFilter.get() == 0 ) { 00407 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00408 return 2; 00409 } 00410 vector<string> valVars; 00411 valFilter->vars(valVars); 00412 cout << "Read validation data from file " << valFile.c_str() 00413 << " for variables"; 00414 for( int i=0;i<valVars.size();i++ ) 00415 cout << " \"" << valVars[i].c_str() << "\""; 00416 cout << endl; 00417 cout << "Total number of points read: " << valFilter->size() << endl; 00418 cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0]) 00419 << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl; 00420 } 00421 00422 // filter validation data by class 00423 if( valFilter.get() != 0 ) { 00424 if( !valFilter->filterByClass(inputClassesString.c_str()) ) { 00425 cerr << "Cannot choose input classes for string " 00426 << inputClassesString << endl; 00427 return 2; 00428 } 00429 valFilter->classes(inputClasses); 00430 cout << "Validation data filtered by class." << endl; 00431 for( int i=0;i<inputClasses.size();i++ ) { 00432 cout << "Points in class " << inputClasses[i] << ": " 00433 << valFilter->ptsInClass(inputClasses[i]) << endl; 00434 } 00435 } 00436 00437 // scale weights 00438 if( scaleWeights && valFilter.get()!=0 ) 00439 valFilter->scaleWeights(inputClasses[1],sW); 00440 00441 // apply low cutoff 00442 if( setLowCutoff && valFilter.get()!=0 ) { 00443 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00444 cerr << "Unable to replace missing values in validation data." << endl; 00445 return 2; 00446 } 00447 else 00448 cout << "Values below " << lowCutoff << " in validation data" 00449 << " have been replaced with medians." << endl; 00450 } 00451 00452 // apply transformation of variables to training and test data 00453 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00454 if( !transformerFile.empty() ) { 00455 SprVarTransformerReader transReader; 00456 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00457 if( t == 0 ) { 00458 cerr << "Unable to read VarTransformer from file " 00459 << transformerFile.c_str() << endl; 00460 return 2; 00461 } 00462 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00463 SprTransformerFilter* t_valid = 0; 00464 if( valFilter.get() != 0 ) 00465 t_valid = new SprTransformerFilter(valFilter.get()); 00466 bool replaceOriginalData = true; 00467 if( !t_train->transform(t,replaceOriginalData) ) { 00468 cerr << "Unable to apply VarTransformer to training data." << endl; 00469 return 2; 00470 } 00471 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00472 cerr << "Unable to apply VarTransformer to validation data." << endl; 00473 return 2; 00474 } 00475 cout << "Variable transformation from file " 00476 << transformerFile.c_str() << " has been applied to " 00477 << "training and validation data." << endl; 00478 garbage_train.reset(filter.release()); 00479 garbage_valid.reset(valFilter.release()); 00480 filter.reset(t_train); 00481 valFilter.reset(t_valid); 00482 } 00483 00484 // make optimization criterion 00485 auto_ptr<SprAbsTwoClassCriterion> crit; 00486 switch( iCrit ) 00487 { 00488 case 1 : 00489 crit.reset(new SprTwoClassIDFraction); 00490 cout << "Optimization criterion set to " 00491 << "Fraction of correctly classified events " << endl; 00492 break; 00493 case 5 : 00494 crit.reset(new SprTwoClassGiniIndex); 00495 cout << "Optimization criterion set to " 00496 << "Gini index -1+p^2+q^2 " << endl; 00497 break; 00498 case 6 : 00499 crit.reset(new SprTwoClassCrossEntropy); 00500 cout << "Optimization criterion set to " 00501 << "Cross-entropy p*log(p)+q*log(q) " << endl; 00502 break; 00503 default : 00504 cerr << "Unable to make initialization criterion." << endl; 00505 return 3; 00506 } 00507 00508 // make per-event loss 00509 auto_ptr<SprAverageLoss> loss; 00510 switch( iLoss ) 00511 { 00512 case 1 : 00513 loss.reset(new SprAverageLoss(&SprLoss::quadratic, 00514 &SprTransformation::logit)); 00515 cout << "Per-event loss set to " 00516 << "Quadratic loss (y-f(x))^2 " << endl; 00517 break; 00518 case 2 : 00519 loss.reset(new SprAverageLoss(&SprLoss::exponential)); 00520 cout << "Per-event loss set to " 00521 << "Exponential loss exp(-y*f(x)) " << endl; 00522 break; 00523 case 3 : 00524 loss.reset(new SprAverageLoss(&SprLoss::correct_id, 00525 &SprTransformation::inftyRangeToDiscrete01)); 00526 cout << "Per-event loss set to " 00527 << "Misid rate int(y==f(x)) " << endl; 00528 break; 00529 default : 00530 cout << "No per-event loss is chosen. Will use the default." << endl; 00531 break; 00532 } 00533 00534 // make AdaBoost mode 00535 SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete; 00536 switch( iAdaBoostMode ) 00537 { 00538 case 1 : 00539 abMode = SprTrainedAdaBoost::Discrete; 00540 cout << "Will train Discrete AdaBoost." << endl; 00541 break; 00542 case 2 : 00543 abMode = SprTrainedAdaBoost::Real; 00544 cout << "Will train Real AdaBoost." << endl; 00545 break; 00546 case 3 : 00547 abMode = SprTrainedAdaBoost::Epsilon; 00548 cout << "Will train Epsilon AdaBoost." << endl; 00549 break; 00550 default : 00551 cout << "Will train Discrete AdaBoost." << endl; 00552 break; 00553 } 00554 00555 // make bootstrap for resampling input features 00556 auto_ptr<SprIntegerBootstrap> bootstrap; 00557 if( nFeaturesToSample > filter->dim() ) 00558 nFeaturesToSample = filter->dim(); 00559 if( nFeaturesToSample > 0 ) { 00560 bootstrap.reset(new SprIntegerBootstrap(filter->dim(),nFeaturesToSample)); 00561 if( !resumeFile.empty() ) bootstrap->init(-1); 00562 } 00563 00564 // make decision tree 00565 bool discrete = true; 00566 if( abMode==SprTrainedAdaBoost::Real ) discrete = false; 00567 bool doMerge = false; 00568 auto_ptr<SprDecisionTree> tree; 00569 if( useTopdown ) 00570 tree.reset(new SprTopdownTree(filter.get(),crit.get(), 00571 nmin,discrete,bootstrap.get())); 00572 else 00573 tree.reset(new SprDecisionTree(filter.get(),crit.get(), 00574 nmin,doMerge,discrete,bootstrap.get())); 00575 if( countTreeSplits ) tree->startSplitCounter(); 00576 if( abMode == SprTrainedAdaBoost::Real ) tree->forceMixedNodes(); 00577 tree->useFastSort(); 00578 00579 // if cross-validation requested, cross-validate and exit 00580 if( nCross > 0 ) { 00581 // message 00582 cout << "Will cross-validate by dividing training data into " 00583 << nCross << " subsamples." << endl; 00584 vector<vector<int> > nodeMinSize; 00585 00586 // decode validation string 00587 if( !nodeValidationString.empty() ) 00588 SprStringParser::parseToInts(nodeValidationString.c_str(),nodeMinSize); 00589 else { 00590 nodeMinSize.resize(1); 00591 nodeMinSize[0].push_back(nmin); 00592 } 00593 if( nodeMinSize.empty() || nodeMinSize[0].empty() ) { 00594 cerr << "Unable to determine node size for cross-validation." << endl; 00595 return 4; 00596 } 00597 else { 00598 cout << "Will cross-validate for trees with minimal node sizes: "; 00599 for( int i=0;i<nodeMinSize[0].size();i++ ) 00600 cout << nodeMinSize[0][i] << " "; 00601 cout << endl; 00602 } 00603 00604 // loop over nodes to prepare classifiers 00605 vector<SprDecisionTree*> trees(nodeMinSize[0].size()); 00606 vector<SprAbsClassifier*> classifiers(nodeMinSize[0].size()); 00607 for( int i=0;i<nodeMinSize[0].size();i++ ) { 00608 SprDecisionTree* tree1 = 0; 00609 if( useTopdown ) 00610 tree1 = 00611 new SprTopdownTree(filter.get(),crit.get(), 00612 nodeMinSize[0][i], 00613 discrete,bootstrap.get()); 00614 else 00615 tree1 = 00616 new SprDecisionTree(filter.get(),crit.get(), 00617 nodeMinSize[0][i],doMerge, 00618 discrete,bootstrap.get()); 00619 if( abMode == SprTrainedAdaBoost::Real ) tree1->forceMixedNodes(); 00620 tree1->useFastSort(); 00621 SprAdaBoost* ab1 = new SprAdaBoost(filter.get(),cycles, 00622 useStandardAB,abMode,bagInput); 00623 cout << "Setting epsilon to " << epsilon << endl; 00624 ab1->setEpsilon(epsilon); 00625 if( !ab1->addTrainable(tree1,SprUtils::lowerBound(0.5)) ) { 00626 cerr << "Unable to add decision tree to AdaBoost for CV." << endl; 00627 for( int j=0;j<trees.size();j++ ) { 00628 delete trees[j]; 00629 delete classifiers[j]; 00630 } 00631 return 4; 00632 } 00633 trees[i] = tree1; 00634 classifiers[i] = ab1; 00635 } 00636 00637 // cross-validate 00638 vector<double> cvFom; 00639 SprCrossValidator cv(filter.get(),nCross); 00640 if( !cv.validate(crit.get(),loss.get(),classifiers, 00641 inputClasses[0],inputClasses[1], 00642 SprUtils::lowerBound(0.5),cvFom,verbose) ) { 00643 cerr << "Unable to cross-validate." << endl; 00644 for( int j=0;j<trees.size();j++ ) { 00645 delete trees[j]; 00646 delete classifiers[j]; 00647 } 00648 return 4; 00649 } 00650 else { 00651 cout << "Cross-validated FOMs:" << endl; 00652 for( int i=0;i<cvFom.size();i++ ) { 00653 cout << "Node size=" << setw(8) << nodeMinSize[0][i] 00654 << " FOM=" << setw(10) << cvFom[i] << endl; 00655 } 00656 } 00657 00658 // cleanup 00659 for( int j=0;j<trees.size();j++ ) { 00660 delete trees[j]; 00661 delete classifiers[j]; 00662 } 00663 00664 // normal exit 00665 return 0; 00666 }// end cross-validation 00667 00668 // make AdaBoost 00669 SprAdaBoost ab(filter.get(),cycles,useStandardAB,abMode,bagInput); 00670 cout << "Setting epsilon to " << epsilon << endl; 00671 ab.setEpsilon(epsilon); 00672 00673 // set validation 00674 if( valFilter.get()!=0 && !valFilter->empty() ) 00675 ab.setValidation(valFilter.get(),valPrint,loss.get()); 00676 00677 // read saved Boost from file 00678 if( !resumeFile.empty() ) { 00679 if( !SprClassifierReader::readTrainable(resumeFile.c_str(), 00680 &ab,verbose) ) { 00681 cerr << "Failed to read saved AdaBoost from file " 00682 << resumeFile.c_str() << endl; 00683 return 5; 00684 } 00685 cout << "Read saved AdaBoost from file " << resumeFile.c_str() 00686 << " with " << ab.nTrained() << " trained classifiers." << endl; 00687 } 00688 if( skipInitialEventReweighting ) ab.skipInitialEventReweighting(true); 00689 00690 // add trainable tree 00691 if( !ab.addTrainable(tree.get(),SprUtils::lowerBound(0.5)) ) { 00692 cerr << "Unable to add decision tree to AdaBoost." << endl; 00693 return 6; 00694 } 00695 00696 // train 00697 if( !ab.train(verbose) ) 00698 cerr << "AdaBoost terminated with error." << endl; 00699 if( ab.nTrained() == 0 ) { 00700 cerr << "Unable to train AdaBoost." << endl; 00701 return 7; 00702 } 00703 else { 00704 cout << "AdaBoost finished training with " << ab.nTrained() 00705 << " classifiers." << endl; 00706 } 00707 00708 // save trained AdaBoost 00709 if( !outFile.empty() ) { 00710 if( !ab.store(outFile.c_str()) ) { 00711 cerr << "Cannot store AdaBoost in file " << outFile.c_str() << endl; 00712 return 8; 00713 } 00714 } 00715 00716 // print out counted splits 00717 if( countTreeSplits ) tree->printSplitCounter(cout); 00718 00719 // save reweighted data 00720 if( !weightedDataOut.empty() ) { 00721 if( !ab.storeData(weightedDataOut.c_str()) ) { 00722 cerr << "Cannot store weighted AdaBoost data to file " 00723 << weightedDataOut.c_str() << endl; 00724 return 9; 00725 } 00726 } 00727 00728 // make a trained AdaBoost 00729 auto_ptr<SprTrainedAdaBoost> trainedAda(ab.makeTrained()); 00730 if( trainedAda.get() == 0 ) { 00731 cerr << "Unable to get trained AdaBoost." << endl; 00732 return 7; 00733 } 00734 00735 // store code into file 00736 if( !codeFile.empty() ) { 00737 if( !trainedAda->storeCode(codeFile.c_str()) ) { 00738 cerr << "Unable to store code for trained AdaBoost." << endl; 00739 return 8; 00740 } 00741 } 00742 00743 // make histogram if requested 00744 if( tupleFile.empty() ) 00745 return 0; 00746 00747 // make a writer 00748 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training")); 00749 if( !tuple->init(tupleFile.c_str()) ) { 00750 cerr << "Unable to open output file " << tupleFile.c_str() << endl; 00751 return 10; 00752 } 00753 00754 // determine if certain variables are to be excluded from usage, 00755 // but included in the output storage file (-Z option) 00756 string printVarsDoNotFeed; 00757 vector<vector<string> > varsDoNotFeed; 00758 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed); 00759 vector<unsigned> mapper; 00760 for( int d=0;d<vars.size();d++ ) { 00761 if( varsDoNotFeed.empty() || 00762 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d]) 00763 ==varsDoNotFeed[0].end()) ) { 00764 mapper.push_back(d); 00765 } 00766 else { 00767 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " ); 00768 printVarsDoNotFeed += vars[d]; 00769 } 00770 } 00771 if( !printVarsDoNotFeed.empty() ) { 00772 cout << "The following variables are not used in the algorithm, " 00773 << "but will be included in the output file: " 00774 << printVarsDoNotFeed.c_str() << endl; 00775 } 00776 00777 // feed 00778 SprDataFeeder feeder(filter.get(),tuple.get(),mapper); 00779 feeder.addClassifier(trainedAda.get(),"ada"); 00780 if( !feeder.feed(1000) ) { 00781 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl; 00782 return 11; 00783 } 00784 00785 // exit 00786 return 0; 00787 }