00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprStdBackprop.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedStdBackprop.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00026
00027 #include <stdlib.h>
00028 #include <unistd.h>
00029 #include <iostream>
00030 #include <vector>
00031 #include <set>
00032 #include <string>
00033 #include <memory>
00034 #include <iomanip>
00035
00036 using namespace std;
00037
00038
00039 void help(const char* prog)
00040 {
00041 cout << "Usage: " << prog
00042 << " training_data_file " << endl;
00043 cout << "\t Options: " << endl;
00044 cout << "\t-h --- help " << endl;
00045 cout << "\t-o output Tuple file " << endl;
00046 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00047 cout << "\t-A save output data in ascii instead of Root " << endl;
00048 cout << "\t-M AdaBoost mode " << endl;
00049 cout << "\t\t 1 = Discrete AdaBoost (default) " << endl;
00050 cout << "\t\t 2 = Real AdaBoost " << endl;
00051 cout << "\t\t 3 = Epsilon AdaBoost " << endl;
00052 cout << "\t-E epsilon for Epsilon and Real AdaBoosts (def=0.01)" << endl;
00053 cout << "\t-n number of AdaBoost training cycles (1 for single NN)" << endl;
00054 cout << "\t-l number of Neural Net training cycles " << endl;
00055 cout << "\t-N neural net configuration, e.g., '6:3:1' (see SprStdBackprop.hh)" << endl;
00056 cout << "\t-L learning rate of the network (default=0.1) " << endl;
00057 cout << "\t-I learning rate for network initialization (def=0.1)" << endl;
00058 cout << "\t-i number of input points to use for initialization (def=all)"
00059 << endl;
00060 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00061 cout << "\t-Q apply variable transformation saved in file " << endl;
00062 cout << "\t-g per-event loss for (cross-)validation " << endl;
00063 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl;
00064 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl;
00065 cout << "\t-m replace data values below this cutoff with medians" << endl;
00066 cout << "\t-s use standard AdaBoost (see SprTrainedAdaBoost.hh)"<< endl;
00067 cout << "\t-e skip initial event reweighting when resuming " << endl;
00068 cout << "\t-u store data with modified weights to file " << endl;
00069 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00070 cout << "\t-f store trained AdaBoost to file " << endl;
00071 cout << "\t-r resume training for AdaBoost stored in file " << endl;
00072 cout << "\t-R resume training for a single neural net stored in file"
00073 << endl;
00074 cout << "\t-S resume training from SNNS configuration stored in file"
00075 << endl;
00076 cout << "\t-K keep this fraction in training set and " << endl;
00077 cout << "\t\t put the rest into validation set " << endl;
00078 cout << "\t-D randomize training set split-up " << endl;
00079 cout << "\t-t read validation/test data from a file " << endl;
00080 cout << "\t\t (must be in same format as input data!!! " << endl;
00081 cout << "\t-d frequency of print-outs for validation data " << endl;
00082 cout << "\t-w scale all signal weights by this factor " << endl;
00083 cout << "\t-V include only these input variables " << endl;
00084 cout << "\t-z exclude input variables from the list " << endl;
00085 cout << "\t-Z exclude input variables from the list, "
00086 << "but put them in the output file " << endl;
00087 cout << "\t\t Variables must be listed in quotes and separated by commas."
00088 << endl;
00089 }
00090
00091
00092 int main(int argc, char ** argv)
00093 {
00094
00095 if( argc < 2 ) {
00096 help(argv[0]);
00097 return 1;
00098 }
00099
00100
00101 string tupleFile;
00102 int readMode = 0;
00103 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00104 unsigned adaCycles = 0;
00105 unsigned nnCycles = 0;
00106 double eta = 0.1;
00107 int iLoss = 1;
00108 int verbose = 0;
00109 string outFile;
00110 string valFile;
00111 unsigned valPrint = 0;
00112 bool scaleWeights = false;
00113 double sW = 1.;
00114 bool useStandardAB = false;
00115 int iAdaBoostMode = 1;
00116 double epsilon = 0.01;
00117 bool skipInitialEventReweighting = false;
00118 string weightedDataOut;
00119 bool setLowCutoff = false;
00120 double lowCutoff = 0;
00121 string includeList, excludeList;
00122 string inputClassesString;
00123 string stringVarsDoNotFeed;
00124 string resumeFile, resumeSNNSFile, resumeNNFile;
00125 string netConfig;
00126 double initEta = 0.1;
00127 unsigned initPoints = 0;
00128 bool split = false;
00129 double splitFactor = 0;
00130 bool splitRandomize = false;
00131 string transformerFile;
00132
00133
00134 int c;
00135 extern char* optarg;
00136
00137 while((c = getopt(argc,argv,"ho:a:AM:E:n:l:N:L:I:i:y:Q:g:m:seu:v:f:r:R:S:K:Dt:d:w:V:z:Z:")) != EOF ) {
00138 switch( c )
00139 {
00140 case 'h' :
00141 help(argv[0]);
00142 return 1;
00143 case 'M' :
00144 iAdaBoostMode = (optarg==0 ? 1 : atoi(optarg));
00145 break;
00146 case 'E' :
00147 epsilon = (optarg==0 ? 0.01 : atof(optarg));
00148 break;
00149 case 'o' :
00150 tupleFile = optarg;
00151 break;
00152 case 'a' :
00153 readMode = (optarg==0 ? 0 : atoi(optarg));
00154 break;
00155 case 'A' :
00156 writeMode = SprRWFactory::Ascii;
00157 break;
00158 case 'n' :
00159 adaCycles = (optarg==0 ? 1 : atoi(optarg));
00160 break;
00161 case 'l' :
00162 nnCycles = (optarg==0 ? 1 : atoi(optarg));
00163 break;
00164 case 'N' :
00165 netConfig = optarg;
00166 break;
00167 case 'L' :
00168 eta = (optarg==0 ? 0.1 : atof(optarg));
00169 break;
00170 case 'I' :
00171 initEta = (optarg==0 ? 0.1 : atof(optarg));
00172 break;
00173 case 'i' :
00174 initPoints = (optarg==0 ? 0 : atoi(optarg));
00175 break;
00176 case 'y' :
00177 inputClassesString = optarg;
00178 break;
00179 case 'Q' :
00180 transformerFile = optarg;
00181 break;
00182 case 'g' :
00183 iLoss = (optarg==0 ? 0 : atoi(optarg));
00184 break;
00185 case 'm' :
00186 if( optarg != 0 ) {
00187 setLowCutoff = true;
00188 lowCutoff = atof(optarg);
00189 }
00190 break;
00191 case 's' :
00192 useStandardAB = true;
00193 break;
00194 case 'e' :
00195 skipInitialEventReweighting = true;
00196 break;
00197 case 'u' :
00198 weightedDataOut = optarg;
00199 break;
00200 case 'v' :
00201 verbose = (optarg==0 ? 0 : atoi(optarg));
00202 break;
00203 case 'f' :
00204 outFile = optarg;
00205 break;
00206 case 'r' :
00207 resumeFile = optarg;
00208 break;
00209 case 'R' :
00210 resumeNNFile = optarg;
00211 break;
00212 case 'S' :
00213 resumeSNNSFile = optarg;
00214 break;
00215 case 'K' :
00216 split = true;
00217 splitFactor = (optarg==0 ? 0 : atof(optarg));
00218 break;
00219 case 'D' :
00220 splitRandomize = true;
00221 break;
00222 case 't' :
00223 valFile = optarg;
00224 break;
00225 case 'd' :
00226 valPrint = (optarg==0 ? 0 : atoi(optarg));
00227 break;
00228 case 'w' :
00229 if( optarg != 0 ) {
00230 scaleWeights = true;
00231 sW = atof(optarg);
00232 }
00233 break;
00234 case 'V' :
00235 includeList = optarg;
00236 break;
00237 case 'z' :
00238 excludeList = optarg;
00239 break;
00240 case 'Z' :
00241 stringVarsDoNotFeed = optarg;
00242 break;
00243 }
00244 }
00245
00246
00247 string trFile = argv[argc-1];
00248
00249
00250 SprRWFactory::DataType inputType
00251 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00252 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00253
00254
00255 set<string> includeSet;
00256 if( !includeList.empty() ) {
00257 vector<vector<string> > includeVars;
00258 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00259 assert( !includeVars.empty() );
00260 for( int i=0;i<includeVars[0].size();i++ )
00261 includeSet.insert(includeVars[0][i]);
00262 if( !reader->chooseVars(includeSet) ) {
00263 cerr << "Unable to include variables in training set." << endl;
00264 return 2;
00265 }
00266 else {
00267 cout << "Following variables have been included in optimization: ";
00268 for( set<string>::const_iterator
00269 i=includeSet.begin();i!=includeSet.end();i++ )
00270 cout << "\"" << *i << "\"" << " ";
00271 cout << endl;
00272 }
00273 }
00274
00275
00276 set<string> excludeSet;
00277 if( !excludeList.empty() ) {
00278 vector<vector<string> > excludeVars;
00279 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00280 assert( !excludeVars.empty() );
00281 for( int i=0;i<excludeVars[0].size();i++ )
00282 excludeSet.insert(excludeVars[0][i]);
00283 if( !reader->chooseAllBut(excludeSet) ) {
00284 cerr << "Unable to exclude variables from training set." << endl;
00285 return 2;
00286 }
00287 else {
00288 cout << "Following variables have been excluded from optimization: ";
00289 for( set<string>::const_iterator
00290 i=excludeSet.begin();i!=excludeSet.end();i++ )
00291 cout << "\"" << *i << "\"" << " ";
00292 cout << endl;
00293 }
00294 }
00295
00296
00297 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00298 if( filter.get() == 0 ) {
00299 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00300 return 2;
00301 }
00302 vector<string> vars;
00303 filter->vars(vars);
00304 cout << "Read data from file " << trFile.c_str()
00305 << " for variables";
00306 for( int i=0;i<vars.size();i++ )
00307 cout << " \"" << vars[i].c_str() << "\"";
00308 cout << endl;
00309 cout << "Total number of points read: " << filter->size() << endl;
00310
00311
00312 vector<SprClass> inputClasses;
00313 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00314 cerr << "Cannot choose input classes for string "
00315 << inputClassesString << endl;
00316 return 2;
00317 }
00318 filter->classes(inputClasses);
00319 assert( inputClasses.size() > 1 );
00320 cout << "Training data filtered by class." << endl;
00321 for( int i=0;i<inputClasses.size();i++ ) {
00322 cout << "Points in class " << inputClasses[i] << ": "
00323 << filter->ptsInClass(inputClasses[i]) << endl;
00324 }
00325
00326
00327 if( scaleWeights ) {
00328 cout << "Signal weights are multiplied by " << sW << endl;
00329 filter->scaleWeights(inputClasses[1],sW);
00330 }
00331
00332
00333 if( setLowCutoff ) {
00334 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00335 cerr << "Unable to replace missing values in training data." << endl;
00336 return 2;
00337 }
00338 else
00339 cout << "Values below " << lowCutoff << " in training data"
00340 << " have been replaced with medians." << endl;
00341 }
00342
00343
00344 auto_ptr<SprAbsFilter> valFilter;
00345 if( split && !valFile.empty() ) {
00346 cerr << "Unable to split training data and use validation data "
00347 << "from a separate file." << endl;
00348 return 2;
00349 }
00350 if( split && valPrint!=0 ) {
00351 cout << "Splitting training data with factor " << splitFactor << endl;
00352 if( splitRandomize )
00353 cout << "Will use randomized splitting." << endl;
00354 vector<double> weights;
00355 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00356 if( splitted == 0 ) {
00357 cerr << "Unable to split training data." << endl;
00358 return 2;
00359 }
00360 bool ownData = true;
00361 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00362 cout << "Training data re-filtered:" << endl;
00363 for( int i=0;i<inputClasses.size();i++ ) {
00364 cout << "Points in class " << inputClasses[i] << ": "
00365 << filter->ptsInClass(inputClasses[i]) << endl;
00366 }
00367 }
00368 if( !valFile.empty() && valPrint!=0 ) {
00369 auto_ptr<SprAbsReader>
00370 valReader(SprRWFactory::makeReader(inputType,readMode));
00371 if( !includeSet.empty() ) {
00372 if( !valReader->chooseVars(includeSet) ) {
00373 cerr << "Unable to include variables in validation set." << endl;
00374 return 2;
00375 }
00376 }
00377 if( !excludeSet.empty() ) {
00378 if( !valReader->chooseAllBut(excludeSet) ) {
00379 cerr << "Unable to exclude variables from validation set." << endl;
00380 return 2;
00381 }
00382 }
00383 valFilter.reset(valReader->read(valFile.c_str()));
00384 if( valFilter.get() == 0 ) {
00385 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00386 return 2;
00387 }
00388 vector<string> valVars;
00389 valFilter->vars(valVars);
00390 cout << "Read validation data from file " << valFile.c_str()
00391 << " for variables";
00392 for( int i=0;i<valVars.size();i++ )
00393 cout << " \"" << valVars[i].c_str() << "\"";
00394 cout << endl;
00395 cout << "Total number of points read: " << valFilter->size() << endl;
00396 cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0])
00397 << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl;
00398 }
00399
00400
00401 if( valFilter.get() != 0 ) {
00402 if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00403 cerr << "Cannot choose input classes for string "
00404 << inputClassesString << endl;
00405 return 2;
00406 }
00407 valFilter->classes(inputClasses);
00408 cout << "Validation data filtered by class." << endl;
00409 for( int i=0;i<inputClasses.size();i++ ) {
00410 cout << "Points in class " << inputClasses[i] << ": "
00411 << valFilter->ptsInClass(inputClasses[i]) << endl;
00412 }
00413 }
00414
00415
00416 if( scaleWeights && valFilter.get()!=0 )
00417 valFilter->scaleWeights(inputClasses[1],sW);
00418
00419
00420 if( setLowCutoff && valFilter.get()!=0 ) {
00421 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00422 cerr << "Unable to replace missing values in validation data." << endl;
00423 return 2;
00424 }
00425 else
00426 cout << "Values below " << lowCutoff << " in validation data"
00427 << " have been replaced with medians." << endl;
00428 }
00429
00430
00431 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00432 if( !transformerFile.empty() ) {
00433 SprVarTransformerReader transReader;
00434 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00435 if( t == 0 ) {
00436 cerr << "Unable to read VarTransformer from file "
00437 << transformerFile.c_str() << endl;
00438 return 2;
00439 }
00440 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00441 SprTransformerFilter* t_valid = 0;
00442 if( valFilter.get() != 0 )
00443 t_valid = new SprTransformerFilter(valFilter.get());
00444 bool replaceOriginalData = true;
00445 if( !t_train->transform(t,replaceOriginalData) ) {
00446 cerr << "Unable to apply VarTransformer to training data." << endl;
00447 return 2;
00448 }
00449 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00450 cerr << "Unable to apply VarTransformer to validation data." << endl;
00451 return 2;
00452 }
00453 cout << "Variable transformation from file "
00454 << transformerFile.c_str() << " has been applied to "
00455 << "training and validation data." << endl;
00456 garbage_train.reset(filter.release());
00457 garbage_valid.reset(valFilter.release());
00458 filter.reset(t_train);
00459 valFilter.reset(t_valid);
00460 }
00461
00462
00463 auto_ptr<SprAverageLoss> loss;
00464 switch( iLoss )
00465 {
00466 case 1 :
00467 if( adaCycles > 1 ) {
00468 loss.reset(new SprAverageLoss(&SprLoss::quadratic,
00469 &SprTransformation::logit));
00470 }
00471 else {
00472 loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00473 }
00474 cout << "Per-event loss set to "
00475 << "Quadratic loss (y-f(x))^2 " << endl;
00476 useStandardAB = true;
00477 break;
00478 case 2 :
00479 if( adaCycles > 1 ) {
00480 loss.reset(new SprAverageLoss(&SprLoss::exponential));
00481 }
00482 else {
00483 loss.reset(new SprAverageLoss(&SprLoss::exponential,
00484 &SprTransformation::logitInverse));
00485 }
00486 cout << "Per-event loss set to "
00487 << "Exponential loss exp(-y*f(x)) " << endl;
00488 useStandardAB = true;
00489 break;
00490 default :
00491 cout << "No per-event loss is chosen. Will use the default." << endl;
00492 break;
00493 }
00494
00495
00496 SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete;
00497 switch( iAdaBoostMode )
00498 {
00499 case 1 :
00500 abMode = SprTrainedAdaBoost::Discrete;
00501 cout << "Will train Discrete AdaBoost." << endl;
00502 break;
00503 case 2 :
00504 abMode = SprTrainedAdaBoost::Real;
00505 cout << "Will train Real AdaBoost." << endl;
00506 break;
00507 case 3 :
00508 abMode = SprTrainedAdaBoost::Epsilon;
00509 cout << "Will train Epsilon AdaBoost." << endl;
00510 break;
00511 default :
00512 cout << "Will train Discrete AdaBoost." << endl;
00513 break;
00514 }
00515
00516
00517 int resume = int(!resumeFile.empty())
00518 + int(!resumeNNFile.empty())
00519 + int(!resumeSNNSFile.empty());
00520 if( resume > 1 ) {
00521 cerr << "Reading more than one classifier configuration is not allowed."
00522 << " Requested: " << resume << endl;
00523 return 5;
00524 }
00525 if( (!resumeNNFile.empty() || !resumeSNNSFile.empty())
00526 && !netConfig.empty() ) {
00527 cerr << "What do you want to do - read NN configuration from a file "
00528 << "or specify configuration on the command line? "
00529 << "Life is tough - you cannot do both." << endl;
00530 return 5;
00531 }
00532
00533
00534 auto_ptr<SprStdBackprop> stdnn;
00535 if( adaCycles>0 && resumeNNFile.empty() && resumeSNNSFile.empty() ) {
00536 stdnn.reset(new SprStdBackprop(filter.get(),
00537 netConfig.c_str(),
00538 nnCycles,
00539 eta));
00540 if( !stdnn->init(initEta,initPoints) ) {
00541 cerr << "Unable to initialize neural net." << endl;
00542 return 6;
00543 }
00544 }
00545 else {
00546 stdnn.reset(new SprStdBackprop(filter.get(),
00547 nnCycles,
00548 eta));
00549 }
00550
00551
00552 SprTrainedStdBackprop* trainedNN = 0;
00553 if( !resumeSNNSFile.empty() ) {
00554 if( !stdnn->readSNNS(resumeSNNSFile.c_str()) ) {
00555 cerr << "Unable to read SNNS configuration from file "
00556 << resumeSNNSFile.c_str() << endl;
00557 return 6;
00558 }
00559 trainedNN = stdnn->makeTrained();
00560 cout << "Read SNNS configuration from file "
00561 << resumeSNNSFile.c_str() << endl;
00562 }
00563 if( !resumeNNFile.empty() ) {
00564 if( !SprClassifierReader::readTrainable(resumeNNFile.c_str(),
00565 stdnn.get(),verbose) ) {
00566 cerr << "Unable to read SPR NN configuration from file "
00567 << resumeNNFile.c_str() << endl;
00568 return 6;
00569 }
00570 trainedNN = stdnn->makeTrained();
00571 cout << "Read SPR neural net configuration from file "
00572 << resumeNNFile.c_str() << endl;
00573 }
00574
00575
00576 auto_ptr<SprAbsClassifier> classifier;
00577 if( adaCycles != 1 ) {
00578
00579 SprAdaBoost* ab = new SprAdaBoost(filter.get(),
00580 adaCycles,
00581 useStandardAB,
00582 abMode);
00583 cout << "Setting epsilon to " << epsilon << endl;
00584 ab->setEpsilon(epsilon);
00585
00586
00587 if( skipInitialEventReweighting ) ab->skipInitialEventReweighting(true);
00588
00589
00590 if( valFilter.get()!=0 && !valFilter->empty() )
00591 ab->setValidation(valFilter.get(),valPrint,loss.get());
00592
00593
00594 if( resumeFile.empty() ) {
00595 if( trainedNN != 0 ) {
00596 if( !ab->addTrained(trainedNN,true) ) {
00597 cerr << "Unable to add first trained NN to AdaBoost." << endl;
00598 return 6;
00599 }
00600 }
00601 }
00602 else {
00603 if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00604 ab,verbose) ) {
00605 cerr << "Failed to read saved AdaBoost from file "
00606 << resumeFile.c_str() << endl;
00607 return 6;
00608 }
00609 cout << "Read saved AdaBoost from file " << resumeFile.c_str()
00610 << " with " << ab->nTrained() << " trained classifiers." << endl;
00611 }
00612
00613
00614 if( !ab->addTrainable(stdnn.get()) ) {
00615 cerr << "Unable to add neural net to AdaBoost." << endl;
00616 return 6;
00617 }
00618
00619
00620 classifier.reset(ab);
00621 }
00622 else {
00623
00624 if( valFilter.get()!=0 && !valFilter->empty() )
00625 stdnn->setValidation(valFilter.get(),valPrint,loss.get());
00626
00627
00628 classifier.reset(stdnn.release());
00629 }
00630
00631
00632 if( !classifier->train(verbose) ) {
00633 cerr << "Training terminated with error." << endl;
00634 return 7;
00635 }
00636 else {
00637 cout << "Training done." << endl;
00638 if( adaCycles != 1 ) {
00639 SprAdaBoost* ab = static_cast<SprAdaBoost*>(classifier.get());
00640 cout << "AdaBoost finished training with " << ab->nTrained()
00641 << " classifiers." << endl;
00642 }
00643 }
00644
00645
00646 if( !outFile.empty() ) {
00647 if( !classifier->store(outFile.c_str()) ) {
00648 cerr << "Cannot store classifier in file " << outFile.c_str() << endl;
00649 return 8;
00650 }
00651 }
00652
00653
00654 if( adaCycles > 1 ) {
00655 if( !weightedDataOut.empty() ) {
00656 SprAdaBoost* ab = static_cast<SprAdaBoost*>(classifier.get());
00657 if( !ab->storeData(weightedDataOut.c_str()) ) {
00658 cerr << "Cannot store weighted AdaBoost data to file "
00659 << weightedDataOut.c_str() << endl;
00660 return 9;
00661 }
00662 }
00663 }
00664
00665
00666 auto_ptr<SprAbsTrainedClassifier> trained(classifier->makeTrained());
00667 if( trained.get() == 0 ) {
00668 cerr << "Unable to get trained classifier." << endl;
00669 return 9;
00670 }
00671
00672
00673 if( tupleFile.empty() )
00674 return 0;
00675
00676
00677 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00678 if( !tuple->init(tupleFile.c_str()) ) {
00679 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00680 return 10;
00681 }
00682
00683
00684
00685 string printVarsDoNotFeed;
00686 vector<vector<string> > varsDoNotFeed;
00687 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00688 vector<unsigned> mapper;
00689 for( int d=0;d<vars.size();d++ ) {
00690 if( varsDoNotFeed.empty() ||
00691 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00692 ==varsDoNotFeed[0].end()) ) {
00693 mapper.push_back(d);
00694 }
00695 else {
00696 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00697 printVarsDoNotFeed += vars[d];
00698 }
00699 }
00700 if( !printVarsDoNotFeed.empty() ) {
00701 cout << "The following variables are not used in the algorithm, "
00702 << "but will be included in the output file: "
00703 << printVarsDoNotFeed.c_str() << endl;
00704 }
00705
00706
00707 SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00708 string classifierName;
00709 if( adaCycles != 1 )
00710 classifierName = "adann";
00711 else
00712 classifierName = "nn";
00713 feeder.addClassifier(trained.get(),classifierName.c_str());
00714 if( !feeder.feed(1000) ) {
00715 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00716 return 11;
00717 }
00718
00719
00720 return 0;
00721 }