#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStdBackprop.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedStdBackprop.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <iomanip>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void help | ( | const char * | prog | ) |
Definition at line 39 of file SprStdBackpropApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00040 { 00041 cout << "Usage: " << prog 00042 << " training_data_file " << endl; 00043 cout << "\t Options: " << endl; 00044 cout << "\t-h --- help " << endl; 00045 cout << "\t-o output Tuple file " << endl; 00046 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00047 cout << "\t-A save output data in ascii instead of Root " << endl; 00048 cout << "\t-M AdaBoost mode " << endl; 00049 cout << "\t\t 1 = Discrete AdaBoost (default) " << endl; 00050 cout << "\t\t 2 = Real AdaBoost " << endl; 00051 cout << "\t\t 3 = Epsilon AdaBoost " << endl; 00052 cout << "\t-E epsilon for Epsilon and Real AdaBoosts (def=0.01)" << endl; 00053 cout << "\t-n number of AdaBoost training cycles (1 for single NN)" << endl; 00054 cout << "\t-l number of Neural Net training cycles " << endl; 00055 cout << "\t-N neural net configuration, e.g., '6:3:1' (see SprStdBackprop.hh)" << endl; 00056 cout << "\t-L learning rate of the network (default=0.1) " << endl; 00057 cout << "\t-I learning rate for network initialization (def=0.1)" << endl; 00058 cout << "\t-i number of input points to use for initialization (def=all)" 00059 << endl; 00060 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl; 00061 cout << "\t-Q apply variable transformation saved in file " << endl; 00062 cout << "\t-g per-event loss for (cross-)validation " << endl; 00063 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl; 00064 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl; 00065 cout << "\t-m replace data values below this cutoff with medians" << endl; 00066 cout << "\t-s use standard AdaBoost (see SprTrainedAdaBoost.hh)"<< endl; 00067 cout << "\t-e skip initial event reweighting when resuming " << endl; 00068 cout << "\t-u store data with modified weights to file " << endl; 00069 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00070 cout << "\t-f store trained AdaBoost to file " << endl; 00071 cout << "\t-r resume training for AdaBoost stored in file " << endl; 00072 cout << "\t-R resume training for a single neural net stored in file" 00073 << endl; 00074 cout << "\t-S resume training from SNNS configuration stored in file" 00075 << endl; 00076 cout << "\t-K keep this fraction in training set and " << endl; 00077 cout << "\t\t put the rest into validation set " << endl; 00078 cout << "\t-D randomize training set split-up " << endl; 00079 cout << "\t-t read validation/test data from a file " << endl; 00080 cout << "\t\t (must be in same format as input data!!! " << endl; 00081 cout << "\t-d frequency of print-outs for validation data " << endl; 00082 cout << "\t-w scale all signal weights by this factor " << endl; 00083 cout << "\t-V include only these input variables " << endl; 00084 cout << "\t-z exclude input variables from the list " << endl; 00085 cout << "\t-Z exclude input variables from the list, " 00086 << "but put them in the output file " << endl; 00087 cout << "\t\t Variables must be listed in quotes and separated by commas." 00088 << endl; 00089 }
Definition at line 92 of file SprStdBackpropApp.cc.
References begin, c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, end, lat::endl(), HLT_VtxMuL3::Epsilon, geometryDiff::epsilon, eta, filter, find(), help(), i, int, split, t, vars, and weights.
00093 { 00094 // check command line 00095 if( argc < 2 ) { 00096 help(argv[0]); 00097 return 1; 00098 } 00099 00100 // init 00101 string tupleFile; 00102 int readMode = 0; 00103 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00104 unsigned adaCycles = 0; 00105 unsigned nnCycles = 0; 00106 double eta = 0.1; 00107 int iLoss = 1; 00108 int verbose = 0; 00109 string outFile; 00110 string valFile; 00111 unsigned valPrint = 0; 00112 bool scaleWeights = false; 00113 double sW = 1.; 00114 bool useStandardAB = false; 00115 int iAdaBoostMode = 1; 00116 double epsilon = 0.01; 00117 bool skipInitialEventReweighting = false; 00118 string weightedDataOut; 00119 bool setLowCutoff = false; 00120 double lowCutoff = 0; 00121 string includeList, excludeList; 00122 string inputClassesString; 00123 string stringVarsDoNotFeed; 00124 string resumeFile, resumeSNNSFile, resumeNNFile; 00125 string netConfig; 00126 double initEta = 0.1; 00127 unsigned initPoints = 0; 00128 bool split = false; 00129 double splitFactor = 0; 00130 bool splitRandomize = false; 00131 string transformerFile; 00132 00133 // decode command line 00134 int c; 00135 extern char* optarg; 00136 // extern int optind; 00137 while((c = getopt(argc,argv,"ho:a:AM:E:n:l:N:L:I:i:y:Q:g:m:seu:v:f:r:R:S:K:Dt:d:w:V:z:Z:")) != EOF ) { 00138 switch( c ) 00139 { 00140 case 'h' : 00141 help(argv[0]); 00142 return 1; 00143 case 'M' : 00144 iAdaBoostMode = (optarg==0 ? 1 : atoi(optarg)); 00145 break; 00146 case 'E' : 00147 epsilon = (optarg==0 ? 0.01 : atof(optarg)); 00148 break; 00149 case 'o' : 00150 tupleFile = optarg; 00151 break; 00152 case 'a' : 00153 readMode = (optarg==0 ? 0 : atoi(optarg)); 00154 break; 00155 case 'A' : 00156 writeMode = SprRWFactory::Ascii; 00157 break; 00158 case 'n' : 00159 adaCycles = (optarg==0 ? 1 : atoi(optarg)); 00160 break; 00161 case 'l' : 00162 nnCycles = (optarg==0 ? 1 : atoi(optarg)); 00163 break; 00164 case 'N' : 00165 netConfig = optarg; 00166 break; 00167 case 'L' : 00168 eta = (optarg==0 ? 0.1 : atof(optarg)); 00169 break; 00170 case 'I' : 00171 initEta = (optarg==0 ? 0.1 : atof(optarg)); 00172 break; 00173 case 'i' : 00174 initPoints = (optarg==0 ? 0 : atoi(optarg)); 00175 break; 00176 case 'y' : 00177 inputClassesString = optarg; 00178 break; 00179 case 'Q' : 00180 transformerFile = optarg; 00181 break; 00182 case 'g' : 00183 iLoss = (optarg==0 ? 0 : atoi(optarg)); 00184 break; 00185 case 'm' : 00186 if( optarg != 0 ) { 00187 setLowCutoff = true; 00188 lowCutoff = atof(optarg); 00189 } 00190 break; 00191 case 's' : 00192 useStandardAB = true; 00193 break; 00194 case 'e' : 00195 skipInitialEventReweighting = true; 00196 break; 00197 case 'u' : 00198 weightedDataOut = optarg; 00199 break; 00200 case 'v' : 00201 verbose = (optarg==0 ? 0 : atoi(optarg)); 00202 break; 00203 case 'f' : 00204 outFile = optarg; 00205 break; 00206 case 'r' : 00207 resumeFile = optarg; 00208 break; 00209 case 'R' : 00210 resumeNNFile = optarg; 00211 break; 00212 case 'S' : 00213 resumeSNNSFile = optarg; 00214 break; 00215 case 'K' : 00216 split = true; 00217 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00218 break; 00219 case 'D' : 00220 splitRandomize = true; 00221 break; 00222 case 't' : 00223 valFile = optarg; 00224 break; 00225 case 'd' : 00226 valPrint = (optarg==0 ? 0 : atoi(optarg)); 00227 break; 00228 case 'w' : 00229 if( optarg != 0 ) { 00230 scaleWeights = true; 00231 sW = atof(optarg); 00232 } 00233 break; 00234 case 'V' : 00235 includeList = optarg; 00236 break; 00237 case 'z' : 00238 excludeList = optarg; 00239 break; 00240 case 'Z' : 00241 stringVarsDoNotFeed = optarg; 00242 break; 00243 } 00244 } 00245 00246 // Get training file. 00247 string trFile = argv[argc-1]; 00248 00249 // make reader 00250 SprRWFactory::DataType inputType 00251 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00252 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00253 00254 // include variables 00255 set<string> includeSet; 00256 if( !includeList.empty() ) { 00257 vector<vector<string> > includeVars; 00258 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00259 assert( !includeVars.empty() ); 00260 for( int i=0;i<includeVars[0].size();i++ ) 00261 includeSet.insert(includeVars[0][i]); 00262 if( !reader->chooseVars(includeSet) ) { 00263 cerr << "Unable to include variables in training set." << endl; 00264 return 2; 00265 } 00266 else { 00267 cout << "Following variables have been included in optimization: "; 00268 for( set<string>::const_iterator 00269 i=includeSet.begin();i!=includeSet.end();i++ ) 00270 cout << "\"" << *i << "\"" << " "; 00271 cout << endl; 00272 } 00273 } 00274 00275 // exclude variables 00276 set<string> excludeSet; 00277 if( !excludeList.empty() ) { 00278 vector<vector<string> > excludeVars; 00279 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00280 assert( !excludeVars.empty() ); 00281 for( int i=0;i<excludeVars[0].size();i++ ) 00282 excludeSet.insert(excludeVars[0][i]); 00283 if( !reader->chooseAllBut(excludeSet) ) { 00284 cerr << "Unable to exclude variables from training set." << endl; 00285 return 2; 00286 } 00287 else { 00288 cout << "Following variables have been excluded from optimization: "; 00289 for( set<string>::const_iterator 00290 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00291 cout << "\"" << *i << "\"" << " "; 00292 cout << endl; 00293 } 00294 } 00295 00296 // read training data from file 00297 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00298 if( filter.get() == 0 ) { 00299 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00300 return 2; 00301 } 00302 vector<string> vars; 00303 filter->vars(vars); 00304 cout << "Read data from file " << trFile.c_str() 00305 << " for variables"; 00306 for( int i=0;i<vars.size();i++ ) 00307 cout << " \"" << vars[i].c_str() << "\""; 00308 cout << endl; 00309 cout << "Total number of points read: " << filter->size() << endl; 00310 00311 // filter training data by class 00312 vector<SprClass> inputClasses; 00313 if( !filter->filterByClass(inputClassesString.c_str()) ) { 00314 cerr << "Cannot choose input classes for string " 00315 << inputClassesString << endl; 00316 return 2; 00317 } 00318 filter->classes(inputClasses); 00319 assert( inputClasses.size() > 1 ); 00320 cout << "Training data filtered by class." << endl; 00321 for( int i=0;i<inputClasses.size();i++ ) { 00322 cout << "Points in class " << inputClasses[i] << ": " 00323 << filter->ptsInClass(inputClasses[i]) << endl; 00324 } 00325 00326 // scale weights 00327 if( scaleWeights ) { 00328 cout << "Signal weights are multiplied by " << sW << endl; 00329 filter->scaleWeights(inputClasses[1],sW); 00330 } 00331 00332 // apply low cutoff 00333 if( setLowCutoff ) { 00334 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00335 cerr << "Unable to replace missing values in training data." << endl; 00336 return 2; 00337 } 00338 else 00339 cout << "Values below " << lowCutoff << " in training data" 00340 << " have been replaced with medians." << endl; 00341 } 00342 00343 // read validation data from file 00344 auto_ptr<SprAbsFilter> valFilter; 00345 if( split && !valFile.empty() ) { 00346 cerr << "Unable to split training data and use validation data " 00347 << "from a separate file." << endl; 00348 return 2; 00349 } 00350 if( split && valPrint!=0 ) { 00351 cout << "Splitting training data with factor " << splitFactor << endl; 00352 if( splitRandomize ) 00353 cout << "Will use randomized splitting." << endl; 00354 vector<double> weights; 00355 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00356 if( splitted == 0 ) { 00357 cerr << "Unable to split training data." << endl; 00358 return 2; 00359 } 00360 bool ownData = true; 00361 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00362 cout << "Training data re-filtered:" << endl; 00363 for( int i=0;i<inputClasses.size();i++ ) { 00364 cout << "Points in class " << inputClasses[i] << ": " 00365 << filter->ptsInClass(inputClasses[i]) << endl; 00366 } 00367 } 00368 if( !valFile.empty() && valPrint!=0 ) { 00369 auto_ptr<SprAbsReader> 00370 valReader(SprRWFactory::makeReader(inputType,readMode)); 00371 if( !includeSet.empty() ) { 00372 if( !valReader->chooseVars(includeSet) ) { 00373 cerr << "Unable to include variables in validation set." << endl; 00374 return 2; 00375 } 00376 } 00377 if( !excludeSet.empty() ) { 00378 if( !valReader->chooseAllBut(excludeSet) ) { 00379 cerr << "Unable to exclude variables from validation set." << endl; 00380 return 2; 00381 } 00382 } 00383 valFilter.reset(valReader->read(valFile.c_str())); 00384 if( valFilter.get() == 0 ) { 00385 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00386 return 2; 00387 } 00388 vector<string> valVars; 00389 valFilter->vars(valVars); 00390 cout << "Read validation data from file " << valFile.c_str() 00391 << " for variables"; 00392 for( int i=0;i<valVars.size();i++ ) 00393 cout << " \"" << valVars[i].c_str() << "\""; 00394 cout << endl; 00395 cout << "Total number of points read: " << valFilter->size() << endl; 00396 cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0]) 00397 << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl; 00398 } 00399 00400 // filter validation data by class 00401 if( valFilter.get() != 0 ) { 00402 if( !valFilter->filterByClass(inputClassesString.c_str()) ) { 00403 cerr << "Cannot choose input classes for string " 00404 << inputClassesString << endl; 00405 return 2; 00406 } 00407 valFilter->classes(inputClasses); 00408 cout << "Validation data filtered by class." << endl; 00409 for( int i=0;i<inputClasses.size();i++ ) { 00410 cout << "Points in class " << inputClasses[i] << ": " 00411 << valFilter->ptsInClass(inputClasses[i]) << endl; 00412 } 00413 } 00414 00415 // scale weights 00416 if( scaleWeights && valFilter.get()!=0 ) 00417 valFilter->scaleWeights(inputClasses[1],sW); 00418 00419 // apply low cutoff 00420 if( setLowCutoff && valFilter.get()!=0 ) { 00421 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00422 cerr << "Unable to replace missing values in validation data." << endl; 00423 return 2; 00424 } 00425 else 00426 cout << "Values below " << lowCutoff << " in validation data" 00427 << " have been replaced with medians." << endl; 00428 } 00429 00430 // apply transformation of variables to training and test data 00431 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00432 if( !transformerFile.empty() ) { 00433 SprVarTransformerReader transReader; 00434 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00435 if( t == 0 ) { 00436 cerr << "Unable to read VarTransformer from file " 00437 << transformerFile.c_str() << endl; 00438 return 2; 00439 } 00440 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00441 SprTransformerFilter* t_valid = 0; 00442 if( valFilter.get() != 0 ) 00443 t_valid = new SprTransformerFilter(valFilter.get()); 00444 bool replaceOriginalData = true; 00445 if( !t_train->transform(t,replaceOriginalData) ) { 00446 cerr << "Unable to apply VarTransformer to training data." << endl; 00447 return 2; 00448 } 00449 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00450 cerr << "Unable to apply VarTransformer to validation data." << endl; 00451 return 2; 00452 } 00453 cout << "Variable transformation from file " 00454 << transformerFile.c_str() << " has been applied to " 00455 << "training and validation data." << endl; 00456 garbage_train.reset(filter.release()); 00457 garbage_valid.reset(valFilter.release()); 00458 filter.reset(t_train); 00459 valFilter.reset(t_valid); 00460 } 00461 00462 // make per-event loss 00463 auto_ptr<SprAverageLoss> loss; 00464 switch( iLoss ) 00465 { 00466 case 1 : 00467 if( adaCycles > 1 ) { 00468 loss.reset(new SprAverageLoss(&SprLoss::quadratic, 00469 &SprTransformation::logit)); 00470 } 00471 else { 00472 loss.reset(new SprAverageLoss(&SprLoss::quadratic)); 00473 } 00474 cout << "Per-event loss set to " 00475 << "Quadratic loss (y-f(x))^2 " << endl; 00476 useStandardAB = true; 00477 break; 00478 case 2 : 00479 if( adaCycles > 1 ) { 00480 loss.reset(new SprAverageLoss(&SprLoss::exponential)); 00481 } 00482 else { 00483 loss.reset(new SprAverageLoss(&SprLoss::exponential, 00484 &SprTransformation::logitInverse)); 00485 } 00486 cout << "Per-event loss set to " 00487 << "Exponential loss exp(-y*f(x)) " << endl; 00488 useStandardAB = true; 00489 break; 00490 default : 00491 cout << "No per-event loss is chosen. Will use the default." << endl; 00492 break; 00493 } 00494 00495 // make AdaBoost mode 00496 SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete; 00497 switch( iAdaBoostMode ) 00498 { 00499 case 1 : 00500 abMode = SprTrainedAdaBoost::Discrete; 00501 cout << "Will train Discrete AdaBoost." << endl; 00502 break; 00503 case 2 : 00504 abMode = SprTrainedAdaBoost::Real; 00505 cout << "Will train Real AdaBoost." << endl; 00506 break; 00507 case 3 : 00508 abMode = SprTrainedAdaBoost::Epsilon; 00509 cout << "Will train Epsilon AdaBoost." << endl; 00510 break; 00511 default : 00512 cout << "Will train Discrete AdaBoost." << endl; 00513 break; 00514 } 00515 00516 // sanity check 00517 int resume = int(!resumeFile.empty()) 00518 + int(!resumeNNFile.empty()) 00519 + int(!resumeSNNSFile.empty()); 00520 if( resume > 1 ) { 00521 cerr << "Reading more than one classifier configuration is not allowed." 00522 << " Requested: " << resume << endl; 00523 return 5; 00524 } 00525 if( (!resumeNNFile.empty() || !resumeSNNSFile.empty()) 00526 && !netConfig.empty() ) { 00527 cerr << "What do you want to do - read NN configuration from a file " 00528 << "or specify configuration on the command line? " 00529 << "Life is tough - you cannot do both." << endl; 00530 return 5; 00531 } 00532 00533 // make a single NN 00534 auto_ptr<SprStdBackprop> stdnn; 00535 if( adaCycles>0 && resumeNNFile.empty() && resumeSNNSFile.empty() ) { 00536 stdnn.reset(new SprStdBackprop(filter.get(), 00537 netConfig.c_str(), 00538 nnCycles, 00539 eta)); 00540 if( !stdnn->init(initEta,initPoints) ) { 00541 cerr << "Unable to initialize neural net." << endl; 00542 return 6; 00543 } 00544 } 00545 else { 00546 stdnn.reset(new SprStdBackprop(filter.get(), 00547 nnCycles, 00548 eta)); 00549 } 00550 00551 // read saved NN from file 00552 SprTrainedStdBackprop* trainedNN = 0; 00553 if( !resumeSNNSFile.empty() ) { 00554 if( !stdnn->readSNNS(resumeSNNSFile.c_str()) ) { 00555 cerr << "Unable to read SNNS configuration from file " 00556 << resumeSNNSFile.c_str() << endl; 00557 return 6; 00558 } 00559 trainedNN = stdnn->makeTrained(); 00560 cout << "Read SNNS configuration from file " 00561 << resumeSNNSFile.c_str() << endl; 00562 } 00563 if( !resumeNNFile.empty() ) { 00564 if( !SprClassifierReader::readTrainable(resumeNNFile.c_str(), 00565 stdnn.get(),verbose) ) { 00566 cerr << "Unable to read SPR NN configuration from file " 00567 << resumeNNFile.c_str() << endl; 00568 return 6; 00569 } 00570 trainedNN = stdnn->makeTrained(); 00571 cout << "Read SPR neural net configuration from file " 00572 << resumeNNFile.c_str() << endl; 00573 } 00574 00575 // make classifier to train 00576 auto_ptr<SprAbsClassifier> classifier; 00577 if( adaCycles != 1 ) { 00578 // make AdaBoost 00579 SprAdaBoost* ab = new SprAdaBoost(filter.get(), 00580 adaCycles, 00581 useStandardAB, 00582 abMode); 00583 cout << "Setting epsilon to " << epsilon << endl; 00584 ab->setEpsilon(epsilon); 00585 00586 // skip reweigting 00587 if( skipInitialEventReweighting ) ab->skipInitialEventReweighting(true); 00588 00589 // set validation 00590 if( valFilter.get()!=0 && !valFilter->empty() ) 00591 ab->setValidation(valFilter.get(),valPrint,loss.get()); 00592 00593 // read saved AdaBoost 00594 if( resumeFile.empty() ) { 00595 if( trainedNN != 0 ) { 00596 if( !ab->addTrained(trainedNN,true) ) { 00597 cerr << "Unable to add first trained NN to AdaBoost." << endl; 00598 return 6; 00599 } 00600 } 00601 } 00602 else { 00603 if( !SprClassifierReader::readTrainable(resumeFile.c_str(), 00604 ab,verbose) ) { 00605 cerr << "Failed to read saved AdaBoost from file " 00606 << resumeFile.c_str() << endl; 00607 return 6; 00608 } 00609 cout << "Read saved AdaBoost from file " << resumeFile.c_str() 00610 << " with " << ab->nTrained() << " trained classifiers." << endl; 00611 } 00612 00613 // add a trainable NN 00614 if( !ab->addTrainable(stdnn.get()) ) { 00615 cerr << "Unable to add neural net to AdaBoost." << endl; 00616 return 6; 00617 } 00618 00619 // reset classifier 00620 classifier.reset(ab); 00621 } 00622 else { 00623 // set validation 00624 if( valFilter.get()!=0 && !valFilter->empty() ) 00625 stdnn->setValidation(valFilter.get(),valPrint,loss.get()); 00626 00627 // reset classifier 00628 classifier.reset(stdnn.release()); 00629 } 00630 00631 // train 00632 if( !classifier->train(verbose) ) { 00633 cerr << "Training terminated with error." << endl; 00634 return 7; 00635 } 00636 else { 00637 cout << "Training done." << endl; 00638 if( adaCycles != 1 ) { 00639 SprAdaBoost* ab = static_cast<SprAdaBoost*>(classifier.get()); 00640 cout << "AdaBoost finished training with " << ab->nTrained() 00641 << " classifiers." << endl; 00642 } 00643 } 00644 00645 // save trained classifier 00646 if( !outFile.empty() ) { 00647 if( !classifier->store(outFile.c_str()) ) { 00648 cerr << "Cannot store classifier in file " << outFile.c_str() << endl; 00649 return 8; 00650 } 00651 } 00652 00653 // save reweighted data 00654 if( adaCycles > 1 ) { 00655 if( !weightedDataOut.empty() ) { 00656 SprAdaBoost* ab = static_cast<SprAdaBoost*>(classifier.get()); 00657 if( !ab->storeData(weightedDataOut.c_str()) ) { 00658 cerr << "Cannot store weighted AdaBoost data to file " 00659 << weightedDataOut.c_str() << endl; 00660 return 9; 00661 } 00662 } 00663 } 00664 00665 // make a trained AdaBoost 00666 auto_ptr<SprAbsTrainedClassifier> trained(classifier->makeTrained()); 00667 if( trained.get() == 0 ) { 00668 cerr << "Unable to get trained classifier." << endl; 00669 return 9; 00670 } 00671 00672 // make histogram if requested 00673 if( tupleFile.empty() ) 00674 return 0; 00675 00676 // make a writer 00677 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training")); 00678 if( !tuple->init(tupleFile.c_str()) ) { 00679 cerr << "Unable to open output file " << tupleFile.c_str() << endl; 00680 return 10; 00681 } 00682 00683 // determine if certain variables are to be excluded from usage, 00684 // but included in the output storage file (-Z option) 00685 string printVarsDoNotFeed; 00686 vector<vector<string> > varsDoNotFeed; 00687 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed); 00688 vector<unsigned> mapper; 00689 for( int d=0;d<vars.size();d++ ) { 00690 if( varsDoNotFeed.empty() || 00691 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d]) 00692 ==varsDoNotFeed[0].end()) ) { 00693 mapper.push_back(d); 00694 } 00695 else { 00696 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " ); 00697 printVarsDoNotFeed += vars[d]; 00698 } 00699 } 00700 if( !printVarsDoNotFeed.empty() ) { 00701 cout << "The following variables are not used in the algorithm, " 00702 << "but will be included in the output file: " 00703 << printVarsDoNotFeed.c_str() << endl; 00704 } 00705 00706 // feed 00707 SprDataFeeder feeder(filter.get(),tuple.get(),mapper); 00708 string classifierName; 00709 if( adaCycles != 1 ) 00710 classifierName = "adann"; 00711 else 00712 classifierName = "nn"; 00713 feeder.addClassifier(trained.get(),classifierName.c_str()); 00714 if( !feeder.feed(1000) ) { 00715 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl; 00716 return 11; 00717 } 00718 00719 // exit 00720 return 0; 00721 }