#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprFisher.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLogitR.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedFisher.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedLogitR.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include "PhysicsTools/StatPatternRecognition/src/SprVector.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <cassert>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void help | ( | const char * | prog | ) |
Definition at line 33 of file SprFisherLogitApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00034 { 00035 cout << "Usage: " << prog << " training_data_file" << endl; 00036 cout << "\t Options: " << endl; 00037 cout << "\t-h --- help " << endl; 00038 cout << "\t-m order of Fisher " << endl; 00039 cout << "\t\t 1 = linear " << endl; 00040 cout << "\t\t 2 = quadratic " << endl; 00041 cout << "\t\t 3 = both " << endl; 00042 cout << "\t-l use logistic regression " << endl; 00043 cout << "\t-e accuracy for logistic regression (default=0.001)" << endl; 00044 cout << "\t-u update factor for logistic regression (default=1)"<< endl; 00045 cout << "\t-i initialize logistic regression coeffs to 0 (def=LDA output)" 00046 << endl; 00047 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl; 00048 cout << "\t-Q apply variable transformation saved in file " << endl; 00049 cout << "\t-o output Tuple file " << endl; 00050 cout << "\t-s use standard output ranging from -infty to +infty"<< endl; 00051 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00052 cout << "\t-A save output data in ascii instead of Root " << endl; 00053 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00054 cout << "\t-f store classifier configuration to file " << endl; 00055 cout << "\t-K keep this fraction in training set and " << endl; 00056 cout << "\t\t put the rest into validation set " << endl; 00057 cout << "\t-D randomize training set split-up " << endl; 00058 cout << "\t-t read validation/test data from a file " << endl; 00059 cout << "\t\t (must be in same format as input data!!! " << endl; 00060 cout << "\t-p output file to store validation/test data " << endl; 00061 cout << "\t-w scale all signal weights by this factor " << endl; 00062 cout << "\t-V include only these input variables " << endl; 00063 cout << "\t-z exclude input variables from the list " << endl; 00064 cout << "\t\t Variables must be listed in quotes and separated by commas." 00065 << endl; 00066 }
Definition at line 69 of file SprFisherLogitApp.cc.
References DeDxTools::beta(), c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, lat::endl(), filter, help(), i, split, t, vars, and weights.
00070 { 00071 // check command line 00072 if( argc < 2 ) { 00073 help(argv[0]); 00074 return 1; 00075 } 00076 00077 // init 00078 int fisherMode = 0; 00079 bool useLogit = false; 00080 double eps = 0.001; 00081 double updateFactor = 1; 00082 bool initToZero = false; 00083 string tupleFile; 00084 int readMode = 0; 00085 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00086 int verbose = 0; 00087 string outFile; 00088 string valFile; 00089 string valHbkFile; 00090 bool scaleWeights = false; 00091 double sW = 1.; 00092 string includeList, excludeList; 00093 string inputClassesString; 00094 bool useStandard = false; 00095 bool split = false; 00096 double splitFactor = 0; 00097 bool splitRandomize = false; 00098 string transformerFile; 00099 00100 // decode command line 00101 int c; 00102 extern char* optarg; 00103 extern int optind; 00104 while( (c = getopt(argc,argv,"hm:le:u:iy:Q:o:sa:Av:f:K:Dt:p:w:V:z:")) != EOF ) { 00105 switch( c ) 00106 { 00107 case 'h' : 00108 help(argv[0]); 00109 return 1; 00110 case 'm' : 00111 fisherMode = (optarg==0 ? 1 : atoi(optarg)); 00112 break; 00113 case 'l' : 00114 useLogit = true; 00115 break; 00116 case 'e' : 00117 eps = (optarg==0 ? 0.001 : atof(optarg)); 00118 break; 00119 case 'u' : 00120 updateFactor = (optarg==0 ? 1. : atof(optarg)); 00121 break; 00122 case 'i' : 00123 initToZero = true; 00124 break; 00125 case 'y' : 00126 inputClassesString = optarg; 00127 break; 00128 case 'Q' : 00129 transformerFile = optarg; 00130 break; 00131 case 'o' : 00132 tupleFile = optarg; 00133 break; 00134 case 's' : 00135 useStandard = true; 00136 break; 00137 case 'a' : 00138 readMode = (optarg==0 ? 0 : atoi(optarg)); 00139 break; 00140 case 'A' : 00141 writeMode = SprRWFactory::Ascii; 00142 break; 00143 case 'v' : 00144 verbose = (optarg==0 ? 0 : atoi(optarg)); 00145 break; 00146 case 'f' : 00147 outFile = optarg; 00148 break; 00149 case 'K' : 00150 split = true; 00151 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00152 break; 00153 case 'D' : 00154 splitRandomize = true; 00155 break; 00156 case 't' : 00157 valFile = optarg; 00158 break; 00159 case 'p' : 00160 valHbkFile = optarg; 00161 break; 00162 case 'w' : 00163 if( optarg != 0 ) { 00164 scaleWeights = true; 00165 sW = atof(optarg); 00166 } 00167 break; 00168 case 'V' : 00169 includeList = optarg; 00170 break; 00171 case 'z' : 00172 excludeList = optarg; 00173 break; 00174 } 00175 } 00176 00177 // training file name must be the only argument that appears 00178 // after all options on the command line 00179 string trFile; 00180 if( optind == argc-1 ) 00181 trFile = argv[optind]; 00182 if( trFile.empty() ) { 00183 cerr << "No training file is specified." << endl; 00184 return 1; 00185 } 00186 00187 // sanity check 00188 if( fisherMode==0 && !useLogit ) { 00189 cerr << "Neither Fisher nor logistic regression is requested." << endl; 00190 return 1; 00191 } 00192 00193 // make reader 00194 SprRWFactory::DataType inputType 00195 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00196 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00197 00198 // include variables 00199 set<string> includeSet; 00200 if( !includeList.empty() ) { 00201 vector<vector<string> > includeVars; 00202 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00203 assert( !includeVars.empty() ); 00204 for( int i=0;i<includeVars[0].size();i++ ) 00205 includeSet.insert(includeVars[0][i]); 00206 if( !reader->chooseVars(includeSet) ) { 00207 cerr << "Unable to include variables in training set." << endl; 00208 return 2; 00209 } 00210 else { 00211 cout << "Following variables have been included in optimization: "; 00212 for( set<string>::const_iterator 00213 i=includeSet.begin();i!=includeSet.end();i++ ) 00214 cout << "\"" << *i << "\"" << " "; 00215 cout << endl; 00216 } 00217 } 00218 00219 // exclude variables 00220 set<string> excludeSet; 00221 if( !excludeList.empty() ) { 00222 vector<vector<string> > excludeVars; 00223 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00224 assert( !excludeVars.empty() ); 00225 for( int i=0;i<excludeVars[0].size();i++ ) 00226 excludeSet.insert(excludeVars[0][i]); 00227 if( !reader->chooseAllBut(excludeSet) ) { 00228 cerr << "Unable to exclude variables from training set." << endl; 00229 return 2; 00230 } 00231 else { 00232 cout << "Following variables have been excluded from optimization: "; 00233 for( set<string>::const_iterator 00234 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00235 cout << "\"" << *i << "\"" << " "; 00236 cout << endl; 00237 } 00238 } 00239 00240 // read training data from file 00241 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00242 if( filter.get() == 0 ) { 00243 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00244 return 2; 00245 } 00246 vector<string> vars; 00247 filter->vars(vars); 00248 cout << "Read data from file " << trFile.c_str() 00249 << " for variables"; 00250 for( int i=0;i<vars.size();i++ ) 00251 cout << " \"" << vars[i].c_str() << "\""; 00252 cout << endl; 00253 cout << "Total number of points read: " << filter->size() << endl; 00254 00255 // filter training data by class 00256 vector<SprClass> inputClasses; 00257 if( !filter->filterByClass(inputClassesString.c_str()) ) { 00258 cerr << "Cannot choose input classes for string " 00259 << inputClassesString << endl; 00260 return 2; 00261 } 00262 filter->classes(inputClasses); 00263 assert( inputClasses.size() > 1 ); 00264 cout << "Training data filtered by class." << endl; 00265 for( int i=0;i<inputClasses.size();i++ ) { 00266 cout << "Points in class " << inputClasses[i] << ": " 00267 << filter->ptsInClass(inputClasses[i]) << endl; 00268 } 00269 00270 // scale weights 00271 if( scaleWeights ) { 00272 cout << "Signal weights are multiplied by " << sW << endl; 00273 filter->scaleWeights(inputClasses[1],sW); 00274 } 00275 00276 // read validation data from file 00277 auto_ptr<SprAbsFilter> valFilter; 00278 if( split && !valFile.empty() ) { 00279 cerr << "Unable to split training data and use validation data " 00280 << "from a separate file." << endl; 00281 return 2; 00282 } 00283 if( split ) { 00284 cout << "Splitting training data with factor " << splitFactor << endl; 00285 if( splitRandomize ) 00286 cout << "Will use randomized splitting." << endl; 00287 vector<double> weights; 00288 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00289 if( splitted == 0 ) { 00290 cerr << "Unable to split training data." << endl; 00291 return 2; 00292 } 00293 bool ownData = true; 00294 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00295 cout << "Training data re-filtered:" << endl; 00296 for( int i=0;i<inputClasses.size();i++ ) { 00297 cout << "Points in class " << inputClasses[i] << ": " 00298 << filter->ptsInClass(inputClasses[i]) << endl; 00299 } 00300 } if( !valFile.empty() ) { 00301 auto_ptr<SprAbsReader> 00302 valReader(SprRWFactory::makeReader(inputType,readMode)); 00303 if( !includeSet.empty() ) { 00304 if( !valReader->chooseVars(includeSet) ) { 00305 cerr << "Unable to include variables in validation set." << endl; 00306 return 2; 00307 } 00308 } 00309 if( !excludeSet.empty() ) { 00310 if( !valReader->chooseAllBut(excludeSet) ) { 00311 cerr << "Unable to exclude variables from validation set." << endl; 00312 return 2; 00313 } 00314 } 00315 valFilter.reset(valReader->read(valFile.c_str())); 00316 if( valFilter.get() == 0 ) { 00317 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00318 return 2; 00319 } 00320 vector<string> valVars; 00321 valFilter->vars(valVars); 00322 cout << "Read validation data from file " << valFile.c_str() 00323 << " for variables"; 00324 for( int i=0;i<valVars.size();i++ ) 00325 cout << " \"" << valVars[i].c_str() << "\""; 00326 cout << endl; 00327 cout << "Total number of points read: " << valFilter->size() << endl; 00328 } 00329 00330 // filter validation data by class 00331 if( valFilter.get() != 0 ) { 00332 if( !valFilter->filterByClass(inputClassesString.c_str()) ) { 00333 cerr << "Cannot choose input classes for string " 00334 << inputClassesString << endl; 00335 return 2; 00336 } 00337 valFilter->classes(inputClasses); 00338 cout << "Validation data filtered by class." << endl; 00339 for( int i=0;i<inputClasses.size();i++ ) { 00340 cout << "Points in class " << inputClasses[i] << ": " 00341 << valFilter->ptsInClass(inputClasses[i]) << endl; 00342 } 00343 } 00344 00345 // scale weights 00346 if( scaleWeights && valFilter.get()!=0 ) 00347 valFilter->scaleWeights(inputClasses[1],sW); 00348 00349 // apply transformation of variables to training and test data 00350 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00351 if( !transformerFile.empty() ) { 00352 SprVarTransformerReader transReader; 00353 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00354 if( t == 0 ) { 00355 cerr << "Unable to read VarTransformer from file " 00356 << transformerFile.c_str() << endl; 00357 return 2; 00358 } 00359 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00360 SprTransformerFilter* t_valid = 0; 00361 if( valFilter.get() != 0 ) 00362 t_valid = new SprTransformerFilter(valFilter.get()); 00363 bool replaceOriginalData = true; 00364 if( !t_train->transform(t,replaceOriginalData) ) { 00365 cerr << "Unable to apply VarTransformer to training data." << endl; 00366 return 2; 00367 } 00368 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00369 cerr << "Unable to apply VarTransformer to validation data." << endl; 00370 return 2; 00371 } 00372 cout << "Variable transformation from file " 00373 << transformerFile.c_str() << " has been applied to " 00374 << "training and validation data." << endl; 00375 garbage_train.reset(filter.release()); 00376 garbage_valid.reset(valFilter.release()); 00377 filter.reset(t_train); 00378 valFilter.reset(t_valid); 00379 } 00380 00381 // train Fisher 00382 auto_ptr<SprFisher> fisher; 00383 auto_ptr<SprTrainedFisher> trainedFisher1, trainedFisher2; 00384 bool both = false; 00385 if( fisherMode != 0 ) { 00386 if( fisherMode!=1 && fisherMode!=2 && fisherMode!=3 ) { 00387 cerr << "Unknown mode for Fisher " << fisherMode << endl; 00388 return 3; 00389 } 00390 if( fisherMode == 3 ) { 00391 both = true; 00392 fisherMode = 1; 00393 } 00394 cout << "Initializing Fisher in mode " << fisherMode << endl; 00395 fisher.reset(new SprFisher(filter.get(),fisherMode)); 00396 if( !fisher->train(verbose) ) { 00397 cerr << "Unable to train Fisher." << endl; 00398 return 3; 00399 } 00400 else { 00401 cout << "Trained Fisher:" << endl; 00402 fisher->print(cout); 00403 } 00404 00405 // make a trained Fisher 00406 trainedFisher1.reset(fisher->makeTrained()); 00407 if( trainedFisher1.get() == 0 ) { 00408 cerr << "Unable to make a trained Fisher." << endl; 00409 return 4; 00410 } 00411 if( useStandard ) trainedFisher1->useStandard(); 00412 00413 // train another one if necessary 00414 if( both ) { 00415 fisher->setMode(2); 00416 if( !fisher->train(verbose) ) { 00417 cerr << "Unable to train 2nd Fisher." << endl; 00418 return 5; 00419 } 00420 else { 00421 cout << "Trained 2nd Fisher:" << endl; 00422 fisher->print(cout); 00423 } 00424 trainedFisher2.reset(fisher->makeTrained()); 00425 if( trainedFisher2.get() == 0 ) { 00426 cerr << "Unable to make a trained 2nd Fisher." << endl; 00427 return 6; 00428 } 00429 if( useStandard ) trainedFisher2->useStandard(); 00430 } 00431 } 00432 00433 // train logistic regression 00434 auto_ptr<SprLogitR> logit; 00435 auto_ptr<SprTrainedLogitR> trainedLogit; 00436 if( useLogit ) { 00437 // init 00438 if( initToZero ) { 00439 SprVector beta(filter->dim()); 00440 for( int i=0;i<filter->dim();i++ ) beta[i] = 0; 00441 logit.reset(new SprLogitR(filter.get(),0,beta,eps,updateFactor)); 00442 } 00443 else { 00444 logit.reset(new SprLogitR(filter.get(),eps,updateFactor)); 00445 } 00446 00447 // train 00448 if( !logit->train(verbose) ) { 00449 cerr << "Unable to train logistic regression." << endl; 00450 return 7; 00451 } 00452 else { 00453 cout << "Trained Logistic Regression:" << endl; 00454 logit->print(cout); 00455 } 00456 00457 // make trained logit 00458 trainedLogit.reset(logit->makeTrained()); 00459 if( trainedLogit.get() == 0 ) { 00460 cerr << "Unable to make trained logistic regression." << endl; 00461 return 8; 00462 } 00463 if( useStandard ) trainedLogit->useStandard(); 00464 } 00465 00466 // save classifier configuration into file 00467 if( !outFile.empty() ) { 00468 if( both || (fisherMode>0 && useLogit) ) { 00469 cerr << "More than one classifier trained. " 00470 << "Cannot save classifier configurations to file." << endl; 00471 return 9; 00472 } 00473 SprAbsClassifier* trainable = 0; 00474 if( fisher.get() != 0 ) trainable = fisher.get(); 00475 if( logit.get() != 0 ) trainable = logit.get(); 00476 assert( trainable != 0 ); 00477 if( !trainable->store(outFile.c_str()) ) { 00478 cerr << "Cannot store classifier in file " << outFile.c_str() << endl; 00479 return 9; 00480 } 00481 } 00482 00483 // make histogram if requested 00484 if( tupleFile.empty() && valHbkFile.empty() ) 00485 return 0; 00486 00487 // feed training data 00488 if( !tupleFile.empty() ) { 00489 // make a writer 00490 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training")); 00491 if( !tuple->init(tupleFile.c_str()) ) { 00492 cerr << "Unable to open output file " << tupleFile.c_str() << endl; 00493 return 10; 00494 } 00495 string firstClassifier; 00496 if( trainedFisher2.get()!=0 || fisherMode==1 ) 00497 firstClassifier = "lin"; 00498 else 00499 firstClassifier = "qua"; 00500 // feed 00501 SprDataFeeder feeder(filter.get(),tuple.get()); 00502 feeder.addClassifier(trainedFisher1.get(),firstClassifier.c_str()); 00503 feeder.addClassifier(trainedFisher2.get(),"qua"); 00504 feeder.addClassifier(trainedLogit.get(),"logit"); 00505 if( !feeder.feed(1000) ) { 00506 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl; 00507 return 11; 00508 } 00509 } 00510 00511 // feed validation data 00512 if( !valHbkFile.empty() ) { 00513 // make a writer 00514 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"test")); 00515 if( !tuple->init(valHbkFile.c_str()) ) { 00516 cerr << "Unable to open output file " << valHbkFile.c_str() << endl; 00517 return 12; 00518 } 00519 string firstClassifier; 00520 if( trainedFisher2.get()!=0 || fisherMode==1 ) 00521 firstClassifier = "lin"; 00522 else 00523 firstClassifier = "qua"; 00524 // feed 00525 SprDataFeeder feeder(valFilter.get(),tuple.get()); 00526 feeder.addClassifier(trainedFisher1.get(),firstClassifier.c_str()); 00527 feeder.addClassifier(trainedFisher2.get(),"qua"); 00528 feeder.addClassifier(trainedLogit.get(),"logit"); 00529 if( !feeder.feed(1000) ) { 00530 cerr << "Cannot feed data into file " << valHbkFile.c_str() << endl; 00531 return 13; 00532 } 00533 } 00534 00535 // exit 00536 return 0; 00537 }