00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
00026 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
00027 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00028 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00029 #include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
00030 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00031 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00032 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00033 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00034 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00035 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00036 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00037 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00038
00039 #include <stdlib.h>
00040 #include <unistd.h>
00041 #include <iostream>
00042 #include <vector>
00043 #include <set>
00044 #include <string>
00045 #include <memory>
00046 #include <iomanip>
00047
00048 using namespace std;
00049
00050
00051 void help(const char* prog)
00052 {
00053 cout << "Usage: " << prog
00054 << " training_data_file" << endl;
00055 cout << "\t Options: " << endl;
00056 cout << "\t-h --- help " << endl;
00057 cout << "\t-j use regular tree instead of faster topdown tree " << endl;
00058 cout << "\t-k discrete decision tree output (default=continuous)"<< endl;
00059 cout << "\t-o output Tuple file " << endl;
00060 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00061 cout << "\t-A save output data in ascii instead of Root " << endl;
00062 cout << "\t-n number of Bagger training cycles " << endl;
00063 cout << "\t-l minimal number of entries per tree leaf (def=1) " << endl;
00064 cout << "\t-s max number of sampled features (def=0 no sampling)"<< endl;
00065 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00066 cout << "\t-Q apply variable transformation saved in file " << endl;
00067 cout << "\t-b use a version of Breiman's arc-x4 algorithm " << endl;
00068 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00069 cout << "\t-f store trained Bagger to file " << endl;
00070 cout << "\t-F generate code for AdaBoost and store to file " << endl;
00071 cout << "\t-c criterion for optimization " << endl;
00072 cout << "\t\t 1 = correctly classified fraction " << endl;
00073 cout << "\t\t 2 = signal significance s/sqrt(s+b) " << endl;
00074 cout << "\t\t 3 = purity s/(s+b) " << endl;
00075 cout << "\t\t 4 = tagger efficiency Q " << endl;
00076 cout << "\t\t 5 = Gini index (default) " << endl;
00077 cout << "\t\t 6 = cross-entropy " << endl;
00078 cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl;
00079 cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b)) " << endl;
00080 cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b)) " << endl;
00081 cout << "\t\t -P background normalization factor for Punzi FOM" << endl;
00082 cout << "\t-g per-event loss for (cross-)validation " << endl;
00083 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl;
00084 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl;
00085 cout << "\t\t 3 - misid fraction " << endl;
00086 cout << "\t-m replace data values below this cutoff with medians" << endl;
00087 cout << "\t-i count splits on input variables " << endl;
00088 cout << "\t-r resume training for Bagger stored in file " << endl;
00089 cout << "\t-K keep this fraction in training set and " << endl;
00090 cout << "\t\t put the rest into validation set " << endl;
00091 cout << "\t-D randomize training set split-up " << endl;
00092 cout << "\t-G generate seed from time of day for bootstrap " << endl;
00093 cout << "\t-t read validation/test data from a file " << endl;
00094 cout << "\t\t (must be in same format as input data!!! " << endl;
00095 cout << "\t-d frequency of print-outs for validation data " << endl;
00096 cout << "\t-w scale all signal weights by this factor " << endl;
00097 cout << "\t-V include only these input variables " << endl;
00098 cout << "\t-z exclude input variables from the list " << endl;
00099 cout << "\t-Z exclude input variables from the list, "
00100 << "but put them in the output file " << endl;
00101 cout << "\t\t Variables must be listed in quotes and separated by commas."
00102 << endl;
00103 cout << "\t-x cross-validate by splitting data into a given "
00104 << "number of pieces" << endl;
00105 cout << "\t-q a set of minimal node sizes for cross-validation" << endl;
00106 cout << "\t\t Node sizes must be listed in quotes and separated by commas."
00107 << endl;
00108 }
00109
00110
00111 int main(int argc, char ** argv)
00112 {
00113
00114 if( argc < 2 ) {
00115 help(argv[0]);
00116 return 1;
00117 }
00118
00119
00120 string tupleFile;
00121 int readMode = 0;
00122 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00123 unsigned cycles = 0;
00124 unsigned nmin = 1;
00125 int verbose = 0;
00126 string outFile;
00127 string codeFile;
00128 string resumeFile;
00129 int iCrit = 5;
00130 string valFile;
00131 unsigned valPrint = 0;
00132 bool scaleWeights = false;
00133 double sW = 1.;
00134 int nFeaturesToSample = 0;
00135 bool countTreeSplits = false;
00136 bool setLowCutoff = false;
00137 double lowCutoff = 0;
00138 string includeList, excludeList;
00139 unsigned nCross = 0;
00140 string nodeValidationString;
00141 bool useTopdown = true;
00142 bool discrete = false;
00143 int iLoss = 0;
00144 string inputClassesString;
00145 bool useArcE4 = false;
00146 double bW = 1.;
00147 string stringVarsDoNotFeed;
00148 bool split = false;
00149 double splitFactor = 0;
00150 bool splitRandomize = false;
00151 bool initBootstrapFromTimeOfDay = false;
00152 string transformerFile;
00153
00154
00155 int c;
00156 extern char* optarg;
00157
00158 while( (c = getopt(argc,argv,"hjko:a:An:l:s:y:Q:bv:f:F:c:P:g:m:ir:K:DGt:d:w:V:z:Z:x:q:"))
00159 != EOF ) {
00160 switch( c )
00161 {
00162 case 'h' :
00163 help(argv[0]);
00164 return 1;
00165 case 'j' :
00166 useTopdown = false;
00167 break;
00168 case 'k' :
00169 discrete = true;
00170 break;
00171 case 'o' :
00172 tupleFile = optarg;
00173 break;
00174 case 'a' :
00175 readMode = (optarg==0 ? 0 : atoi(optarg));
00176 break;
00177 case 'A' :
00178 writeMode = SprRWFactory::Ascii;
00179 break;
00180 case 'n' :
00181 cycles = (optarg==0 ? 1 : atoi(optarg));
00182 break;
00183 case 'l' :
00184 nmin = (optarg==0 ? 1 : atoi(optarg));
00185 break;
00186 case 's' :
00187 nFeaturesToSample = (optarg==0 ? 0 : atoi(optarg));
00188 break;
00189 case 'y' :
00190 inputClassesString = optarg;
00191 break;
00192 case 'Q' :
00193 transformerFile = optarg;
00194 break;
00195 case 'b' :
00196 useArcE4 = true;
00197 break;
00198 case 'v' :
00199 verbose = (optarg==0 ? 0 : atoi(optarg));
00200 break;
00201 case 'f' :
00202 outFile = optarg;
00203 break;
00204 case 'F' :
00205 codeFile = optarg;
00206 break;
00207 case 'c' :
00208 iCrit = (optarg==0 ? 5 : atoi(optarg));
00209 break;
00210 case 'P' :
00211 bW = (optarg==0 ? 1 : atof(optarg));
00212 break;
00213 case 'g' :
00214 iLoss = (optarg==0 ? 0 : atoi(optarg));
00215 break;
00216 case 'm' :
00217 if( optarg != 0 ) {
00218 setLowCutoff = true;
00219 lowCutoff = atof(optarg);
00220 }
00221 break;
00222 case 'i' :
00223 countTreeSplits = true;
00224 break;
00225 case 'r' :
00226 resumeFile = optarg;
00227 break;
00228 case 'K' :
00229 split = true;
00230 splitFactor = (optarg==0 ? 0 : atof(optarg));
00231 break;
00232 case 'D' :
00233 splitRandomize = true;
00234 break;
00235 case 'G' :
00236 initBootstrapFromTimeOfDay = true;
00237 break;
00238 case 't' :
00239 valFile = optarg;
00240 break;
00241 case 'd' :
00242 valPrint = (optarg==0 ? 0 : atoi(optarg));
00243 break;
00244 case 'w' :
00245 if( optarg != 0 ) {
00246 scaleWeights = true;
00247 sW = atof(optarg);
00248 }
00249 break;
00250 case 'V' :
00251 includeList = optarg;
00252 break;
00253 case 'z' :
00254 excludeList = optarg;
00255 break;
00256 case 'Z' :
00257 stringVarsDoNotFeed = optarg;
00258 break;
00259 case 'x' :
00260 nCross = (optarg==0 ? 0 : atoi(optarg));
00261 break;
00262 case 'q' :
00263 nodeValidationString = optarg;
00264 break;
00265 }
00266 }
00267
00268
00269 string trFile = argv[argc-1];
00270 if( trFile.empty() ) {
00271 cerr << "No training file is specified." << endl;
00272 return 1;
00273 }
00274
00275
00276 SprRWFactory::DataType inputType
00277 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00278 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00279
00280
00281 set<string> includeSet;
00282 if( !includeList.empty() ) {
00283 vector<vector<string> > includeVars;
00284 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00285 assert( !includeVars.empty() );
00286 for( int i=0;i<includeVars[0].size();i++ )
00287 includeSet.insert(includeVars[0][i]);
00288 if( !reader->chooseVars(includeSet) ) {
00289 cerr << "Unable to include variables in training set." << endl;
00290 return 2;
00291 }
00292 else {
00293 cout << "Following variables have been included in optimization: ";
00294 for( set<string>::const_iterator
00295 i=includeSet.begin();i!=includeSet.end();i++ )
00296 cout << "\"" << *i << "\"" << " ";
00297 cout << endl;
00298 }
00299 }
00300
00301
00302 set<string> excludeSet;
00303 if( !excludeList.empty() ) {
00304 vector<vector<string> > excludeVars;
00305 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00306 assert( !excludeVars.empty() );
00307 for( int i=0;i<excludeVars[0].size();i++ )
00308 excludeSet.insert(excludeVars[0][i]);
00309 if( !reader->chooseAllBut(excludeSet) ) {
00310 cerr << "Unable to exclude variables from training set." << endl;
00311 return 2;
00312 }
00313 else {
00314 cout << "Following variables have been excluded from optimization: ";
00315 for( set<string>::const_iterator
00316 i=excludeSet.begin();i!=excludeSet.end();i++ )
00317 cout << "\"" << *i << "\"" << " ";
00318 cout << endl;
00319 }
00320 }
00321
00322
00323 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00324 if( filter.get() == 0 ) {
00325 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00326 return 2;
00327 }
00328 vector<string> vars;
00329 filter->vars(vars);
00330 cout << "Read data from file " << trFile.c_str()
00331 << " for variables";
00332 for( int i=0;i<vars.size();i++ )
00333 cout << " \"" << vars[i].c_str() << "\"";
00334 cout << endl;
00335 cout << "Total number of points read: " << filter->size() << endl;
00336
00337
00338 vector<SprClass> inputClasses;
00339 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00340 cerr << "Cannot choose input classes for string "
00341 << inputClassesString << endl;
00342 return 2;
00343 }
00344 filter->classes(inputClasses);
00345 assert( inputClasses.size() > 1 );
00346 cout << "Training data filtered by class." << endl;
00347 for( int i=0;i<inputClasses.size();i++ ) {
00348 cout << "Points in class " << inputClasses[i] << ": "
00349 << filter->ptsInClass(inputClasses[i]) << endl;
00350 }
00351
00352
00353 if( scaleWeights ) {
00354 cout << "Signal weights are multiplied by " << sW << endl;
00355 filter->scaleWeights(inputClasses[1],sW);
00356 }
00357
00358
00359 if( setLowCutoff ) {
00360 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00361 cerr << "Unable to replace missing values in training data." << endl;
00362 return 2;
00363 }
00364 else
00365 cout << "Values below " << lowCutoff << " in training data"
00366 << " have been replaced with medians." << endl;
00367 }
00368
00369
00370 auto_ptr<SprAbsFilter> valFilter;
00371 if( split && !valFile.empty() ) {
00372 cerr << "Unable to split training data and use validation data "
00373 << "from a separate file." << endl;
00374 return 2;
00375 }
00376 if( split && valPrint!=0 ) {
00377 cout << "Splitting training data with factor " << splitFactor << endl;
00378 if( splitRandomize )
00379 cout << "Will use randomized splitting." << endl;
00380 vector<double> weights;
00381 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00382 if( splitted == 0 ) {
00383 cerr << "Unable to split training data." << endl;
00384 return 2;
00385 }
00386 bool ownData = true;
00387 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00388 cout << "Training data re-filtered:" << endl;
00389 for( int i=0;i<inputClasses.size();i++ ) {
00390 cout << "Points in class " << inputClasses[i] << ": "
00391 << filter->ptsInClass(inputClasses[i]) << endl;
00392 }
00393 }
00394 if( !valFile.empty() && valPrint!=0 ) {
00395 auto_ptr<SprAbsReader>
00396 valReader(SprRWFactory::makeReader(inputType,readMode));
00397 if( !includeSet.empty() ) {
00398 if( !valReader->chooseVars(includeSet) ) {
00399 cerr << "Unable to include variables in validation set." << endl;
00400 return 2;
00401 }
00402 }
00403 if( !excludeSet.empty() ) {
00404 if( !valReader->chooseAllBut(excludeSet) ) {
00405 cerr << "Unable to exclude variables from validation set." << endl;
00406 return 2;
00407 }
00408 }
00409 valFilter.reset(valReader->read(valFile.c_str()));
00410 if( valFilter.get() == 0 ) {
00411 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00412 return 2;
00413 }
00414 vector<string> valVars;
00415 valFilter->vars(valVars);
00416 cout << "Read validation data from file " << valFile.c_str()
00417 << " for variables";
00418 for( int i=0;i<valVars.size();i++ )
00419 cout << " \"" << valVars[i].c_str() << "\"";
00420 cout << endl;
00421 cout << "Total number of points read: " << valFilter->size() << endl;
00422 }
00423
00424
00425 if( valFilter.get() != 0 ) {
00426 if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00427 cerr << "Cannot choose input classes for string "
00428 << inputClassesString << endl;
00429 return 2;
00430 }
00431 valFilter->classes(inputClasses);
00432 cout << "Validation data filtered by class." << endl;
00433 for( int i=0;i<inputClasses.size();i++ ) {
00434 cout << "Points in class " << inputClasses[i] << ": "
00435 << valFilter->ptsInClass(inputClasses[i]) << endl;
00436 }
00437 }
00438
00439
00440 if( scaleWeights && valFilter.get()!=0 )
00441 valFilter->scaleWeights(inputClasses[1],sW);
00442
00443
00444 if( setLowCutoff && valFilter.get()!=0 ) {
00445 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00446 cerr << "Unable to replace missing values in validation data." << endl;
00447 return 2;
00448 }
00449 else
00450 cout << "Values below " << lowCutoff << " in validation data"
00451 << " have been replaced with medians." << endl;
00452 }
00453
00454
00455 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00456 if( !transformerFile.empty() ) {
00457 SprVarTransformerReader transReader;
00458 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00459 if( t == 0 ) {
00460 cerr << "Unable to read VarTransformer from file "
00461 << transformerFile.c_str() << endl;
00462 return 2;
00463 }
00464 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00465 SprTransformerFilter* t_valid = 0;
00466 if( valFilter.get() != 0 )
00467 t_valid = new SprTransformerFilter(valFilter.get());
00468 bool replaceOriginalData = true;
00469 if( !t_train->transform(t,replaceOriginalData) ) {
00470 cerr << "Unable to apply VarTransformer to training data." << endl;
00471 return 2;
00472 }
00473 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00474 cerr << "Unable to apply VarTransformer to validation data." << endl;
00475 return 2;
00476 }
00477 cout << "Variable transformation from file "
00478 << transformerFile.c_str() << " has been applied to "
00479 << "training and validation data." << endl;
00480 garbage_train.reset(filter.release());
00481 garbage_valid.reset(valFilter.release());
00482 filter.reset(t_train);
00483 valFilter.reset(t_valid);
00484 }
00485
00486
00487 auto_ptr<SprAbsTwoClassCriterion> crit;
00488 switch( iCrit )
00489 {
00490 case 1 :
00491 crit.reset(new SprTwoClassIDFraction);
00492 cout << "Optimization criterion set to "
00493 << "Fraction of correctly classified events " << endl;
00494 break;
00495 case 2 :
00496 crit.reset(new SprTwoClassSignalSignif);
00497 cout << "Optimization criterion set to "
00498 << "Signal significance S/sqrt(S+B) " << endl;
00499 break;
00500 case 3 :
00501 crit.reset(new SprTwoClassPurity);
00502 cout << "Optimization criterion set to "
00503 << "Purity S/(S+B) " << endl;
00504 break;
00505 case 4 :
00506 crit.reset(new SprTwoClassTaggerEff);
00507 cout << "Optimization criterion set to "
00508 << "Tagging efficiency Q = e*(1-2w)^2 " << endl;
00509 break;
00510 case 5 :
00511 crit.reset(new SprTwoClassGiniIndex);
00512 cout << "Optimization criterion set to "
00513 << "Gini index -1+p^2+q^2 " << endl;
00514 break;
00515 case 6 :
00516 crit.reset(new SprTwoClassCrossEntropy);
00517 cout << "Optimization criterion set to "
00518 << "Cross-entropy p*log(p)+q*log(q) " << endl;
00519 break;
00520 case 7 :
00521 crit.reset(new SprTwoClassUniformPriorUL90);
00522 cout << "Optimization criterion set to "
00523 << "Inverse of 90% Bayesian upper limit with uniform prior" << endl;
00524 break;
00525 case 8 :
00526 crit.reset(new SprTwoClassBKDiscovery);
00527 cout << "Optimization criterion set to "
00528 << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl;
00529 break;
00530 case 9 :
00531 crit.reset(new SprTwoClassPunzi(bW));
00532 cout << "Optimization criterion set to "
00533 << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl;
00534 break;
00535 default :
00536 cerr << "Unable to make initialization criterion." << endl;
00537 return 3;
00538 }
00539
00540
00541 if( useArcE4 && !crit->symmetric() ) {
00542 cerr << "Unable to use arc-e4 with an asymmetric criterion." << endl;
00543 return 3;
00544 }
00545
00546
00547 auto_ptr<SprAverageLoss> loss;
00548 switch( iLoss )
00549 {
00550 case 1 :
00551 loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00552 cout << "Per-event loss set to "
00553 << "Quadratic loss (y-f(x))^2 " << endl;
00554 break;
00555 case 2 :
00556 loss.reset(new SprAverageLoss(&SprLoss::purity_ratio));
00557 cout << "Per-event loss set to "
00558 << "Exponential loss exp(-y*f(x)) " << endl;
00559 break;
00560 case 3 :
00561 loss.reset(new SprAverageLoss(&SprLoss::correct_id,
00562 &SprTransformation::continuous01ToDiscrete01));
00563 cout << "Per-event loss set to "
00564 << "Misid rate int(y==f(x)) " << endl;
00565 break;
00566 default :
00567 cout << "No per-event loss is chosen. Will use the default." << endl;
00568 break;
00569 }
00570
00571
00572 auto_ptr<SprIntegerBootstrap> bootstrap;
00573 if( nFeaturesToSample > filter->dim() )
00574 nFeaturesToSample = filter->dim();
00575 if( nFeaturesToSample > 0 ) {
00576 bootstrap.reset(new SprIntegerBootstrap(filter->dim(),nFeaturesToSample));
00577 if( !resumeFile.empty() || initBootstrapFromTimeOfDay )
00578 bootstrap->init(-1);
00579 }
00580
00581
00582 bool doMerge = !crit->symmetric();
00583 if( doMerge ) useTopdown = false;
00584 auto_ptr<SprDecisionTree> tree;
00585 if( useTopdown ) {
00586 tree.reset(new SprTopdownTree(filter.get(),crit.get(),nmin,
00587 discrete,bootstrap.get()));
00588 }
00589 else {
00590 tree.reset(new SprDecisionTree(filter.get(),crit.get(),nmin,doMerge,
00591 discrete,bootstrap.get()));
00592 }
00593 if( countTreeSplits ) tree->startSplitCounter();
00594 tree->useFastSort();
00595
00596
00597 if( nCross > 0 ) {
00598
00599 cout << "Will cross-validate by dividing training data into "
00600 << nCross << " subsamples." << endl;
00601 vector<vector<int> > nodeMinSize;
00602
00603
00604 if( !nodeValidationString.empty() )
00605 SprStringParser::parseToInts(nodeValidationString.c_str(),nodeMinSize);
00606 else {
00607 nodeMinSize.resize(1);
00608 nodeMinSize[0].push_back(nmin);
00609 }
00610 if( nodeMinSize.empty() || nodeMinSize[0].empty() ) {
00611 cerr << "Unable to determine node size for cross-validation." << endl;
00612 return 4;
00613 }
00614 else {
00615 cout << "Will cross-validate for trees with minimal node sizes: ";
00616 for( int i=0;i<nodeMinSize[0].size();i++ )
00617 cout << nodeMinSize[0][i] << " ";
00618 cout << endl;
00619 }
00620
00621
00622 vector<SprDecisionTree*> trees(nodeMinSize[0].size());
00623 vector<SprAbsClassifier*> classifiers(nodeMinSize[0].size());
00624 for( int i=0;i<nodeMinSize[0].size();i++ ) {
00625 SprDecisionTree* tree1 = 0;
00626 if( useTopdown ) {
00627 tree1 = new SprTopdownTree(filter.get(),crit.get(),nodeMinSize[0][i],
00628 discrete,bootstrap.get());
00629 }
00630 else {
00631 tree1 = new SprDecisionTree(filter.get(),crit.get(),nodeMinSize[0][i],
00632 doMerge,discrete,bootstrap.get());
00633 }
00634 tree1->useFastSort();
00635 SprBagger* bagger1 = 0;
00636 if( useArcE4 )
00637 bagger1 = new SprArcE4(filter.get(),cycles,discrete);
00638 else
00639 bagger1 = new SprBagger(filter.get(),cycles,discrete);
00640 if( initBootstrapFromTimeOfDay
00641 && !bagger1->initBootstrapFromTimeOfDay() ) {
00642 cerr << "Unable to generate seed from time of day for Bagger." << endl;
00643 return 4;
00644 }
00645 if( !bagger1->addTrainable(tree1) ) {
00646 cerr << "Unable to add decision tree to Bagger for CV." << endl;
00647 for( int j=0;j<trees.size();j++ ) {
00648 delete trees[j];
00649 delete classifiers[j];
00650 }
00651 return 4;
00652 }
00653 trees[i] = tree1;
00654 classifiers[i] = bagger1;
00655 }
00656
00657
00658 vector<double> cvFom;
00659 SprCrossValidator cv(filter.get(),nCross);
00660 if( !cv.validate(crit.get(),loss.get(),classifiers,
00661 inputClasses[0],inputClasses[1],
00662 SprUtils::lowerBound(0.5),cvFom,verbose) ) {
00663 cerr << "Unable to cross-validate." << endl;
00664 for( int j=0;j<trees.size();j++ ) {
00665 delete trees[j];
00666 delete classifiers[j];
00667 }
00668 return 4;
00669 }
00670 else {
00671 cout << "Cross-validated FOMs:" << endl;
00672 for( int i=0;i<cvFom.size();i++ ) {
00673 cout << "Node size=" << setw(8) << nodeMinSize[0][i]
00674 << " FOM=" << setw(10) << cvFom[i] << endl;
00675 }
00676 }
00677
00678
00679 for( int j=0;j<trees.size();j++ ) {
00680 delete trees[j];
00681 delete classifiers[j];
00682 }
00683
00684
00685 return 0;
00686 }
00687
00688
00689 auto_ptr<SprBagger> bagger;
00690 if( useArcE4 )
00691 bagger.reset(new SprArcE4(filter.get(),cycles,discrete));
00692 else
00693 bagger.reset(new SprBagger(filter.get(),cycles,discrete));
00694
00695
00696 if( initBootstrapFromTimeOfDay && !bagger->initBootstrapFromTimeOfDay() ) {
00697 cerr << "Unable to generate seed from time of day for Bagger." << endl;
00698 return 4;
00699 }
00700
00701
00702 if( valFilter.get()!=0 && !valFilter->empty() )
00703 bagger->setValidation(valFilter.get(),valPrint,crit.get(),loss.get());
00704
00705
00706 if( !resumeFile.empty() ) {
00707 if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00708 bagger.get(),verbose) ) {
00709 cerr << "Failed to read saved Bagger from file "
00710 << resumeFile.c_str() << endl;
00711 return 5;
00712 }
00713 cout << "Read saved Bagger from file " << resumeFile.c_str()
00714 << " with " << bagger->nTrained() << " trained classifiers."
00715 << endl;
00716 }
00717
00718
00719 if( !bagger->addTrainable(tree.get()) ) {
00720 cerr << "Unable to add decision tree to Bagger." << endl;
00721 return 6;
00722 }
00723
00724
00725 if( !bagger->train(verbose) )
00726 cerr << "Bagger terminated with error." << endl;
00727 if( bagger->nTrained() == 0 ) {
00728 cerr << "Unable to train Bagger." << endl;
00729 return 7;
00730 }
00731 else {
00732 cout << "Bagger finished training with " << bagger->nTrained()
00733 << " classifiers." << endl;
00734 }
00735
00736
00737 if( !outFile.empty() ) {
00738 if( !bagger->store(outFile.c_str()) ) {
00739 cerr << "Cannot store Bagger in file " << outFile.c_str() << endl;
00740 return 8;
00741 }
00742 }
00743
00744
00745 if( countTreeSplits ) tree->printSplitCounter(cout);
00746
00747
00748 auto_ptr<SprTrainedBagger> trainedBagger(bagger->makeTrained());
00749 if( trainedBagger.get() == 0 ) {
00750 cerr << "Unable to get trained Bagger." << endl;
00751 return 7;
00752 }
00753
00754
00755 if( !codeFile.empty() ) {
00756 if( !trainedBagger->storeCode(codeFile.c_str()) ) {
00757 cerr << "Unable to store code for trained Bagger." << endl;
00758 return 8;
00759 }
00760 }
00761
00762
00763 if( tupleFile.empty() )
00764 return 0;
00765
00766
00767 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00768 if( !tuple->init(tupleFile.c_str()) ) {
00769 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00770 return 9;
00771 }
00772
00773
00774
00775 string printVarsDoNotFeed;
00776 vector<vector<string> > varsDoNotFeed;
00777 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00778 vector<unsigned> mapper;
00779 for( int d=0;d<vars.size();d++ ) {
00780 if( varsDoNotFeed.empty() ||
00781 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00782 ==varsDoNotFeed[0].end()) ) {
00783 mapper.push_back(d);
00784 }
00785 else {
00786 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00787 printVarsDoNotFeed += vars[d];
00788 }
00789 }
00790 if( !printVarsDoNotFeed.empty() ) {
00791 cout << "The following variables are not used in the algorithm, "
00792 << "but will be included in the output file: "
00793 << printVarsDoNotFeed.c_str() << endl;
00794 }
00795
00796
00797 SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00798 feeder.addClassifier(trainedBagger.get(),"bag");
00799 if( !feeder.feed(1000) ) {
00800 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00801 return 10;
00802 }
00803
00804
00805 return 0;
00806 }