00001
00002
00003
00004
00005
00006
00007
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00017
00018 #include <stdlib.h>
00019 #include <unistd.h>
00020 #include <iostream>
00021 #include <vector>
00022 #include <set>
00023 #include <string>
00024 #include <memory>
00025
00026 using namespace std;
00027
00028
00029 void help(const char* prog)
00030 {
00031 cout << "Usage: " << prog
00032 << " input_data_file output_training_data_file output_test_data_file"
00033 << endl;
00034 cout << "\t Options: " << endl;
00035 cout << "\t-h --- help " << endl;
00036 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00037 cout << "\t-A save output data in ascii instead of Root " << endl;
00038 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00039 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00040 cout << "\t-V include only these input variables " << endl;
00041 cout << "\t-z exclude input variables from the list " << endl;
00042 cout << "\t\t Variables must be listed in quotes and separated by commas."
00043 << endl;
00044 cout << "\t-K keep the specified fraction in input data " << endl;
00045 cout << "\t\t If no fraction specified, 0.5 is assumed. " << endl;
00046 cout << "\t-S random seed used for splitting. " << endl;
00047 cout << "\t\t If none, puts K into training and (1-K) into test data."
00048 << endl;
00049 }
00050
00051
00052 int main(int argc, char ** argv)
00053 {
00054
00055 if( argc < 2 ) {
00056 help(argv[0]);
00057 return 1;
00058 }
00059
00060
00061 int readMode = 0;
00062 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00063 int verbose = 0;
00064 string outFile;
00065 string includeList, excludeList;
00066 string inputClassesString;
00067 double splitFactor = 0.5;
00068 int seed = 0;
00069 bool splitRandomize = false;
00070
00071
00072 int c;
00073 extern char* optarg;
00074
00075 while( (c = getopt(argc,argv,"ha:Ay:v:V:z:K:S:")) != EOF ) {
00076 switch( c )
00077 {
00078 case 'h' :
00079 help(argv[0]);
00080 return 1;
00081 case 'a' :
00082 readMode = (optarg==0 ? 0 : atoi(optarg));
00083 break;
00084 case 'A' :
00085 writeMode = SprRWFactory::Ascii;
00086 break;
00087 case 'y' :
00088 inputClassesString = optarg;
00089 break;
00090 case 'v' :
00091 verbose = (optarg==0 ? 0 : atoi(optarg));
00092 break;
00093 case 'V' :
00094 includeList = optarg;
00095 break;
00096 case 'z' :
00097 excludeList = optarg;
00098 break;
00099 case 'K' :
00100 splitFactor = (optarg==0 ? 0.5 : atof(optarg));
00101 break;
00102 case 'S' :
00103 splitRandomize = true;
00104 seed = (optarg==0 ? 0 : atoi(optarg));
00105 break;
00106 }
00107 }
00108
00109
00110 string inputFile = argv[argc-3];
00111 if( inputFile.empty() ) {
00112 cerr << "No input file is specified." << endl;
00113 return 1;
00114 }
00115 string trainFile = argv[argc-2];
00116 if( trainFile.empty() ) {
00117 cerr << "No training file is specified." << endl;
00118 return 1;
00119 }
00120 string testFile = argv[argc-1];
00121 if( testFile.empty() ) {
00122 cerr << "No test file is specified." << endl;
00123 return 1;
00124 }
00125
00126
00127 SprRWFactory::DataType inputType
00128 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00129 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00130
00131
00132 set<string> includeSet;
00133 if( !includeList.empty() ) {
00134 vector<vector<string> > includeVars;
00135 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00136 assert( !includeVars.empty() );
00137 for( int i=0;i<includeVars[0].size();i++ )
00138 includeSet.insert(includeVars[0][i]);
00139 if( !reader->chooseVars(includeSet) ) {
00140 cerr << "Unable to include variables in input set." << endl;
00141 return 2;
00142 }
00143 else {
00144 cout << "Following variables have been included in optimization: ";
00145 for( set<string>::const_iterator
00146 i=includeSet.begin();i!=includeSet.end();i++ )
00147 cout << "\"" << *i << "\"" << " ";
00148 cout << endl;
00149 }
00150 }
00151
00152
00153 set<string> excludeSet;
00154 if( !excludeList.empty() ) {
00155 vector<vector<string> > excludeVars;
00156 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00157 assert( !excludeVars.empty() );
00158 for( int i=0;i<excludeVars[0].size();i++ )
00159 excludeSet.insert(excludeVars[0][i]);
00160 if( !reader->chooseAllBut(excludeSet) ) {
00161 cerr << "Unable to exclude variables from input set." << endl;
00162 return 2;
00163 }
00164 else {
00165 cout << "Following variables have been excluded from optimization: ";
00166 for( set<string>::const_iterator
00167 i=excludeSet.begin();i!=excludeSet.end();i++ )
00168 cout << "\"" << *i << "\"" << " ";
00169 cout << endl;
00170 }
00171 }
00172
00173
00174 auto_ptr<SprAbsFilter> filter(reader->read(inputFile.c_str()));
00175 if( filter.get() == 0 ) {
00176 cerr << "Unable to read data from file " << inputFile.c_str() << endl;
00177 return 2;
00178 }
00179 vector<string> vars;
00180 filter->vars(vars);
00181 cout << "Read data from file " << inputFile.c_str() << " for variables";
00182 for( int i=0;i<vars.size();i++ )
00183 cout << " \"" << vars[i].c_str() << "\"";
00184 cout << endl;
00185 cout << "Total number of points read: " << filter->size() << endl;
00186
00187
00188 vector<SprClass> inputClasses;
00189 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00190 cerr << "Cannot choose input classes for string "
00191 << inputClassesString << endl;
00192 return 2;
00193 }
00194 filter->classes(inputClasses);
00195 assert( inputClasses.size() > 1 );
00196 cout << "Input data filtered by class." << endl;
00197 for( int i=0;i<inputClasses.size();i++ ) {
00198 cout << "Points in class " << inputClasses[i] << ": "
00199 << filter->ptsInClass(inputClasses[i]) << endl;
00200 }
00201
00202
00203 cout << "Splitting input data with factor " << splitFactor << endl;
00204 vector<double> weights;
00205 SprData* splitted = filter->split(splitFactor,weights,splitRandomize,seed);
00206 if( splitted == 0 ) {
00207 cerr << "Unable to split input data." << endl;
00208 return 2;
00209 }
00210 bool ownData = true;
00211 auto_ptr<SprAbsFilter> valFilter(new SprEmptyFilter(splitted,
00212 weights,ownData));
00213 cout << "Data re-filtered:" << endl;
00214 cout << "Training data:" << endl;
00215 for( int i=0;i<inputClasses.size();i++ ) {
00216 cout << "Points in class " << inputClasses[i] << ": "
00217 << filter->ptsInClass(inputClasses[i]) << endl;
00218 }
00219 cout << "Test data:" << endl;
00220 for( int i=0;i<inputClasses.size();i++ ) {
00221 cout << "Points in class " << inputClasses[i] << ": "
00222 << valFilter->ptsInClass(inputClasses[i]) << endl;
00223 }
00224
00225
00226 auto_ptr<SprAbsWriter> trainTuple(SprRWFactory::makeWriter(writeMode,
00227 "training"));
00228 if( !trainTuple->init(trainFile.c_str()) ) {
00229 cerr << "Unable to open output file " << trainFile.c_str() << endl;
00230 return 3;
00231 }
00232 auto_ptr<SprAbsWriter> testTuple(SprRWFactory::makeWriter(writeMode,
00233 "test"));
00234 if( !testTuple->init(testFile.c_str()) ) {
00235 cerr << "Unable to open output file " << testFile.c_str() << endl;
00236 return 3;
00237 }
00238
00239
00240 SprDataFeeder feeder(filter.get(),trainTuple.get());
00241 if( !feeder.feed(1000) ) {
00242 cerr << "Cannot feed data into file " << trainFile.c_str() << endl;
00243 return 4;
00244 }
00245 SprDataFeeder valFeeder(valFilter.get(),testTuple.get());
00246 if( !valFeeder.feed(1000) ) {
00247 cerr << "Cannot feed data into file " << testFile.c_str() << endl;
00248 return 4;
00249 }
00250
00251
00252 return 0;
00253 }