CMS 3D CMS Logo

SprSplitterApp.cc

Go to the documentation of this file.
00001 //$Id: SprSplitterApp.cc,v 1.1 2007/12/01 01:29:41 narsky Exp $
00002 /*
00003   This executable splits input data into training and test data,
00004   optionally converting them into a different format (e.g., Ascii 
00005   instead of Root).
00006 */
00007 
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00017 
00018 #include <stdlib.h>
00019 #include <unistd.h>
00020 #include <iostream>
00021 #include <vector>
00022 #include <set>
00023 #include <string>
00024 #include <memory>
00025 
00026 using namespace std;
00027 
00028 
00029 void help(const char* prog) 
00030 {
00031   cout << "Usage:  " << prog 
00032        << " input_data_file output_training_data_file output_test_data_file" 
00033        << endl;
00034   cout << "\t Options: " << endl;
00035   cout << "\t-h --- help                                        " << endl;
00036   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00037   cout << "\t-A save output data in ascii instead of Root       " << endl;
00038   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00039   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00040   cout << "\t-V include only these input variables              " << endl;
00041   cout << "\t-z exclude input variables from the list           " << endl;
00042   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00043        << endl;
00044   cout << "\t-K keep the specified fraction in input data       " << endl;
00045   cout << "\t\t If no fraction specified, 0.5 is assumed.       " << endl;
00046   cout << "\t-S random seed used for splitting.                 " << endl;
00047   cout << "\t\t If none, puts K into training and (1-K) into test data." 
00048        << endl;
00049 }
00050 
00051 
00052 int main(int argc, char ** argv)
00053 {
00054   // check command line
00055   if( argc < 2 ) {
00056     help(argv[0]);
00057     return 1;
00058   }
00059 
00060   // init
00061   int readMode = 0;
00062   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00063   int verbose = 0;
00064   string outFile;
00065   string includeList, excludeList;
00066   string inputClassesString;
00067   double splitFactor = 0.5;
00068   int seed = 0;
00069   bool splitRandomize = false;
00070 
00071   // decode command line
00072   int c;
00073   extern char* optarg;
00074   //  extern int optind;
00075   while( (c = getopt(argc,argv,"ha:Ay:v:V:z:K:S:")) != EOF ) {
00076     switch( c )
00077       {
00078       case 'h' :
00079         help(argv[0]);
00080         return 1;
00081        case 'a' :
00082         readMode = (optarg==0 ? 0 : atoi(optarg));
00083         break;
00084       case 'A' :
00085         writeMode = SprRWFactory::Ascii;
00086         break;
00087       case 'y' :
00088         inputClassesString = optarg;
00089         break;
00090       case 'v' :
00091         verbose = (optarg==0 ? 0 : atoi(optarg));
00092         break;
00093       case 'V' :
00094         includeList = optarg;
00095         break;
00096       case 'z' :
00097         excludeList = optarg;
00098         break;
00099       case 'K' :
00100         splitFactor = (optarg==0 ? 0.5 : atof(optarg));
00101         break;
00102       case 'S' :
00103         splitRandomize = true;
00104         seed = (optarg==0 ? 0 : atoi(optarg));
00105         break;
00106       }
00107   }
00108 
00109   // arguments
00110   string inputFile = argv[argc-3];
00111   if( inputFile.empty() ) {
00112     cerr << "No input file is specified." << endl;
00113     return 1;
00114   }
00115   string trainFile = argv[argc-2];
00116   if( trainFile.empty() ) {
00117     cerr << "No training file is specified." << endl;
00118     return 1;
00119   }
00120   string testFile = argv[argc-1];
00121   if( testFile.empty() ) {
00122     cerr << "No test file is specified." << endl;
00123     return 1;
00124   }
00125 
00126   // make reader
00127   SprRWFactory::DataType inputType 
00128     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00129   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00130 
00131   // include variables
00132   set<string> includeSet;
00133   if( !includeList.empty() ) {
00134     vector<vector<string> > includeVars;
00135     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00136     assert( !includeVars.empty() );
00137     for( int i=0;i<includeVars[0].size();i++ ) 
00138       includeSet.insert(includeVars[0][i]);
00139     if( !reader->chooseVars(includeSet) ) {
00140       cerr << "Unable to include variables in input set." << endl;
00141       return 2;
00142     }
00143     else {
00144       cout << "Following variables have been included in optimization: ";
00145       for( set<string>::const_iterator 
00146              i=includeSet.begin();i!=includeSet.end();i++ )
00147         cout << "\"" << *i << "\"" << " ";
00148       cout << endl;
00149     }
00150   }
00151 
00152   // exclude variables
00153   set<string> excludeSet;
00154   if( !excludeList.empty() ) {
00155     vector<vector<string> > excludeVars;
00156     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00157     assert( !excludeVars.empty() );
00158     for( int i=0;i<excludeVars[0].size();i++ ) 
00159       excludeSet.insert(excludeVars[0][i]);
00160     if( !reader->chooseAllBut(excludeSet) ) {
00161       cerr << "Unable to exclude variables from input set." << endl;
00162       return 2;
00163     }
00164     else {
00165       cout << "Following variables have been excluded from optimization: ";
00166       for( set<string>::const_iterator 
00167              i=excludeSet.begin();i!=excludeSet.end();i++ )
00168         cout << "\"" << *i << "\"" << " ";
00169       cout << endl;
00170     }
00171   }
00172 
00173   // read training data from file
00174   auto_ptr<SprAbsFilter> filter(reader->read(inputFile.c_str()));
00175   if( filter.get() == 0 ) {
00176     cerr << "Unable to read data from file " << inputFile.c_str() << endl;
00177     return 2;
00178   }
00179   vector<string> vars;
00180   filter->vars(vars);
00181   cout << "Read data from file " << inputFile.c_str() << " for variables";
00182   for( int i=0;i<vars.size();i++ ) 
00183     cout << " \"" << vars[i].c_str() << "\"";
00184   cout << endl;
00185   cout << "Total number of points read: " << filter->size() << endl;
00186 
00187   // filter training data by class
00188   vector<SprClass> inputClasses;
00189   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00190     cerr << "Cannot choose input classes for string " 
00191          << inputClassesString << endl;
00192     return 2;
00193   }
00194   filter->classes(inputClasses);
00195   assert( inputClasses.size() > 1 );
00196   cout << "Input data filtered by class." << endl;
00197   for( int i=0;i<inputClasses.size();i++ ) {
00198     cout << "Points in class " << inputClasses[i] << ":   " 
00199          << filter->ptsInClass(inputClasses[i]) << endl;
00200   }
00201 
00202   // split data
00203   cout << "Splitting input data with factor " << splitFactor << endl;
00204   vector<double> weights;
00205   SprData* splitted = filter->split(splitFactor,weights,splitRandomize,seed);
00206   if( splitted == 0 ) {
00207     cerr << "Unable to split input data." << endl;
00208     return 2;
00209   }
00210   bool ownData = true;
00211   auto_ptr<SprAbsFilter> valFilter(new SprEmptyFilter(splitted,
00212                                                       weights,ownData));
00213   cout << "Data re-filtered:" << endl;
00214   cout << "Training data:" << endl;
00215   for( int i=0;i<inputClasses.size();i++ ) {
00216     cout << "Points in class " << inputClasses[i] << ":   " 
00217          << filter->ptsInClass(inputClasses[i]) << endl;
00218   }
00219   cout << "Test data:" << endl;
00220   for( int i=0;i<inputClasses.size();i++ ) {
00221     cout << "Points in class " << inputClasses[i] << ":   " 
00222          << valFilter->ptsInClass(inputClasses[i]) << endl;
00223   }
00224 
00225   // make a writer
00226   auto_ptr<SprAbsWriter> trainTuple(SprRWFactory::makeWriter(writeMode,
00227                                                              "training"));
00228   if( !trainTuple->init(trainFile.c_str()) ) {
00229     cerr << "Unable to open output file " << trainFile.c_str() << endl;
00230     return 3;
00231   }
00232   auto_ptr<SprAbsWriter> testTuple(SprRWFactory::makeWriter(writeMode,
00233                                                             "test"));
00234   if( !testTuple->init(testFile.c_str()) ) {
00235     cerr << "Unable to open output file " << testFile.c_str() << endl;
00236     return 3;
00237   }
00238 
00239   // feed
00240   SprDataFeeder feeder(filter.get(),trainTuple.get());
00241   if( !feeder.feed(1000) ) {
00242     cerr << "Cannot feed data into file " << trainFile.c_str() << endl;
00243     return 4;
00244   }
00245   SprDataFeeder valFeeder(valFilter.get(),testTuple.get());
00246   if( !valFeeder.feed(1000) ) {
00247     cerr << "Cannot feed data into file " << testFile.c_str() << endl;
00248     return 4;
00249   }
00250 
00251   // exit
00252   return 0;
00253 }

Generated on Tue Jun 9 17:41:59 2009 for CMSSW by  doxygen 1.5.4