CMS 3D CMS Logo

SprOutputWriterApp.cc

Go to the documentation of this file.
00001 //$Id: SprOutputWriterApp.cc,v 1.6 2007/12/01 01:29:41 narsky Exp $
00002 
00003 
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprCoordinateMapper.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedFisher.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedLogitR.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00022 
00023 #include <stdlib.h>
00024 #include <unistd.h>
00025 #include <iostream>
00026 #include <set>
00027 #include <vector>
00028 #include <memory>
00029 #include <string>
00030 #include <cassert>
00031 #include <algorithm>
00032 
00033 using namespace std;
00034 
00035 
00036 void help(const char* prog) 
00037 {
00038   cout << "Usage:  " << prog << " list_of_classifier_config_files"
00039        << " input_data_file output_tuple_file" << endl;
00040   cout << "\t (List of files must be in quotes, separated by commas.)" << endl;
00041   cout << "\t Options: " << endl;
00042   cout << "\t-h --- help                                        " << endl;
00043   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00044   cout << "\t-Q apply variable transformation saved in file     " << endl;
00045   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00046   cout << "\t-A save output data in ascii instead of Root       " << endl;
00047   cout << "\t-K use 1-fraction of input data                    " << endl;
00048   cout << "\t\t This option is for consistency with other execs." << endl;
00049   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00050   cout << "\t-w scale all signal weights by this factor         " << endl;
00051   cout << "\t-t output tuple name (default=data)                " << endl;
00052   cout << "\t-C output classifier names (in quotes, separated by commas)" 
00053        << endl;
00054   cout << "\t-p feeder print-out frequency (default=1000 events)" << endl;
00055   cout << "\t-s use output in range (-infty,+infty) instead of [0,1]" << endl;
00056   cout << "\t-V include only these input variables              " << endl;
00057   cout << "\t-z exclude input variables from the list           " << endl;
00058   cout << "\t-Z exclude input variables from the list, "
00059        << "but put them in the output file " << endl;
00060   cout << "\t-M map variable lists from trained classifiers onto" << endl;
00061   cout << "\t\t variables available in input data."               << endl;
00062   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00063        << endl;
00064 }
00065 
00066 
00067 void cleanup(vector<SprAbsTrainedClassifier*>& trained) {
00068   for( int i=0;i<trained.size();i++ ) delete trained[i];
00069 }
00070 
00071 
00072 int main(int argc, char ** argv)
00073 {
00074   // check command line
00075   if( argc < 4 ) {
00076     help(argv[0]);
00077     return 1;
00078   }
00079 
00080   // init
00081   int readMode = 0;
00082   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00083   int verbose = 0;
00084   bool scaleWeights = false;
00085   double sW = 1.;
00086   bool useStandard = false;
00087   string tupleName;
00088   string classifierNameList;
00089   string includeList, excludeList;
00090   string inputClassesString;
00091   int nPrintOut = 1000;
00092   string stringVarsDoNotFeed;
00093   bool mapTrainedVars = false;
00094   bool split = false;
00095   double splitFactor = 0;
00096   string transformerFile;
00097 
00098  
00099   // decode command line
00100   int c;
00101   extern char* optarg;
00102   extern int optind;
00103   while( (c = getopt(argc,argv,"hy:Q:a:AK:v:w:t:C:p:sV:z:Z:M")) != EOF ) {
00104     switch( c )
00105       {
00106       case 'h' :
00107         help(argv[0]);
00108         return 1;
00109       case 'y' :
00110         inputClassesString = optarg;
00111         break;
00112       case 'Q' :
00113         transformerFile = optarg;
00114         break;
00115       case 'a' :
00116         readMode = (optarg==0 ? 0 : atoi(optarg));
00117         break;
00118       case 'A' :
00119         writeMode = SprRWFactory::Ascii;
00120         break;
00121       case 'K' :
00122         split = true;
00123         splitFactor = (optarg==0 ? 0 : atof(optarg));
00124         break;
00125       case 'v' :
00126         verbose = (optarg==0 ? 0 : atoi(optarg));
00127         break;
00128       case 'w' :
00129         if( optarg != 0 ) {
00130           scaleWeights = true;
00131           sW = atof(optarg);
00132         }
00133         break;
00134       case 't' :
00135         tupleName = optarg;
00136         break;
00137       case 'C' :
00138         classifierNameList = optarg;
00139         break;
00140       case 'p' :
00141         nPrintOut = (optarg==0 ? 1000 : atoi(optarg));
00142         break;
00143       case 's' :
00144         useStandard = true;
00145         break;
00146       case 'V' :
00147         includeList = optarg;
00148         break;
00149       case 'z' :
00150         excludeList = optarg;
00151         break;
00152       case 'Z' :
00153         stringVarsDoNotFeed = optarg;
00154         break;
00155       case 'M' :
00156         mapTrainedVars = true;
00157         break;
00158       }
00159   }
00160 
00161   // Must have 3 arguments on the command line
00162   string configFileList = argv[argc-3];
00163   string dataFile       = argv[argc-2];
00164   string tupleFile      = argv[argc-1];
00165   if( configFileList.empty() ) {
00166     cerr << "No classifier configuration files are specified." << endl;
00167     return 1;
00168   }
00169   if( dataFile.empty() ) {
00170     cerr << "No input data file is specified." << endl;
00171     return 1;
00172   }
00173   if( tupleFile.empty() ) {
00174     cerr << "No output tuple file is specified." << endl;
00175     return 1;
00176   }
00177 
00178   // get classifier names and config files
00179   vector<vector<string> > classifierNames, configFiles;
00180   SprStringParser::parseToStrings(classifierNameList.c_str(),classifierNames);
00181   SprStringParser::parseToStrings(configFileList.c_str(),configFiles);
00182   if( configFiles.empty() || configFiles[0].empty() ) {
00183     cerr << "Unable to parse config file list: " 
00184          << configFileList.c_str() << endl;
00185     return 1;
00186   }
00187   int nTrained = configFiles[0].size();
00188   bool useClassifierNames 
00189     = (!classifierNames.empty() && !classifierNames[0].empty());
00190   if( useClassifierNames && (classifierNames[0].size()!=nTrained) ) {
00191     cerr << "Sizes of classifier name list and config file list do not match!"
00192          << endl;
00193     return 1;
00194   }
00195 
00196   // make reader
00197   SprRWFactory::DataType inputType 
00198     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00199   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00200 
00201   // include variables
00202   set<string> includeSet;
00203   if( !includeList.empty() ) {
00204     vector<vector<string> > includeVars;
00205     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00206     assert( !includeVars.empty() );
00207     for( int i=0;i<includeVars[0].size();i++ ) 
00208       includeSet.insert(includeVars[0][i]);
00209     if( !reader->chooseVars(includeSet) ) {
00210       cerr << "Unable to include variables in training set." << endl;
00211       return 2;
00212     }
00213     else {
00214       cout << "Following variables have been included in optimization: ";
00215       for( set<string>::const_iterator 
00216              i=includeSet.begin();i!=includeSet.end();i++ )
00217         cout << "\"" << *i << "\"" << " ";
00218       cout << endl;
00219     }
00220   }
00221 
00222   // exclude variables
00223   set<string> excludeSet;
00224   if( !excludeList.empty() ) {
00225     vector<vector<string> > excludeVars;
00226     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00227     assert( !excludeVars.empty() );
00228     for( int i=0;i<excludeVars[0].size();i++ ) 
00229       excludeSet.insert(excludeVars[0][i]);
00230     if( !reader->chooseAllBut(excludeSet) ) {
00231       cerr << "Unable to exclude variables from training set." << endl;
00232       return 2;
00233     }
00234     else {
00235       cout << "Following variables have been excluded from optimization: ";
00236       for( set<string>::const_iterator 
00237              i=excludeSet.begin();i!=excludeSet.end();i++ )
00238         cout << "\"" << *i << "\"" << " ";
00239       cout << endl;
00240     }
00241   }
00242 
00243   // read input data from file
00244   auto_ptr<SprAbsFilter> filter(reader->read(dataFile.c_str()));
00245   if( filter.get() == 0 ) {
00246     cerr << "Unable to read data from file " << dataFile.c_str() << endl;
00247     return 2;
00248   }
00249   vector<string> vars;
00250   filter->vars(vars);
00251   cout << "Read data from file " << dataFile.c_str() << " for variables";
00252   for( int i=0;i<vars.size();i++ ) 
00253     cout << " \"" << vars[i].c_str() << "\"";
00254   cout << endl;
00255   cout << "Total number of points read: " << filter->size() << endl;
00256 
00257   // filter training data by class
00258   vector<SprClass> inputClasses;
00259   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00260     cerr << "Cannot choose input classes for string " 
00261          << inputClassesString << endl;
00262     return 2;
00263   }
00264   filter->classes(inputClasses);
00265   assert( inputClasses.size() > 1 );
00266   cout << "Training data filtered by class." << endl;
00267   for( int i=0;i<inputClasses.size();i++ ) {
00268     cout << "Points in class " << inputClasses[i] << ":   " 
00269          << filter->ptsInClass(inputClasses[i]) << endl;
00270   }
00271 
00272   // scale weights
00273   if( scaleWeights ) {
00274     cout << "Signal weights are multiplied by " << sW << endl;
00275     filter->scaleWeights(inputClasses[1],sW);
00276   }
00277 
00278   // apply transformation of variables to training and test data
00279   auto_ptr<SprAbsFilter> garbage_train;
00280   if( !transformerFile.empty() ) {
00281     SprVarTransformerReader transReader;
00282     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00283     if( t == 0 ) {
00284       cerr << "Unable to read VarTransformer from file "
00285            << transformerFile.c_str() << endl;
00286       return 2;
00287     }
00288     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00289     bool replaceOriginalData = true;
00290     if( !t_train->transform(t,replaceOriginalData) ) {
00291       cerr << "Unable to apply VarTransformer to training data." << endl;
00292       return 2;
00293     }
00294     cout << "Variable transformation from file "
00295          << transformerFile.c_str() << " has been applied to data." << endl;
00296     garbage_train.reset(filter.release());
00297     filter.reset(t_train);
00298     filter->vars(vars);
00299   }
00300 
00301   // split data if desired
00302   auto_ptr<SprAbsFilter> valFilter;
00303   if( split ) {
00304     cout << "Splitting data with factor " << splitFactor << endl;
00305     vector<double> weights;
00306     SprData* splitted = filter->split(splitFactor,weights,false);
00307     if( splitted == 0 ) {
00308       cerr << "Unable to split data." << endl;
00309       return 2;
00310     }
00311     bool ownData = true;
00312     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00313     cout << "Data re-filtered:" << endl;
00314     for( int i=0;i<inputClasses.size();i++ ) {
00315       cout << "Points in class " << inputClasses[i] << ":   "
00316            << valFilter->ptsInClass(inputClasses[i]) << endl;
00317     }
00318   }
00319   else {
00320     valFilter.reset(filter.release());
00321   }
00322 
00323   // read classifier configuration
00324   vector<SprAbsTrainedClassifier*> trained(nTrained);
00325   vector<SprCoordinateMapper*> specificMappers(nTrained);
00326   for( int i=0;i<nTrained;i++ ) {
00327 
00328     // read classifier
00329     trained[i] 
00330       = SprClassifierReader::readTrained(configFiles[0][i].c_str(),verbose);
00331     if( trained[i] == 0 ) {
00332       cerr << "Unable to read classifier configuration from file "
00333            << configFiles[0][i].c_str() << endl;
00334       cleanup(trained);
00335       return 3;
00336     }
00337     cout << "Read classifier " << trained[i]->name().c_str()
00338          << " with dimensionality " << trained[i]->dim() << endl;
00339 
00340     // get a list of trained variables
00341     vector<string> trainedVars;
00342     trained[i]->vars(trainedVars);
00343     if( verbose > 0 ) {
00344       cout << "Variables:      " << endl;
00345       for( int j=0;j<trainedVars.size();j++ ) 
00346         cout << trainedVars[j].c_str() << " ";
00347       cout << endl;
00348     }
00349 
00350     // map trained-classifier variables onto data variables
00351     if( mapTrainedVars || trained[i]->name()=="Combiner" ) {
00352       specificMappers[i] 
00353         = SprCoordinateMapper::createMapper(trainedVars,vars);
00354     }
00355 
00356     // switch classifier output range
00357     if( useStandard ) {
00358       if(      trained[i]->name() == "AdaBoost" ) {
00359         SprTrainedAdaBoost* specific 
00360           = static_cast<SprTrainedAdaBoost*>(trained[i]);
00361         specific->useStandard();
00362       }
00363       else if( trained[i]->name() == "Fisher" ) {
00364         SprTrainedFisher* specific 
00365           = static_cast<SprTrainedFisher*>(trained[i]);
00366         specific->useStandard();
00367       }
00368       else if( trained[i]->name() == "LogitR" ) {
00369         SprTrainedLogitR* specific 
00370           = static_cast<SprTrainedLogitR*>(trained[i]);
00371         specific->useStandard();
00372       }
00373     }
00374   }
00375 
00376   // make tuple
00377   if( tupleName.empty() ) tupleName = "data";
00378   auto_ptr<SprAbsWriter> 
00379     tuple(SprRWFactory::makeWriter(writeMode,tupleName.c_str()));
00380   if( !tuple->init(tupleFile.c_str()) ) {
00381     cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00382     cleanup(trained);
00383     return 5;
00384   }
00385 
00386   // determine if certain variables are to be excluded from usage,
00387   // but included in the output storage file (-Z option)
00388   string printVarsDoNotFeed;
00389   vector<vector<string> > varsDoNotFeed;
00390   SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00391   vector<unsigned> mapper;
00392   for( int d=0;d<vars.size();d++ ) {
00393     if( varsDoNotFeed.empty() ||
00394         (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00395          ==varsDoNotFeed[0].end()) ) {
00396       mapper.push_back(d);
00397     }
00398     else {
00399       printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00400       printVarsDoNotFeed += vars[d];
00401     }
00402   }
00403   if( !printVarsDoNotFeed.empty() ) {
00404     cout << "The following variables are not used in the algorithm, " 
00405          << "but will be included in the output file: " 
00406          << printVarsDoNotFeed.c_str() << endl;
00407   }
00408 
00409   // feed data into tuple
00410   SprDataFeeder feeder(valFilter.get(),tuple.get(),mapper);
00411   for( int i=0;i<nTrained;i++ ) {
00412     string useName;
00413     if( useClassifierNames ) 
00414       useName = classifierNames[0][i];
00415     else
00416       useName = trained[i]->name();
00417     feeder.addClassifier(trained[i],useName.c_str(),specificMappers[i]);
00418   }
00419   if( !feeder.feed(nPrintOut) ) {
00420     cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00421     cleanup(trained);
00422     return 6;
00423   }
00424 
00425   // exit
00426   cleanup(trained);
00427   return 0;
00428 }

Generated on Tue Jun 9 17:41:59 2009 for CMSSW by  doxygen 1.5.4