#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <iomanip>
#include <fstream>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void help | ( | const char * | prog | ) |
Definition at line 46 of file SprDecisionTreeApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00047 { 00048 cout << "Usage: " << prog 00049 << " training_data_file" << endl; 00050 cout << "\t Options: " << endl; 00051 cout << "\t-h --- help " << endl; 00052 cout << "\t-o output Tuple file " << endl; 00053 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00054 cout << "\t-A save output data in ascii instead of Root " << endl; 00055 cout << "\t-n minimal number of events per tree node (def=1) " << endl; 00056 cout << "\t-m --- merge nodes after training (def = no merge) " << endl; 00057 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl; 00058 cout << "\t-Q apply variable transformation saved in file " << endl; 00059 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00060 cout << "\t-T use Topdown tree with continuous output " << endl; 00061 cout << "\t-f store decision tree to file in human-readable format" << endl; 00062 cout << "\t-F store decision tree to file in machine-readable format"<< endl; 00063 cout << "\t-c criterion for optimization " << endl; 00064 cout << "\t\t 1 = correctly classified fraction " << endl; 00065 cout << "\t\t 2 = signal significance s/sqrt(s+b) " << endl; 00066 cout << "\t\t 3 = purity s/(s+b) " << endl; 00067 cout << "\t\t 4 = tagger efficiency Q " << endl; 00068 cout << "\t\t 5 = Gini index (default) " << endl; 00069 cout << "\t\t 6 = cross-entropy " << endl; 00070 cout << "\t\t 7 = 90% Bayesian upper limit with uniform prior " << endl; 00071 cout << "\t\t 8 = discovery potential 2*(sqrt(s+b)-sqrt(b)) " << endl; 00072 cout << "\t\t 9 = Punzi's sensitivity s/(0.5*nSigma+sqrt(b)) " << endl; 00073 cout << "\t\t -P background normalization factor for Punzi FOM" << endl; 00074 cout << "\t-g per-event loss for (cross-)validation " << endl; 00075 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl; 00076 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl; 00077 cout << "\t\t 3 - misid fraction " << endl; 00078 cout << "\t-i count splits on input variables " << endl; 00079 cout << "\t-K keep this fraction in training set and " << endl; 00080 cout << "\t\t put the rest into validation set " << endl; 00081 cout << "\t-D randomize training set split-up " << endl; 00082 cout << "\t-t read validation/test data from a file " << endl; 00083 cout << "\t\t (must be in same format as input data!!! " << endl; 00084 cout << "\t-p output file to store validation/test data " << endl; 00085 cout << "\t-w scale all signal weights by this factor " << endl; 00086 cout << "\t-V include only these input variables " << endl; 00087 cout << "\t-z exclude input variables from the list " << endl; 00088 cout << "\t\t Variables must be listed in quotes and separated by commas." 00089 << endl; 00090 cout << "\t-x cross-validate by splitting data into a given " 00091 << "number of pieces" << endl; 00092 cout << "\t-q a set of minimal node sizes for cross-validation" << endl; 00093 cout << "\t\t Node sizes must be listed in quotes and separated by commas." 00094 << endl; 00095 }
Definition at line 98 of file SprDecisionTreeApp.cc.
References c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, lat::endl(), filter, help(), i, j, p, size, split, t, tree, v, vars, w, and weights.
00099 { 00100 // check command line 00101 if( argc < 2 ) { 00102 help(argv[0]); 00103 return 1; 00104 } 00105 00106 // init 00107 string tupleFile; 00108 int readMode = 0; 00109 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00110 unsigned nmin = 1; 00111 int verbose = 0; 00112 bool useTopdown = false; 00113 string outHuman, outMachine; 00114 string resumeFile; 00115 int iCrit = 5; 00116 string valFile; 00117 string valHbkFile; 00118 bool doMerge = false; 00119 int iLoss = 0; 00120 bool scaleWeights = false; 00121 double sW = 1.; 00122 bool countTreeSplits = false; 00123 string includeList, excludeList; 00124 unsigned nCross = 0; 00125 string nodeValidationString; 00126 string inputClassesString; 00127 double bW = 1.; 00128 bool split = false; 00129 double splitFactor = 0; 00130 bool splitRandomize = false; 00131 string transformerFile; 00132 00133 // decode command line 00134 int c; 00135 extern char* optarg; 00136 while( (c = getopt(argc,argv,"ho:a:An:v:f:TF:c:P:g:iK:Dt:p:my:Q:w:V:z:x:q:")) != EOF ) { 00137 switch( c ) 00138 { 00139 case 'h' : 00140 help(argv[0]); 00141 return 1; 00142 case 'o' : 00143 tupleFile = optarg; 00144 break; 00145 case 'a' : 00146 readMode = (optarg==0 ? 0 : atoi(optarg)); 00147 break; 00148 case 'A' : 00149 writeMode = SprRWFactory::Ascii; 00150 break; 00151 case 'n' : 00152 nmin = (optarg==0 ? 1 : atoi(optarg)); 00153 break; 00154 case 'v' : 00155 verbose = (optarg==0 ? 0 : atoi(optarg)); 00156 break; 00157 case 'T' : 00158 useTopdown = true; 00159 break; 00160 case 'f' : 00161 outHuman = optarg; 00162 break; 00163 case 'F' : 00164 outMachine = optarg; 00165 break; 00166 case 'c' : 00167 iCrit = (optarg==0 ? 5 : atoi(optarg)); 00168 break; 00169 case 'P' : 00170 bW = (optarg==0 ? 1 : atof(optarg)); 00171 break; 00172 case 'g' : 00173 iLoss = (optarg==0 ? 0 : atoi(optarg)); 00174 break; 00175 case 'i' : 00176 countTreeSplits = true; 00177 break; 00178 case 'K' : 00179 split = true; 00180 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00181 break; 00182 case 'D' : 00183 splitRandomize = true; 00184 break; 00185 case 't' : 00186 valFile = optarg; 00187 break; 00188 case 'p' : 00189 valHbkFile = optarg; 00190 break; 00191 case 'm' : 00192 doMerge = true; 00193 break; 00194 case 'y' : 00195 inputClassesString = optarg; 00196 break; 00197 case 'Q' : 00198 transformerFile = optarg; 00199 break; 00200 case 'w' : 00201 if( optarg != 0 ) { 00202 scaleWeights = true; 00203 sW = atof(optarg); 00204 } 00205 break; 00206 case 'V' : 00207 includeList = optarg; 00208 break; 00209 case 'z' : 00210 excludeList = optarg; 00211 break; 00212 case 'x' : 00213 nCross = (optarg==0 ? 0 : atoi(optarg)); 00214 break; 00215 case 'q' : 00216 nodeValidationString = optarg; 00217 break; 00218 } 00219 } 00220 00221 // There has to be 1 argument after all options. 00222 string trFile = argv[argc-1]; 00223 if( trFile.empty() ) { 00224 cerr << "No training file is specified." << endl; 00225 return 1; 00226 } 00227 00228 // cannot merge nodes in Topdown trees 00229 if( doMerge ) useTopdown = false; 00230 00231 // make reader 00232 SprRWFactory::DataType inputType 00233 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00234 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00235 00236 // include variables 00237 set<string> includeSet; 00238 if( !includeList.empty() ) { 00239 vector<vector<string> > includeVars; 00240 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00241 assert( !includeVars.empty() ); 00242 for( int i=0;i<includeVars[0].size();i++ ) 00243 includeSet.insert(includeVars[0][i]); 00244 if( !reader->chooseVars(includeSet) ) { 00245 cerr << "Unable to include variables in training set." << endl; 00246 return 2; 00247 } 00248 else { 00249 cout << "Following variables have been included in optimization: "; 00250 for( set<string>::const_iterator 00251 i=includeSet.begin();i!=includeSet.end();i++ ) 00252 cout << "\"" << *i << "\"" << " "; 00253 cout << endl; 00254 } 00255 } 00256 00257 // exclude variables 00258 set<string> excludeSet; 00259 if( !excludeList.empty() ) { 00260 vector<vector<string> > excludeVars; 00261 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00262 assert( !excludeVars.empty() ); 00263 for( int i=0;i<excludeVars[0].size();i++ ) 00264 excludeSet.insert(excludeVars[0][i]); 00265 if( !reader->chooseAllBut(excludeSet) ) { 00266 cerr << "Unable to exclude variables from training set." << endl; 00267 return 2; 00268 } 00269 else { 00270 cout << "Following variables have been excluded from optimization: "; 00271 for( set<string>::const_iterator 00272 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00273 cout << "\"" << *i << "\"" << " "; 00274 cout << endl; 00275 } 00276 } 00277 00278 // read training data from file 00279 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00280 if( filter.get() == 0 ) { 00281 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00282 return 2; 00283 } 00284 vector<string> vars; 00285 filter->vars(vars); 00286 cout << "Read data from file " << trFile.c_str() 00287 << " for variables"; 00288 for( int i=0;i<vars.size();i++ ) 00289 cout << " \"" << vars[i].c_str() << "\""; 00290 cout << endl; 00291 cout << "Total number of points read: " << filter->size() << endl; 00292 00293 // filter training data by class 00294 vector<SprClass> inputClasses; 00295 if( !filter->filterByClass(inputClassesString.c_str()) ) { 00296 cerr << "Cannot choose input classes for string " 00297 << inputClassesString << endl; 00298 return 2; 00299 } 00300 filter->classes(inputClasses); 00301 assert( inputClasses.size() > 1 ); 00302 cout << "Training data filtered by class." << endl; 00303 for( int i=0;i<inputClasses.size();i++ ) { 00304 cout << "Points in class " << inputClasses[i] << ": " 00305 << filter->ptsInClass(inputClasses[i]) << endl; 00306 } 00307 00308 // scale weights 00309 if( scaleWeights ) { 00310 cout << "Signal weights are multiplied by " << sW << endl; 00311 filter->scaleWeights(inputClasses[1],sW); 00312 } 00313 00314 // read validation data from file 00315 auto_ptr<SprAbsFilter> valFilter; 00316 if( split && !valFile.empty() ) { 00317 cerr << "Unable to split training data and use validation data " 00318 << "from a separate file." << endl; 00319 return 2; 00320 } 00321 if( split ) { 00322 cout << "Splitting training data with factor " << splitFactor << endl; 00323 if( splitRandomize ) 00324 cout << "Will use randomized splitting." << endl; 00325 vector<double> weights; 00326 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00327 if( splitted == 0 ) { 00328 cerr << "Unable to split training data." << endl; 00329 return 2; 00330 } 00331 bool ownData = true; 00332 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00333 cout << "Training data re-filtered:" << endl; 00334 for( int i=0;i<inputClasses.size();i++ ) { 00335 cout << "Points in class " << inputClasses[i] << ": " 00336 << filter->ptsInClass(inputClasses[i]) << endl; 00337 } 00338 } 00339 if( !valFile.empty() ) { 00340 auto_ptr<SprAbsReader> 00341 valReader(SprRWFactory::makeReader(inputType,readMode)); 00342 if( !includeSet.empty() ) { 00343 if( !valReader->chooseVars(includeSet) ) { 00344 cerr << "Unable to include variables in validation set." << endl; 00345 return 2; 00346 } 00347 } 00348 if( !excludeSet.empty() ) { 00349 if( !valReader->chooseAllBut(excludeSet) ) { 00350 cerr << "Unable to exclude variables from validation set." << endl; 00351 return 2; 00352 } 00353 } 00354 valFilter.reset(valReader->read(valFile.c_str())); 00355 if( valFilter.get() == 0 ) { 00356 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00357 return 2; 00358 } 00359 vector<string> valVars; 00360 valFilter->vars(valVars); 00361 cout << "Read validation data from file " << valFile.c_str() 00362 << " for variables"; 00363 for( int i=0;i<valVars.size();i++ ) 00364 cout << " \"" << valVars[i].c_str() << "\""; 00365 cout << endl; 00366 cout << "Total number of points read: " << valFilter->size() << endl; 00367 } 00368 00369 // filter validation data by class 00370 if( valFilter.get() != 0 ) { 00371 if( !valFilter->filterByClass(inputClassesString.c_str()) ) { 00372 cerr << "Cannot choose input classes for string " 00373 << inputClassesString << endl; 00374 return 2; 00375 } 00376 valFilter->classes(inputClasses); 00377 cout << "Validation data filtered by class." << endl; 00378 for( int i=0;i<inputClasses.size();i++ ) { 00379 cout << "Points in class " << inputClasses[i] << ": " 00380 << valFilter->ptsInClass(inputClasses[i]) << endl; 00381 } 00382 } 00383 00384 // scale weights 00385 if( scaleWeights && valFilter.get()!=0 ) 00386 valFilter->scaleWeights(inputClasses[1],sW); 00387 00388 // apply transformation of variables to training and test data 00389 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00390 if( !transformerFile.empty() ) { 00391 SprVarTransformerReader transReader; 00392 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00393 if( t == 0 ) { 00394 cerr << "Unable to read VarTransformer from file " 00395 << transformerFile.c_str() << endl; 00396 return 2; 00397 } 00398 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00399 SprTransformerFilter* t_valid = 0; 00400 if( valFilter.get() != 0 ) 00401 t_valid = new SprTransformerFilter(valFilter.get()); 00402 bool replaceOriginalData = true; 00403 if( !t_train->transform(t,replaceOriginalData) ) { 00404 cerr << "Unable to apply VarTransformer to training data." << endl; 00405 return 2; 00406 } 00407 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00408 cerr << "Unable to apply VarTransformer to validation data." << endl; 00409 return 2; 00410 } 00411 cout << "Variable transformation from file " 00412 << transformerFile.c_str() << " has been applied to " 00413 << "training and validation data." << endl; 00414 garbage_train.reset(filter.release()); 00415 garbage_valid.reset(valFilter.release()); 00416 filter.reset(t_train); 00417 valFilter.reset(t_valid); 00418 } 00419 00420 // make optimization criterion 00421 auto_ptr<SprAbsTwoClassCriterion> crit; 00422 switch( iCrit ) 00423 { 00424 case 1 : 00425 crit.reset(new SprTwoClassIDFraction); 00426 cout << "Optimization criterion set to " 00427 << "Fraction of correctly classified events " << endl; 00428 break; 00429 case 2 : 00430 crit.reset(new SprTwoClassSignalSignif); 00431 cout << "Optimization criterion set to " 00432 << "Signal significance S/sqrt(S+B) " << endl; 00433 break; 00434 case 3 : 00435 crit.reset(new SprTwoClassPurity); 00436 cout << "Optimization criterion set to " 00437 << "Purity S/(S+B) " << endl; 00438 break; 00439 case 4 : 00440 crit.reset(new SprTwoClassTaggerEff); 00441 cout << "Optimization criterion set to " 00442 << "Tagging efficiency Q = e*(1-2w)^2 " << endl; 00443 break; 00444 case 5 : 00445 crit.reset(new SprTwoClassGiniIndex); 00446 cout << "Optimization criterion set to " 00447 << "Gini index -1+p^2+q^2 " << endl; 00448 break; 00449 case 6 : 00450 crit.reset(new SprTwoClassCrossEntropy); 00451 cout << "Optimization criterion set to " 00452 << "Cross-entropy p*log(p)+q*log(q) " << endl; 00453 break; 00454 case 7 : 00455 crit.reset(new SprTwoClassUniformPriorUL90); 00456 cout << "Optimization criterion set to " 00457 << "Inverse of 90% Bayesian upper limit with uniform prior" << endl; 00458 break; 00459 case 8 : 00460 crit.reset(new SprTwoClassBKDiscovery); 00461 cout << "Optimization criterion set to " 00462 << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl; 00463 break; 00464 case 9 : 00465 crit.reset(new SprTwoClassPunzi(bW)); 00466 cout << "Optimization criterion set to " 00467 << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl; 00468 break; 00469 default : 00470 cerr << "Unable to make initialization criterion." << endl; 00471 return 3; 00472 } 00473 00474 // make per-event loss 00475 auto_ptr<SprAverageLoss> loss; 00476 switch( iLoss ) 00477 { 00478 case 1 : 00479 loss.reset(new SprAverageLoss(&SprLoss::quadratic)); 00480 cout << "Per-event loss set to " 00481 << "Quadratic loss (y-f(x))^2 " << endl; 00482 break; 00483 case 2 : 00484 loss.reset(new SprAverageLoss(&SprLoss::purity_ratio)); 00485 cout << "Per-event loss set to " 00486 << "Exponential loss exp(-y*f(x)) " << endl; 00487 break; 00488 case 3 : 00489 loss.reset(new SprAverageLoss(&SprLoss::correct_id, 00490 &SprTransformation::continuous01ToDiscrete01)); 00491 cout << "Per-event loss set to " 00492 << "Misid rate int(y==f(x)) " << endl; 00493 break; 00494 default : 00495 cout << "No per-event loss is chosen. Will use the default." << endl; 00496 break; 00497 } 00498 00499 // if cross-validation requested, cross-validate and exit 00500 if( nCross > 0 ) { 00501 // message 00502 cout << "Will cross-validate by dividing training data into " 00503 << nCross << " subsamples." << endl; 00504 vector<vector<int> > nodeMinSize; 00505 00506 // decode validation string 00507 if( !nodeValidationString.empty() ) 00508 SprStringParser::parseToInts(nodeValidationString.c_str(),nodeMinSize); 00509 else { 00510 nodeMinSize.resize(1); 00511 nodeMinSize[0].push_back(nmin); 00512 } 00513 if( nodeMinSize.empty() || nodeMinSize[0].empty() ) { 00514 cerr << "Unable to determine node size for cross-validation." << endl; 00515 return 4; 00516 } 00517 else { 00518 cout << "Will cross-validate for trees with minimal node sizes: "; 00519 for( int i=0;i<nodeMinSize[0].size();i++ ) 00520 cout << nodeMinSize[0][i] << " "; 00521 cout << endl; 00522 } 00523 00524 // loop over nodes to prepare classifiers 00525 vector<SprAbsClassifier*> classifiers(nodeMinSize[0].size()); 00526 for( int i=0;i<nodeMinSize[0].size();i++ ) { 00527 SprDecisionTree* tree1 = 0; 00528 if( useTopdown ) { 00529 bool discrete = false; 00530 tree1 = new SprTopdownTree(filter.get(),crit.get(), 00531 nodeMinSize[0][i],discrete); 00532 } 00533 else { 00534 bool discrete = true; 00535 tree1 = new SprDecisionTree(filter.get(),crit.get(), 00536 nodeMinSize[0][i],doMerge,discrete); 00537 } 00538 classifiers[i] = tree1; 00539 } 00540 00541 // cross-validate 00542 vector<double> cvFom; 00543 SprCrossValidator cv(filter.get(),nCross); 00544 if( !cv.validate(crit.get(),loss.get(),classifiers,0,1, 00545 SprUtils::lowerBound(0.5),cvFom,verbose) ) { 00546 cerr << "Unable to cross-validate." << endl; 00547 for( int j=0;j<classifiers.size();j++ ) { 00548 delete classifiers[j]; 00549 } 00550 return 4; 00551 } 00552 else { 00553 cout << "Cross-validated FOMs:" << endl; 00554 for( int i=0;i<cvFom.size();i++ ) { 00555 cout << "Node size=" << setw(8) << nodeMinSize[0][i] 00556 << " FOM=" << setw(10) << cvFom[i] << endl; 00557 } 00558 } 00559 00560 // cleanup 00561 for( int j=0;j<classifiers.size();j++ ) { 00562 delete classifiers[j]; 00563 } 00564 00565 // normal exit 00566 return 0; 00567 }// end cross-validation 00568 00569 // make decision tree 00570 auto_ptr<SprDecisionTree> tree; 00571 if( useTopdown ) { 00572 bool discrete = false; 00573 tree.reset(new SprTopdownTree(filter.get(),crit.get(),nmin,discrete)); 00574 } 00575 else { 00576 tree.reset( new SprDecisionTree(filter.get(),crit.get(), 00577 nmin,doMerge,true)); 00578 if( countTreeSplits ) tree->startSplitCounter(); 00579 tree->setShowBackgroundNodes(true); 00580 } 00581 00582 // train 00583 if( !tree->train(verbose) ) { 00584 cerr << "Unable to train decision tree." << endl; 00585 return 4; 00586 } 00587 cout << "Finished training decision tree." << endl; 00588 00589 // save trained decision tree in human-readable format 00590 if( !outHuman.empty() ) { 00591 if( !tree->store(outHuman.c_str()) ) { 00592 cerr << "Cannot store decision tree in file " 00593 << outHuman.c_str() << endl; 00594 return 5; 00595 } 00596 } 00597 00598 // print out counted splits 00599 if( countTreeSplits ) tree->printSplitCounter(cout); 00600 00601 // make trained decision tree 00602 auto_ptr<SprTrainedDecisionTree> trainedTree(tree->makeTrained()); 00603 00604 // save trained tree in machine-readable format 00605 if( !outMachine.empty() ) { 00606 if( !trainedTree->store(outMachine.c_str()) ) { 00607 cerr << "Unable to save trained tree into " 00608 << outMachine.c_str() << endl; 00609 return 5; 00610 } 00611 } 00612 00613 // compute FOM for the validation data 00614 if( valFilter.get() != 0 ) { 00615 double wcor0(0), wmis0(0), wcor1(0), wmis1(0); 00616 int ncor0(0), nmis0(0), ncor1(0), nmis1(0); 00617 if( loss.get() != 0 ) loss->reset(); 00618 for( int i=0;i<valFilter->size();i++ ) { 00619 const SprPoint* p = (*valFilter.get())[i]; 00620 double w = valFilter->w(i); 00621 double resp = trainedTree->response(p->x_); 00622 if( trainedTree->accept(p) ) { 00623 if( p->class_ == inputClasses[0] ) { 00624 wmis0 += w; 00625 nmis0++; 00626 if( loss.get() != 0 ) loss->update(0,resp,w); 00627 } 00628 else if( p->class_ == inputClasses[1] ) { 00629 wcor1 += w; 00630 ncor1++; 00631 if( loss.get() != 0 ) loss->update(1,resp,w); 00632 } 00633 } 00634 else { 00635 if( p->class_ == inputClasses[0] ) { 00636 wcor0 += w; 00637 ncor0++; 00638 if( loss.get() != 0 ) loss->update(0,resp,w); 00639 } 00640 else if( p->class_ == inputClasses[1] ) { 00641 wmis1 += w; 00642 nmis1++; 00643 if( loss.get() != 0 ) loss->update(1,resp,w); 00644 } 00645 } 00646 } 00647 double vFom = crit->fom(wcor0,wmis0,wcor1,wmis1); 00648 double vLoss = 0; 00649 if( loss.get() != 0 ) vLoss = loss->value(); 00650 cout << "=====================================================" << endl; 00651 cout << "Validation FOM=" << vFom << " Loss=" << vLoss << endl; 00652 cout << "Content of the signal region:" 00653 << " W0=" << wmis0 << " W1=" << wcor1 00654 << " N0=" << nmis0 << " N1=" << ncor1 00655 << endl; 00656 cout << "=====================================================" << endl; 00657 } 00658 00659 // make histogram if requested 00660 if( tupleFile.empty() && valHbkFile.empty() ) return 0; 00661 00662 // make a wrapper to store box numbers 00663 class BoxNumberWrapper : public SprTrainedDecisionTree { 00664 public: 00665 virtual ~BoxNumberWrapper() {} 00666 BoxNumberWrapper(const SprTrainedDecisionTree& tree) 00667 : SprTrainedDecisionTree(tree) {} 00668 double response(const std::vector<double>& v) const { 00669 return this->nBox(v); 00670 } 00671 }; 00672 00673 // feed training data 00674 if( !tupleFile.empty() ) { 00675 // make a writer 00676 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training")); 00677 if( !tuple->init(tupleFile.c_str()) ) { 00678 cerr << "Unable to open output file " << tupleFile.c_str() << endl; 00679 return 6; 00680 } 00681 // wrap 00682 BoxNumberWrapper boxNumber(*(trainedTree.get())); 00683 // feed 00684 SprDataFeeder feeder(filter.get(),tuple.get()); 00685 feeder.addClassifier(trainedTree.get(),"tree"); 00686 feeder.addClassifier(&boxNumber,"box"); 00687 if( !feeder.feed(1000) ) { 00688 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl; 00689 return 6; 00690 } 00691 } 00692 00693 if( !valHbkFile.empty() ) { 00694 // make a writer 00695 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"test")); 00696 if( !tuple->init(valHbkFile.c_str()) ) { 00697 cerr << "Unable to open output file " << valHbkFile.c_str() << endl; 00698 return 7; 00699 } 00700 // wrap 00701 BoxNumberWrapper boxNumber(*(trainedTree.get())); 00702 // feed 00703 SprDataFeeder feeder(valFilter.get(),tuple.get()); 00704 feeder.addClassifier(trainedTree.get(),"tree"); 00705 feeder.addClassifier(&boxNumber,"box"); 00706 if( !feeder.feed(1000) ) { 00707 cerr << "Cannot feed data into file " << valHbkFile.c_str() << endl; 00708 return 7; 00709 } 00710 } 00711 00712 // exit 00713 return 0; 00714 }