00001
00002
00003
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprCoordinateMapper.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedFisher.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedLogitR.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00022
00023 #include <stdlib.h>
00024 #include <unistd.h>
00025 #include <iostream>
00026 #include <set>
00027 #include <vector>
00028 #include <memory>
00029 #include <string>
00030 #include <cassert>
00031 #include <algorithm>
00032
00033 using namespace std;
00034
00035
00036 void help(const char* prog)
00037 {
00038 cout << "Usage: " << prog << " list_of_classifier_config_files"
00039 << " input_data_file output_tuple_file" << endl;
00040 cout << "\t (List of files must be in quotes, separated by commas.)" << endl;
00041 cout << "\t Options: " << endl;
00042 cout << "\t-h --- help " << endl;
00043 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00044 cout << "\t-Q apply variable transformation saved in file " << endl;
00045 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00046 cout << "\t-A save output data in ascii instead of Root " << endl;
00047 cout << "\t-K use 1-fraction of input data " << endl;
00048 cout << "\t\t This option is for consistency with other execs." << endl;
00049 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00050 cout << "\t-w scale all signal weights by this factor " << endl;
00051 cout << "\t-t output tuple name (default=data) " << endl;
00052 cout << "\t-C output classifier names (in quotes, separated by commas)"
00053 << endl;
00054 cout << "\t-p feeder print-out frequency (default=1000 events)" << endl;
00055 cout << "\t-s use output in range (-infty,+infty) instead of [0,1]" << endl;
00056 cout << "\t-V include only these input variables " << endl;
00057 cout << "\t-z exclude input variables from the list " << endl;
00058 cout << "\t-Z exclude input variables from the list, "
00059 << "but put them in the output file " << endl;
00060 cout << "\t-M map variable lists from trained classifiers onto" << endl;
00061 cout << "\t\t variables available in input data." << endl;
00062 cout << "\t\t Variables must be listed in quotes and separated by commas."
00063 << endl;
00064 }
00065
00066
00067 void cleanup(vector<SprAbsTrainedClassifier*>& trained) {
00068 for( int i=0;i<trained.size();i++ ) delete trained[i];
00069 }
00070
00071
00072 int main(int argc, char ** argv)
00073 {
00074
00075 if( argc < 4 ) {
00076 help(argv[0]);
00077 return 1;
00078 }
00079
00080
00081 int readMode = 0;
00082 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00083 int verbose = 0;
00084 bool scaleWeights = false;
00085 double sW = 1.;
00086 bool useStandard = false;
00087 string tupleName;
00088 string classifierNameList;
00089 string includeList, excludeList;
00090 string inputClassesString;
00091 int nPrintOut = 1000;
00092 string stringVarsDoNotFeed;
00093 bool mapTrainedVars = false;
00094 bool split = false;
00095 double splitFactor = 0;
00096 string transformerFile;
00097
00098
00099
00100 int c;
00101 extern char* optarg;
00102 extern int optind;
00103 while( (c = getopt(argc,argv,"hy:Q:a:AK:v:w:t:C:p:sV:z:Z:M")) != EOF ) {
00104 switch( c )
00105 {
00106 case 'h' :
00107 help(argv[0]);
00108 return 1;
00109 case 'y' :
00110 inputClassesString = optarg;
00111 break;
00112 case 'Q' :
00113 transformerFile = optarg;
00114 break;
00115 case 'a' :
00116 readMode = (optarg==0 ? 0 : atoi(optarg));
00117 break;
00118 case 'A' :
00119 writeMode = SprRWFactory::Ascii;
00120 break;
00121 case 'K' :
00122 split = true;
00123 splitFactor = (optarg==0 ? 0 : atof(optarg));
00124 break;
00125 case 'v' :
00126 verbose = (optarg==0 ? 0 : atoi(optarg));
00127 break;
00128 case 'w' :
00129 if( optarg != 0 ) {
00130 scaleWeights = true;
00131 sW = atof(optarg);
00132 }
00133 break;
00134 case 't' :
00135 tupleName = optarg;
00136 break;
00137 case 'C' :
00138 classifierNameList = optarg;
00139 break;
00140 case 'p' :
00141 nPrintOut = (optarg==0 ? 1000 : atoi(optarg));
00142 break;
00143 case 's' :
00144 useStandard = true;
00145 break;
00146 case 'V' :
00147 includeList = optarg;
00148 break;
00149 case 'z' :
00150 excludeList = optarg;
00151 break;
00152 case 'Z' :
00153 stringVarsDoNotFeed = optarg;
00154 break;
00155 case 'M' :
00156 mapTrainedVars = true;
00157 break;
00158 }
00159 }
00160
00161
00162 string configFileList = argv[argc-3];
00163 string dataFile = argv[argc-2];
00164 string tupleFile = argv[argc-1];
00165 if( configFileList.empty() ) {
00166 cerr << "No classifier configuration files are specified." << endl;
00167 return 1;
00168 }
00169 if( dataFile.empty() ) {
00170 cerr << "No input data file is specified." << endl;
00171 return 1;
00172 }
00173 if( tupleFile.empty() ) {
00174 cerr << "No output tuple file is specified." << endl;
00175 return 1;
00176 }
00177
00178
00179 vector<vector<string> > classifierNames, configFiles;
00180 SprStringParser::parseToStrings(classifierNameList.c_str(),classifierNames);
00181 SprStringParser::parseToStrings(configFileList.c_str(),configFiles);
00182 if( configFiles.empty() || configFiles[0].empty() ) {
00183 cerr << "Unable to parse config file list: "
00184 << configFileList.c_str() << endl;
00185 return 1;
00186 }
00187 int nTrained = configFiles[0].size();
00188 bool useClassifierNames
00189 = (!classifierNames.empty() && !classifierNames[0].empty());
00190 if( useClassifierNames && (classifierNames[0].size()!=nTrained) ) {
00191 cerr << "Sizes of classifier name list and config file list do not match!"
00192 << endl;
00193 return 1;
00194 }
00195
00196
00197 SprRWFactory::DataType inputType
00198 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00199 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00200
00201
00202 set<string> includeSet;
00203 if( !includeList.empty() ) {
00204 vector<vector<string> > includeVars;
00205 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00206 assert( !includeVars.empty() );
00207 for( int i=0;i<includeVars[0].size();i++ )
00208 includeSet.insert(includeVars[0][i]);
00209 if( !reader->chooseVars(includeSet) ) {
00210 cerr << "Unable to include variables in training set." << endl;
00211 return 2;
00212 }
00213 else {
00214 cout << "Following variables have been included in optimization: ";
00215 for( set<string>::const_iterator
00216 i=includeSet.begin();i!=includeSet.end();i++ )
00217 cout << "\"" << *i << "\"" << " ";
00218 cout << endl;
00219 }
00220 }
00221
00222
00223 set<string> excludeSet;
00224 if( !excludeList.empty() ) {
00225 vector<vector<string> > excludeVars;
00226 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00227 assert( !excludeVars.empty() );
00228 for( int i=0;i<excludeVars[0].size();i++ )
00229 excludeSet.insert(excludeVars[0][i]);
00230 if( !reader->chooseAllBut(excludeSet) ) {
00231 cerr << "Unable to exclude variables from training set." << endl;
00232 return 2;
00233 }
00234 else {
00235 cout << "Following variables have been excluded from optimization: ";
00236 for( set<string>::const_iterator
00237 i=excludeSet.begin();i!=excludeSet.end();i++ )
00238 cout << "\"" << *i << "\"" << " ";
00239 cout << endl;
00240 }
00241 }
00242
00243
00244 auto_ptr<SprAbsFilter> filter(reader->read(dataFile.c_str()));
00245 if( filter.get() == 0 ) {
00246 cerr << "Unable to read data from file " << dataFile.c_str() << endl;
00247 return 2;
00248 }
00249 vector<string> vars;
00250 filter->vars(vars);
00251 cout << "Read data from file " << dataFile.c_str() << " for variables";
00252 for( int i=0;i<vars.size();i++ )
00253 cout << " \"" << vars[i].c_str() << "\"";
00254 cout << endl;
00255 cout << "Total number of points read: " << filter->size() << endl;
00256
00257
00258 vector<SprClass> inputClasses;
00259 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00260 cerr << "Cannot choose input classes for string "
00261 << inputClassesString << endl;
00262 return 2;
00263 }
00264 filter->classes(inputClasses);
00265 assert( inputClasses.size() > 1 );
00266 cout << "Training data filtered by class." << endl;
00267 for( int i=0;i<inputClasses.size();i++ ) {
00268 cout << "Points in class " << inputClasses[i] << ": "
00269 << filter->ptsInClass(inputClasses[i]) << endl;
00270 }
00271
00272
00273 if( scaleWeights ) {
00274 cout << "Signal weights are multiplied by " << sW << endl;
00275 filter->scaleWeights(inputClasses[1],sW);
00276 }
00277
00278
00279 auto_ptr<SprAbsFilter> garbage_train;
00280 if( !transformerFile.empty() ) {
00281 SprVarTransformerReader transReader;
00282 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00283 if( t == 0 ) {
00284 cerr << "Unable to read VarTransformer from file "
00285 << transformerFile.c_str() << endl;
00286 return 2;
00287 }
00288 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00289 bool replaceOriginalData = true;
00290 if( !t_train->transform(t,replaceOriginalData) ) {
00291 cerr << "Unable to apply VarTransformer to training data." << endl;
00292 return 2;
00293 }
00294 cout << "Variable transformation from file "
00295 << transformerFile.c_str() << " has been applied to data." << endl;
00296 garbage_train.reset(filter.release());
00297 filter.reset(t_train);
00298 filter->vars(vars);
00299 }
00300
00301
00302 auto_ptr<SprAbsFilter> valFilter;
00303 if( split ) {
00304 cout << "Splitting data with factor " << splitFactor << endl;
00305 vector<double> weights;
00306 SprData* splitted = filter->split(splitFactor,weights,false);
00307 if( splitted == 0 ) {
00308 cerr << "Unable to split data." << endl;
00309 return 2;
00310 }
00311 bool ownData = true;
00312 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00313 cout << "Data re-filtered:" << endl;
00314 for( int i=0;i<inputClasses.size();i++ ) {
00315 cout << "Points in class " << inputClasses[i] << ": "
00316 << valFilter->ptsInClass(inputClasses[i]) << endl;
00317 }
00318 }
00319 else {
00320 valFilter.reset(filter.release());
00321 }
00322
00323
00324 vector<SprAbsTrainedClassifier*> trained(nTrained);
00325 vector<SprCoordinateMapper*> specificMappers(nTrained);
00326 for( int i=0;i<nTrained;i++ ) {
00327
00328
00329 trained[i]
00330 = SprClassifierReader::readTrained(configFiles[0][i].c_str(),verbose);
00331 if( trained[i] == 0 ) {
00332 cerr << "Unable to read classifier configuration from file "
00333 << configFiles[0][i].c_str() << endl;
00334 cleanup(trained);
00335 return 3;
00336 }
00337 cout << "Read classifier " << trained[i]->name().c_str()
00338 << " with dimensionality " << trained[i]->dim() << endl;
00339
00340
00341 vector<string> trainedVars;
00342 trained[i]->vars(trainedVars);
00343 if( verbose > 0 ) {
00344 cout << "Variables: " << endl;
00345 for( int j=0;j<trainedVars.size();j++ )
00346 cout << trainedVars[j].c_str() << " ";
00347 cout << endl;
00348 }
00349
00350
00351 if( mapTrainedVars || trained[i]->name()=="Combiner" ) {
00352 specificMappers[i]
00353 = SprCoordinateMapper::createMapper(trainedVars,vars);
00354 }
00355
00356
00357 if( useStandard ) {
00358 if( trained[i]->name() == "AdaBoost" ) {
00359 SprTrainedAdaBoost* specific
00360 = static_cast<SprTrainedAdaBoost*>(trained[i]);
00361 specific->useStandard();
00362 }
00363 else if( trained[i]->name() == "Fisher" ) {
00364 SprTrainedFisher* specific
00365 = static_cast<SprTrainedFisher*>(trained[i]);
00366 specific->useStandard();
00367 }
00368 else if( trained[i]->name() == "LogitR" ) {
00369 SprTrainedLogitR* specific
00370 = static_cast<SprTrainedLogitR*>(trained[i]);
00371 specific->useStandard();
00372 }
00373 }
00374 }
00375
00376
00377 if( tupleName.empty() ) tupleName = "data";
00378 auto_ptr<SprAbsWriter>
00379 tuple(SprRWFactory::makeWriter(writeMode,tupleName.c_str()));
00380 if( !tuple->init(tupleFile.c_str()) ) {
00381 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00382 cleanup(trained);
00383 return 5;
00384 }
00385
00386
00387
00388 string printVarsDoNotFeed;
00389 vector<vector<string> > varsDoNotFeed;
00390 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00391 vector<unsigned> mapper;
00392 for( int d=0;d<vars.size();d++ ) {
00393 if( varsDoNotFeed.empty() ||
00394 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00395 ==varsDoNotFeed[0].end()) ) {
00396 mapper.push_back(d);
00397 }
00398 else {
00399 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00400 printVarsDoNotFeed += vars[d];
00401 }
00402 }
00403 if( !printVarsDoNotFeed.empty() ) {
00404 cout << "The following variables are not used in the algorithm, "
00405 << "but will be included in the output file: "
00406 << printVarsDoNotFeed.c_str() << endl;
00407 }
00408
00409
00410 SprDataFeeder feeder(valFilter.get(),tuple.get(),mapper);
00411 for( int i=0;i<nTrained;i++ ) {
00412 string useName;
00413 if( useClassifierNames )
00414 useName = classifierNames[0][i];
00415 else
00416 useName = trained[i]->name();
00417 feeder.addClassifier(trained[i],useName.c_str(),specificMappers[i]);
00418 }
00419 if( !feeder.feed(nPrintOut) ) {
00420 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00421 cleanup(trained);
00422 return 6;
00423 }
00424
00425
00426 cleanup(trained);
00427 return 0;
00428 }