CMS 3D CMS Logo

SprStdBackpropApp.cc

Go to the documentation of this file.
00001 //$Id: SprStdBackpropApp.cc,v 1.4 2007/11/12 06:19:11 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprStdBackprop.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedStdBackprop.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00026 
00027 #include <stdlib.h>
00028 #include <unistd.h>
00029 #include <iostream>
00030 #include <vector>
00031 #include <set>
00032 #include <string>
00033 #include <memory>
00034 #include <iomanip>
00035 
00036 using namespace std;
00037 
00038 
00039 void help(const char* prog) 
00040 {
00041   cout << "Usage:  " << prog 
00042        << " training_data_file " << endl;
00043   cout << "\t Options: " << endl;
00044   cout << "\t-h --- help                                        " << endl;
00045   cout << "\t-o output Tuple file                               " << endl;
00046   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00047   cout << "\t-A save output data in ascii instead of Root       " << endl;
00048   cout << "\t-M AdaBoost mode                                   " << endl;
00049   cout << "\t\t 1 = Discrete AdaBoost (default)                 " << endl;
00050   cout << "\t\t 2 = Real AdaBoost                               " << endl;
00051   cout << "\t\t 3 = Epsilon AdaBoost                            " << endl;
00052   cout << "\t-E epsilon for Epsilon and Real AdaBoosts (def=0.01)" << endl;
00053   cout << "\t-n number of AdaBoost training cycles (1 for single NN)" << endl;
00054   cout << "\t-l number of Neural Net training cycles            " << endl;
00055   cout << "\t-N neural net configuration, e.g., '6:3:1' (see SprStdBackprop.hh)" << endl;
00056   cout << "\t-L learning rate of the network (default=0.1)      " << endl;
00057   cout << "\t-I learning rate for network initialization (def=0.1)" << endl;
00058   cout << "\t-i number of input points to use for initialization (def=all)"
00059        << endl;
00060   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00061   cout << "\t-Q apply variable transformation saved in file     " << endl;
00062   cout << "\t-g per-event loss for (cross-)validation           " << endl;
00063   cout << "\t\t 1 - quadratic loss (y-f(x))^2                   " << endl;
00064   cout << "\t\t 2 - exponential loss exp(-y*f(x))               " << endl;
00065   cout << "\t-m replace data values below this cutoff with medians" << endl;
00066   cout << "\t-s use standard AdaBoost (see SprTrainedAdaBoost.hh)"<< endl;
00067   cout << "\t-e skip initial event reweighting when resuming    " << endl;
00068   cout << "\t-u store data with modified weights to file        " << endl;
00069   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00070   cout << "\t-f store trained AdaBoost to file                  " << endl;
00071   cout << "\t-r resume training for AdaBoost stored in file     " << endl;
00072   cout << "\t-R resume training for a single neural net stored in file" 
00073        << endl;
00074   cout << "\t-S resume training from SNNS configuration stored in file" 
00075        << endl;
00076   cout << "\t-K keep this fraction in training set and          " << endl;
00077   cout << "\t\t put the rest into validation set                " << endl;
00078   cout << "\t-D randomize training set split-up                 " << endl;
00079   cout << "\t-t read validation/test data from a file           " << endl;
00080   cout << "\t\t (must be in same format as input data!!!        " << endl;
00081   cout << "\t-d frequency of print-outs for validation data     " << endl;
00082   cout << "\t-w scale all signal weights by this factor         " << endl;
00083   cout << "\t-V include only these input variables              " << endl;
00084   cout << "\t-z exclude input variables from the list           " << endl;
00085   cout << "\t-Z exclude input variables from the list, "
00086        << "but put them in the output file " << endl;
00087   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00088        << endl;
00089 }
00090 
00091 
00092 int main(int argc, char ** argv)
00093 {
00094   // check command line
00095   if( argc < 2 ) {
00096     help(argv[0]);
00097     return 1;
00098   }
00099 
00100   // init
00101   string tupleFile;
00102   int readMode = 0;
00103   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00104   unsigned adaCycles = 0;
00105   unsigned nnCycles = 0;
00106   double eta = 0.1;
00107   int iLoss = 1;
00108   int verbose = 0;
00109   string outFile;
00110   string valFile;
00111   unsigned valPrint = 0;
00112   bool scaleWeights = false;
00113   double sW = 1.;
00114   bool useStandardAB = false;
00115   int iAdaBoostMode = 1;
00116   double epsilon = 0.01;
00117   bool skipInitialEventReweighting = false;
00118   string weightedDataOut;
00119   bool setLowCutoff = false;
00120   double lowCutoff = 0;
00121   string includeList, excludeList;
00122   string inputClassesString;
00123   string stringVarsDoNotFeed;
00124   string resumeFile, resumeSNNSFile, resumeNNFile;
00125   string netConfig;
00126   double initEta = 0.1;
00127   unsigned initPoints = 0;
00128   bool split = false;
00129   double splitFactor = 0;
00130   bool splitRandomize = false;
00131   string transformerFile;
00132 
00133   // decode command line
00134   int c;
00135   extern char* optarg;
00136   //  extern int optind;
00137   while((c = getopt(argc,argv,"ho:a:AM:E:n:l:N:L:I:i:y:Q:g:m:seu:v:f:r:R:S:K:Dt:d:w:V:z:Z:")) != EOF ) {
00138     switch( c )
00139       {
00140       case 'h' :
00141         help(argv[0]);
00142         return 1;
00143       case 'M' :
00144         iAdaBoostMode = (optarg==0 ? 1 : atoi(optarg));
00145         break;
00146       case 'E' :
00147         epsilon = (optarg==0 ? 0.01 : atof(optarg));
00148         break;
00149       case 'o' :
00150         tupleFile = optarg;
00151         break;
00152       case 'a' :
00153         readMode = (optarg==0 ? 0 : atoi(optarg));
00154         break;
00155       case 'A' :
00156         writeMode = SprRWFactory::Ascii;
00157         break;
00158       case 'n' :
00159         adaCycles = (optarg==0 ? 1 : atoi(optarg));
00160         break;
00161       case 'l' :
00162         nnCycles = (optarg==0 ? 1 : atoi(optarg));
00163         break;
00164       case 'N' :
00165         netConfig = optarg;
00166         break;
00167       case 'L' :
00168         eta = (optarg==0 ? 0.1 : atof(optarg));
00169         break;
00170       case 'I' :
00171         initEta = (optarg==0 ? 0.1 : atof(optarg));
00172         break;
00173       case 'i' :
00174         initPoints = (optarg==0 ? 0 : atoi(optarg));
00175         break;
00176       case 'y' :
00177         inputClassesString = optarg;
00178         break;
00179       case 'Q' :
00180         transformerFile = optarg;
00181         break;
00182       case 'g' :
00183         iLoss = (optarg==0 ? 0 : atoi(optarg));
00184         break;
00185       case 'm' :
00186         if( optarg != 0 ) {
00187           setLowCutoff = true;
00188           lowCutoff = atof(optarg);
00189         }
00190         break;
00191       case 's' :
00192         useStandardAB = true;
00193         break;
00194       case 'e' :
00195         skipInitialEventReweighting = true;
00196         break;
00197       case 'u' :
00198         weightedDataOut = optarg;
00199         break;
00200       case 'v' :
00201         verbose = (optarg==0 ? 0 : atoi(optarg));
00202         break;
00203       case 'f' :
00204         outFile = optarg;
00205         break;
00206       case 'r' :
00207         resumeFile = optarg;
00208         break;
00209       case 'R' :
00210         resumeNNFile = optarg;
00211         break;
00212       case 'S' :
00213         resumeSNNSFile = optarg;
00214         break;
00215       case 'K' :
00216         split = true;
00217         splitFactor = (optarg==0 ? 0 : atof(optarg));
00218         break;
00219       case 'D' :
00220         splitRandomize = true;
00221         break;
00222       case 't' :
00223         valFile = optarg;
00224         break;
00225       case 'd' :
00226         valPrint = (optarg==0 ? 0 : atoi(optarg));
00227         break;
00228       case 'w' :
00229         if( optarg != 0 ) {
00230           scaleWeights = true;
00231           sW = atof(optarg);
00232         }
00233         break;
00234       case 'V' :
00235         includeList = optarg;
00236         break;
00237       case 'z' :
00238         excludeList = optarg;
00239         break;
00240       case 'Z' :
00241         stringVarsDoNotFeed = optarg;
00242         break;
00243       }
00244   }
00245 
00246   // Get training file.
00247   string trFile = argv[argc-1];
00248 
00249   // make reader
00250   SprRWFactory::DataType inputType 
00251     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00252   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00253 
00254   // include variables
00255   set<string> includeSet;
00256   if( !includeList.empty() ) {
00257     vector<vector<string> > includeVars;
00258     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00259     assert( !includeVars.empty() );
00260     for( int i=0;i<includeVars[0].size();i++ ) 
00261       includeSet.insert(includeVars[0][i]);
00262     if( !reader->chooseVars(includeSet) ) {
00263       cerr << "Unable to include variables in training set." << endl;
00264       return 2;
00265     }
00266     else {
00267       cout << "Following variables have been included in optimization: ";
00268       for( set<string>::const_iterator 
00269              i=includeSet.begin();i!=includeSet.end();i++ )
00270         cout << "\"" << *i << "\"" << " ";
00271       cout << endl;
00272     }
00273   }
00274 
00275   // exclude variables
00276   set<string> excludeSet;
00277   if( !excludeList.empty() ) {
00278     vector<vector<string> > excludeVars;
00279     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00280     assert( !excludeVars.empty() );
00281     for( int i=0;i<excludeVars[0].size();i++ ) 
00282       excludeSet.insert(excludeVars[0][i]);
00283     if( !reader->chooseAllBut(excludeSet) ) {
00284       cerr << "Unable to exclude variables from training set." << endl;
00285       return 2;
00286     }
00287     else {
00288       cout << "Following variables have been excluded from optimization: ";
00289       for( set<string>::const_iterator 
00290              i=excludeSet.begin();i!=excludeSet.end();i++ )
00291         cout << "\"" << *i << "\"" << " ";
00292       cout << endl;
00293     }
00294   }
00295 
00296   // read training data from file
00297   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00298   if( filter.get() == 0 ) {
00299     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00300     return 2;
00301   }
00302   vector<string> vars;
00303   filter->vars(vars);
00304   cout << "Read data from file " << trFile.c_str() 
00305        << " for variables";
00306   for( int i=0;i<vars.size();i++ ) 
00307     cout << " \"" << vars[i].c_str() << "\"";
00308   cout << endl;
00309   cout << "Total number of points read: " << filter->size() << endl;
00310 
00311   // filter training data by class
00312   vector<SprClass> inputClasses;
00313   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00314     cerr << "Cannot choose input classes for string " 
00315          << inputClassesString << endl;
00316     return 2;
00317   }
00318   filter->classes(inputClasses);
00319   assert( inputClasses.size() > 1 );
00320   cout << "Training data filtered by class." << endl;
00321   for( int i=0;i<inputClasses.size();i++ ) {
00322     cout << "Points in class " << inputClasses[i] << ":   " 
00323          << filter->ptsInClass(inputClasses[i]) << endl;
00324   }
00325 
00326   // scale weights
00327   if( scaleWeights ) {
00328     cout << "Signal weights are multiplied by " << sW << endl;
00329     filter->scaleWeights(inputClasses[1],sW);
00330   }
00331 
00332   // apply low cutoff
00333   if( setLowCutoff ) {
00334     if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00335       cerr << "Unable to replace missing values in training data." << endl;
00336       return 2;
00337     }
00338     else
00339       cout << "Values below " << lowCutoff << " in training data"
00340            << " have been replaced with medians." << endl;
00341   }
00342 
00343   // read validation data from file
00344   auto_ptr<SprAbsFilter> valFilter;
00345   if( split && !valFile.empty() ) {
00346     cerr << "Unable to split training data and use validation data " 
00347          << "from a separate file." << endl;
00348     return 2;
00349   }
00350   if( split && valPrint!=0 ) {
00351     cout << "Splitting training data with factor " << splitFactor << endl;
00352     if( splitRandomize )
00353       cout << "Will use randomized splitting." << endl;
00354     vector<double> weights;
00355     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00356     if( splitted == 0 ) {
00357       cerr << "Unable to split training data." << endl;
00358       return 2;
00359     }
00360     bool ownData = true;
00361     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00362     cout << "Training data re-filtered:" << endl;
00363     for( int i=0;i<inputClasses.size();i++ ) {
00364       cout << "Points in class " << inputClasses[i] << ":   " 
00365            << filter->ptsInClass(inputClasses[i]) << endl;
00366     }
00367   }
00368   if( !valFile.empty() && valPrint!=0 ) {
00369     auto_ptr<SprAbsReader> 
00370       valReader(SprRWFactory::makeReader(inputType,readMode));
00371     if( !includeSet.empty() ) {
00372       if( !valReader->chooseVars(includeSet) ) {
00373         cerr << "Unable to include variables in validation set." << endl;
00374         return 2;
00375       }
00376     }
00377     if( !excludeSet.empty() ) {
00378       if( !valReader->chooseAllBut(excludeSet) ) {
00379         cerr << "Unable to exclude variables from validation set." << endl;
00380         return 2;
00381       }
00382     }
00383     valFilter.reset(valReader->read(valFile.c_str()));
00384     if( valFilter.get() == 0 ) {
00385       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00386       return 2;
00387     }
00388     vector<string> valVars;
00389     valFilter->vars(valVars);
00390     cout << "Read validation data from file " << valFile.c_str() 
00391          << " for variables";
00392     for( int i=0;i<valVars.size();i++ ) 
00393       cout << " \"" << valVars[i].c_str() << "\"";
00394     cout << endl;
00395     cout << "Total number of points read: " << valFilter->size() << endl;
00396     cout << "Points in class 0: " << valFilter->ptsInClass(inputClasses[0])
00397          << " 1: " << valFilter->ptsInClass(inputClasses[1]) << endl;
00398   }
00399 
00400   // filter validation data by class
00401   if( valFilter.get() != 0 ) {
00402     if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00403       cerr << "Cannot choose input classes for string " 
00404            << inputClassesString << endl;
00405       return 2;
00406     }
00407     valFilter->classes(inputClasses);
00408     cout << "Validation data filtered by class." << endl;
00409     for( int i=0;i<inputClasses.size();i++ ) {
00410       cout << "Points in class " << inputClasses[i] << ":   " 
00411            << valFilter->ptsInClass(inputClasses[i]) << endl;
00412     }
00413   }
00414 
00415   // scale weights
00416   if( scaleWeights && valFilter.get()!=0 )
00417     valFilter->scaleWeights(inputClasses[1],sW);
00418 
00419   // apply low cutoff
00420   if( setLowCutoff && valFilter.get()!=0 ) {
00421     if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00422       cerr << "Unable to replace missing values in validation data." << endl;
00423       return 2;
00424     }
00425     else
00426       cout << "Values below " << lowCutoff << " in validation data"
00427            << " have been replaced with medians." << endl;
00428   }
00429 
00430   // apply transformation of variables to training and test data
00431   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00432   if( !transformerFile.empty() ) {
00433     SprVarTransformerReader transReader;
00434     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00435     if( t == 0 ) {
00436       cerr << "Unable to read VarTransformer from file "
00437            << transformerFile.c_str() << endl;
00438       return 2;
00439     }
00440     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00441     SprTransformerFilter* t_valid = 0;
00442     if( valFilter.get() != 0 )
00443       t_valid = new SprTransformerFilter(valFilter.get());
00444     bool replaceOriginalData = true;
00445     if( !t_train->transform(t,replaceOriginalData) ) {
00446       cerr << "Unable to apply VarTransformer to training data." << endl;
00447       return 2;
00448     }
00449     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00450       cerr << "Unable to apply VarTransformer to validation data." << endl;
00451       return 2;
00452     }
00453     cout << "Variable transformation from file "
00454          << transformerFile.c_str() << " has been applied to "
00455          << "training and validation data." << endl;
00456     garbage_train.reset(filter.release());
00457     garbage_valid.reset(valFilter.release());
00458     filter.reset(t_train);
00459     valFilter.reset(t_valid);
00460   }
00461 
00462   // make per-event loss
00463   auto_ptr<SprAverageLoss> loss;
00464   switch( iLoss )
00465     {
00466     case 1 :
00467       if( adaCycles > 1 ) {
00468         loss.reset(new SprAverageLoss(&SprLoss::quadratic,
00469                                       &SprTransformation::logit));
00470       }
00471       else {
00472         loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00473       }
00474       cout << "Per-event loss set to "
00475            << "Quadratic loss (y-f(x))^2 " << endl;
00476       useStandardAB = true;
00477       break;
00478     case 2 :
00479       if( adaCycles > 1 ) {
00480         loss.reset(new SprAverageLoss(&SprLoss::exponential));
00481       }
00482       else {
00483         loss.reset(new SprAverageLoss(&SprLoss::exponential,
00484                                       &SprTransformation::logitInverse));
00485       }
00486       cout << "Per-event loss set to "
00487            << "Exponential loss exp(-y*f(x)) " << endl;
00488       useStandardAB = true;
00489       break;
00490     default :
00491       cout << "No per-event loss is chosen. Will use the default." << endl;
00492       break;
00493     }
00494 
00495   // make AdaBoost mode
00496   SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete;
00497   switch( iAdaBoostMode )
00498     {
00499     case 1 :
00500       abMode = SprTrainedAdaBoost::Discrete;
00501       cout << "Will train Discrete AdaBoost." << endl;
00502       break;
00503     case 2 :
00504       abMode = SprTrainedAdaBoost::Real;
00505       cout << "Will train Real AdaBoost." << endl;
00506       break;
00507     case 3 :
00508       abMode = SprTrainedAdaBoost::Epsilon;
00509       cout << "Will train Epsilon AdaBoost." << endl;
00510       break;
00511    default :
00512       cout << "Will train Discrete AdaBoost." << endl;
00513       break;
00514     }
00515 
00516   // sanity check
00517   int resume = int(!resumeFile.empty()) 
00518     + int(!resumeNNFile.empty())
00519     + int(!resumeSNNSFile.empty());
00520   if( resume > 1 ) {
00521     cerr << "Reading more than one classifier configuration is not allowed." 
00522          << " Requested: " << resume << endl;
00523     return 5;
00524   }
00525   if( (!resumeNNFile.empty() || !resumeSNNSFile.empty()) 
00526       && !netConfig.empty() ) {
00527     cerr << "What do you want to do - read NN configuration from a file " 
00528          << "or specify configuration on the command line? "
00529          << "Life is tough - you cannot do both." << endl;
00530     return 5;
00531   }
00532 
00533   // make a single NN
00534   auto_ptr<SprStdBackprop> stdnn;
00535   if( adaCycles>0 && resumeNNFile.empty() && resumeSNNSFile.empty() ) {
00536     stdnn.reset(new SprStdBackprop(filter.get(),
00537                                    netConfig.c_str(),
00538                                    nnCycles,
00539                                    eta));
00540     if( !stdnn->init(initEta,initPoints) ) {
00541       cerr << "Unable to initialize neural net." << endl;
00542       return 6;
00543     }
00544   }
00545   else {
00546     stdnn.reset(new SprStdBackprop(filter.get(),
00547                                    nnCycles,
00548                                    eta));
00549   }
00550   
00551   // read saved NN from file
00552   SprTrainedStdBackprop* trainedNN = 0;
00553   if( !resumeSNNSFile.empty() ) {
00554     if( !stdnn->readSNNS(resumeSNNSFile.c_str()) ) {
00555       cerr << "Unable to read SNNS configuration from file " 
00556            << resumeSNNSFile.c_str() << endl;
00557       return 6;
00558     }
00559     trainedNN = stdnn->makeTrained();
00560     cout << "Read SNNS configuration from file " 
00561          << resumeSNNSFile.c_str() << endl;
00562   }
00563   if( !resumeNNFile.empty() ) {
00564     if( !SprClassifierReader::readTrainable(resumeNNFile.c_str(),
00565                                             stdnn.get(),verbose) ) {
00566       cerr << "Unable to read SPR NN configuration from file " 
00567            << resumeNNFile.c_str() << endl;
00568       return 6;
00569     }
00570     trainedNN = stdnn->makeTrained();
00571     cout << "Read SPR neural net configuration from file " 
00572          << resumeNNFile.c_str() << endl;
00573   }
00574               
00575   // make classifier to train
00576   auto_ptr<SprAbsClassifier> classifier;
00577   if( adaCycles != 1 ) {
00578     // make AdaBoost
00579     SprAdaBoost* ab = new SprAdaBoost(filter.get(),
00580                                       adaCycles,
00581                                       useStandardAB,
00582                                       abMode);
00583     cout << "Setting epsilon to " << epsilon << endl;
00584     ab->setEpsilon(epsilon);
00585     
00586     // skip reweigting
00587     if( skipInitialEventReweighting ) ab->skipInitialEventReweighting(true);
00588 
00589     // set validation
00590     if( valFilter.get()!=0 && !valFilter->empty() )
00591       ab->setValidation(valFilter.get(),valPrint,loss.get());
00592     
00593     // read saved AdaBoost
00594     if( resumeFile.empty() ) {
00595       if( trainedNN != 0 ) {
00596         if( !ab->addTrained(trainedNN,true) ) {
00597           cerr << "Unable to add first trained NN to AdaBoost." << endl;
00598           return 6;
00599         }
00600       }
00601     }
00602     else {
00603       if( !SprClassifierReader::readTrainable(resumeFile.c_str(),
00604                                               ab,verbose) ) {
00605         cerr << "Failed to read saved AdaBoost from file " 
00606              << resumeFile.c_str() << endl;
00607         return 6;
00608       }
00609       cout << "Read saved AdaBoost from file " << resumeFile.c_str()
00610            << " with " << ab->nTrained() << " trained classifiers." << endl;
00611     }
00612 
00613     // add a trainable NN
00614     if( !ab->addTrainable(stdnn.get()) ) {
00615       cerr << "Unable to add neural net to AdaBoost." << endl;
00616       return 6;
00617     }
00618     
00619     // reset classifier
00620     classifier.reset(ab);
00621   }
00622   else {
00623     // set validation
00624     if( valFilter.get()!=0 && !valFilter->empty() )
00625       stdnn->setValidation(valFilter.get(),valPrint,loss.get());
00626     
00627     // reset classifier
00628     classifier.reset(stdnn.release());
00629   }
00630   
00631   // train
00632   if( !classifier->train(verbose) ) {
00633     cerr << "Training terminated with error." << endl;
00634     return 7;
00635   }
00636   else {
00637     cout << "Training done." << endl;
00638     if( adaCycles != 1 ) {
00639       SprAdaBoost* ab = static_cast<SprAdaBoost*>(classifier.get());
00640       cout << "AdaBoost finished training with " << ab->nTrained() 
00641            << " classifiers." << endl;    
00642     }
00643   }
00644 
00645   // save trained classifier
00646   if( !outFile.empty() ) {
00647     if( !classifier->store(outFile.c_str()) ) {
00648       cerr << "Cannot store classifier in file " << outFile.c_str() << endl;
00649       return 8;
00650     }
00651   }
00652 
00653   // save reweighted data
00654   if( adaCycles > 1 ) {
00655     if( !weightedDataOut.empty() ) {
00656       SprAdaBoost* ab = static_cast<SprAdaBoost*>(classifier.get());
00657       if( !ab->storeData(weightedDataOut.c_str()) ) {
00658         cerr << "Cannot store weighted AdaBoost data to file " 
00659              << weightedDataOut.c_str() << endl;
00660         return 9;
00661       }
00662     }
00663   }
00664 
00665   // make a trained AdaBoost
00666   auto_ptr<SprAbsTrainedClassifier> trained(classifier->makeTrained());
00667   if( trained.get() == 0 ) {
00668     cerr << "Unable to get trained classifier." << endl;
00669     return 9;
00670   }
00671 
00672   // make histogram if requested
00673   if( tupleFile.empty() ) 
00674     return 0;
00675 
00676   // make a writer
00677   auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00678   if( !tuple->init(tupleFile.c_str()) ) {
00679     cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00680     return 10;
00681   }
00682 
00683   // determine if certain variables are to be excluded from usage,
00684   // but included in the output storage file (-Z option)
00685   string printVarsDoNotFeed;
00686   vector<vector<string> > varsDoNotFeed;
00687   SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00688   vector<unsigned> mapper;
00689   for( int d=0;d<vars.size();d++ ) {
00690     if( varsDoNotFeed.empty() ||
00691         (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00692          ==varsDoNotFeed[0].end()) ) {
00693       mapper.push_back(d);
00694     }
00695     else {
00696       printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00697       printVarsDoNotFeed += vars[d];
00698     }
00699   }
00700   if( !printVarsDoNotFeed.empty() ) {
00701     cout << "The following variables are not used in the algorithm, " 
00702          << "but will be included in the output file: " 
00703          << printVarsDoNotFeed.c_str() << endl;
00704   }
00705 
00706   // feed
00707   SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00708   string classifierName;
00709   if( adaCycles != 1 )
00710     classifierName = "adann";
00711   else
00712     classifierName = "nn";
00713   feeder.addClassifier(trained.get(),classifierName.c_str());
00714   if( !feeder.feed(1000) ) {
00715     cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00716     return 11;
00717   }
00718 
00719   // exit
00720   return 0;
00721 }

Generated on Tue Jun 9 17:41:59 2009 for CMSSW by  doxygen 1.5.4