00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00011
00012 #include <stdlib.h>
00013 #include <unistd.h>
00014 #include <iostream>
00015 #include <vector>
00016 #include <set>
00017 #include <string>
00018 #include <memory>
00019
00020 using namespace std;
00021
00022
00023 void help(const char* prog)
00024 {
00025 cout << "Usage: " << prog
00026 << " training_data_file" << endl;
00027 cout << "\t Options: " << endl;
00028 cout << "\t-h --- help " << endl;
00029 cout << "\t-o output Tuple file " << endl;
00030 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00031 cout << "\t-A save output data in ascii instead of Root " << endl;
00032 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00033 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00034 cout << "\t-V include only these input variables " << endl;
00035 cout << "\t-z exclude input variables from the list " << endl;
00036 cout << "\t-Z exclude input variables from the list, "
00037 << "but put them in the output file " << endl;
00038 cout << "\t\t Variables must be listed in quotes and separated by commas."
00039 << endl;
00040 cout << "\t\t Variables must be listed in quotes and separated by commas."
00041 << endl;
00042 }
00043
00044
00045 int main(int argc, char ** argv)
00046 {
00047
00048 if( argc < 2 ) {
00049 help(argv[0]);
00050 return 1;
00051 }
00052
00053
00054 string tupleFile;
00055 int readMode = 0;
00056 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00057 int verbose = 0;
00058 string outFile;
00059 string includeList, excludeList;
00060 string inputClassesString;
00061 string stringVarsDoNotFeed;
00062
00063
00064 int c;
00065 extern char* optarg;
00066
00067 while( (c = getopt(argc,argv,"ho:a:Ay:v:V:z:Z:")) != EOF ) {
00068 switch( c )
00069 {
00070 case 'h' :
00071 help(argv[0]);
00072 return 1;
00073 case 'o' :
00074 tupleFile = optarg;
00075 break;
00076 case 'a' :
00077 readMode = (optarg==0 ? 0 : atoi(optarg));
00078 break;
00079 case 'A' :
00080 writeMode = SprRWFactory::Ascii;
00081 break;
00082 case 'y' :
00083 inputClassesString = optarg;
00084 break;
00085 case 'v' :
00086 verbose = (optarg==0 ? 0 : atoi(optarg));
00087 break;
00088 case 'V' :
00089 includeList = optarg;
00090 break;
00091 case 'z' :
00092 excludeList = optarg;
00093 break;
00094 case 'Z' :
00095 stringVarsDoNotFeed = optarg;
00096 break;
00097 }
00098 }
00099
00100
00101 string trFile = argv[argc-1];
00102 if( trFile.empty() ) {
00103 cerr << "No training file is specified." << endl;
00104 return 1;
00105 }
00106
00107
00108 SprRWFactory::DataType inputType
00109 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00110 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00111
00112
00113 set<string> includeSet;
00114 if( !includeList.empty() ) {
00115 vector<vector<string> > includeVars;
00116 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00117 assert( !includeVars.empty() );
00118 for( int i=0;i<includeVars[0].size();i++ )
00119 includeSet.insert(includeVars[0][i]);
00120 if( !reader->chooseVars(includeSet) ) {
00121 cerr << "Unable to include variables in training set." << endl;
00122 return 2;
00123 }
00124 else {
00125 cout << "Following variables have been included in optimization: ";
00126 for( set<string>::const_iterator
00127 i=includeSet.begin();i!=includeSet.end();i++ )
00128 cout << "\"" << *i << "\"" << " ";
00129 cout << endl;
00130 }
00131 }
00132
00133
00134 set<string> excludeSet;
00135 if( !excludeList.empty() ) {
00136 vector<vector<string> > excludeVars;
00137 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00138 assert( !excludeVars.empty() );
00139 for( int i=0;i<excludeVars[0].size();i++ )
00140 excludeSet.insert(excludeVars[0][i]);
00141 if( !reader->chooseAllBut(excludeSet) ) {
00142 cerr << "Unable to exclude variables from training set." << endl;
00143 return 2;
00144 }
00145 else {
00146 cout << "Following variables have been excluded from optimization: ";
00147 for( set<string>::const_iterator
00148 i=excludeSet.begin();i!=excludeSet.end();i++ )
00149 cout << "\"" << *i << "\"" << " ";
00150 cout << endl;
00151 }
00152 }
00153
00154
00155 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00156 if( filter.get() == 0 ) {
00157 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00158 return 2;
00159 }
00160 vector<string> vars;
00161 filter->vars(vars);
00162 cout << "Read data from file " << trFile.c_str()
00163 << " for variables";
00164 for( int i=0;i<vars.size();i++ )
00165 cout << " \"" << vars[i].c_str() << "\"";
00166 cout << endl;
00167 cout << "Total number of points read: " << filter->size() << endl;
00168
00169
00170 vector<SprClass> inputClasses;
00171 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00172 cerr << "Cannot choose input classes for string "
00173 << inputClassesString << endl;
00174 return 2;
00175 }
00176 filter->classes(inputClasses);
00177 assert( inputClasses.size() > 1 );
00178 cout << "Training data filtered by class." << endl;
00179 for( int i=0;i<inputClasses.size();i++ ) {
00180 cout << "Points in class " << inputClasses[i] << ": "
00181 << filter->ptsInClass(inputClasses[i]) << endl;
00182 }
00183
00184
00185 if( tupleFile.empty() ) return 0;
00186
00187
00188 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00189 if( !tuple->init(tupleFile.c_str()) ) {
00190 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00191 return 8;
00192 }
00193
00194
00195
00196 string printVarsDoNotFeed;
00197 vector<vector<string> > varsDoNotFeed;
00198 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00199 vector<unsigned> mapper;
00200 for( int d=0;d<vars.size();d++ ) {
00201 if( varsDoNotFeed.empty() ||
00202 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00203 ==varsDoNotFeed[0].end()) ) {
00204 mapper.push_back(d);
00205 }
00206 else {
00207 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00208 printVarsDoNotFeed += vars[d];
00209 }
00210 }
00211 if( !printVarsDoNotFeed.empty() ) {
00212 cout << "The following variables are not used in the algorithm, "
00213 << "but will be included in the output file: "
00214 << printVarsDoNotFeed.c_str() << endl;
00215 }
00216
00217
00218 SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00219 if( !feeder.feed(1000) ) {
00220 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00221 return 9;
00222 }
00223
00224
00225 return 0;
00226 }