00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprPCATransformer.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00013
00014 #include <stdlib.h>
00015 #include <unistd.h>
00016 #include <iostream>
00017 #include <vector>
00018 #include <set>
00019 #include <string>
00020 #include <memory>
00021
00022 using namespace std;
00023
00024
00025 void help(const char* prog)
00026 {
00027 cout << "Usage: " << prog
00028 << " training_data_file output_file_for_PCA_coefficients" << endl;
00029 cout << "\t Options: " << endl;
00030 cout << "\t-h --- help " << endl;
00031 cout << "\t-o output Tuple file " << endl;
00032 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00033 cout << "\t-A save output data in ascii instead of Root " << endl;
00034 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00035 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00036 cout << "\t-V include only these input variables " << endl;
00037 cout << "\t-z exclude input variables from the list " << endl;
00038 cout << "\t\t Variables must be listed in quotes and separated by commas."
00039 << endl;
00040 }
00041
00042
00043 int main(int argc, char ** argv)
00044 {
00045
00046 if( argc < 2 ) {
00047 help(argv[0]);
00048 return 1;
00049 }
00050
00051
00052 string tupleFile;
00053 int readMode = 0;
00054 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00055 int verbose = 0;
00056 string outFile;
00057 string includeList, excludeList;
00058 string inputClassesString;
00059 string stringVarsDoNotFeed;
00060
00061
00062 int c;
00063 extern char* optarg;
00064
00065 while( (c = getopt(argc,argv,"ho:a:Ay:v:K:DV:z:Z:")) != EOF ) {
00066 switch( c )
00067 {
00068 case 'h' :
00069 help(argv[0]);
00070 return 1;
00071 case 'o' :
00072 tupleFile = optarg;
00073 break;
00074 case 'a' :
00075 readMode = (optarg==0 ? 0 : atoi(optarg));
00076 break;
00077 case 'A' :
00078 writeMode = SprRWFactory::Ascii;
00079 break;
00080 case 'y' :
00081 inputClassesString = optarg;
00082 break;
00083 case 'v' :
00084 verbose = (optarg==0 ? 0 : atoi(optarg));
00085 break;
00086 case 'V' :
00087 includeList = optarg;
00088 break;
00089 case 'z' :
00090 excludeList = optarg;
00091 break;
00092 case 'Z' :
00093 stringVarsDoNotFeed = optarg;
00094 break;
00095 }
00096 }
00097
00098
00099 string trFile = argv[argc-2];
00100 string pcaFile = argv[argc-1];
00101 if( trFile.empty() ) {
00102 cerr << "No training file is specified." << endl;
00103 return 1;
00104 }
00105 if( pcaFile.empty() ) {
00106 cerr << "No file for storing PCA coefficients is specified." << endl;
00107 return 1;
00108 }
00109
00110
00111 SprRWFactory::DataType inputType
00112 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00113 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00114
00115
00116 set<string> includeSet;
00117 if( !includeList.empty() ) {
00118 vector<vector<string> > includeVars;
00119 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00120 assert( !includeVars.empty() );
00121 for( int i=0;i<includeVars[0].size();i++ )
00122 includeSet.insert(includeVars[0][i]);
00123 if( !reader->chooseVars(includeSet) ) {
00124 cerr << "Unable to include variables in training set." << endl;
00125 return 2;
00126 }
00127 else {
00128 cout << "Following variables have been included in optimization: ";
00129 for( set<string>::const_iterator
00130 i=includeSet.begin();i!=includeSet.end();i++ )
00131 cout << "\"" << *i << "\"" << " ";
00132 cout << endl;
00133 }
00134 }
00135
00136
00137 set<string> excludeSet;
00138 if( !excludeList.empty() ) {
00139 vector<vector<string> > excludeVars;
00140 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00141 assert( !excludeVars.empty() );
00142 for( int i=0;i<excludeVars[0].size();i++ )
00143 excludeSet.insert(excludeVars[0][i]);
00144 if( !reader->chooseAllBut(excludeSet) ) {
00145 cerr << "Unable to exclude variables from training set." << endl;
00146 return 2;
00147 }
00148 else {
00149 cout << "Following variables have been excluded from optimization: ";
00150 for( set<string>::const_iterator
00151 i=excludeSet.begin();i!=excludeSet.end();i++ )
00152 cout << "\"" << *i << "\"" << " ";
00153 cout << endl;
00154 }
00155 }
00156
00157
00158 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00159 if( filter.get() == 0 ) {
00160 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00161 return 2;
00162 }
00163 vector<string> vars;
00164 filter->vars(vars);
00165 cout << "Read data from file " << trFile.c_str()
00166 << " for variables";
00167 for( int i=0;i<vars.size();i++ )
00168 cout << " \"" << vars[i].c_str() << "\"";
00169 cout << endl;
00170 cout << "Total number of points read: " << filter->size() << endl;
00171
00172
00173 vector<SprClass> inputClasses;
00174 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00175 cerr << "Cannot choose input classes for string "
00176 << inputClassesString << endl;
00177 return 2;
00178 }
00179 filter->classes(inputClasses);
00180 assert( inputClasses.size() > 1 );
00181 cout << "Training data filtered by class." << endl;
00182 for( int i=0;i<inputClasses.size();i++ ) {
00183 cout << "Points in class " << inputClasses[i] << ": "
00184 << filter->ptsInClass(inputClasses[i]) << endl;
00185 }
00186
00187
00188 SprPCATransformer pca;
00189 if( !pca.train(filter.get(),verbose) ) {
00190 cerr << "Unable to compute PCA coefficients." << endl;
00191 return 3;
00192 }
00193 cout << "Computed PCA coefficients." << endl;
00194
00195
00196 if( !pca.store(pcaFile.c_str()) ) {
00197 cerr << "Unable to store PCA coefficients." << endl;
00198 return 4;
00199 }
00200 cout << "Stored PCA coefficients." << endl;
00201
00202
00203 if( tupleFile.empty() ) return 0;
00204
00205
00206 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00207 if( !tuple->init(tupleFile.c_str()) ) {
00208 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00209 return 5;
00210 }
00211
00212
00213 SprTransformerFilter trans(filter.get());
00214 bool replaceOriginalData = true;
00215 if( !trans.transform(&pca,replaceOriginalData) ) {
00216 cerr << "Unable to transform input data." << endl;
00217 return 6;
00218 }
00219
00220
00221 SprDataFeeder feeder(&trans,tuple.get());
00222 if( !feeder.feed(1000) ) {
00223 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00224 return 9;
00225 }
00226
00227
00228 return 0;
00229 }