00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprFisher.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprLogitR.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedFisher.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedLogitR.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00019 #include "PhysicsTools/StatPatternRecognition/src/SprVector.hh"
00020
00021 #include <stdlib.h>
00022 #include <unistd.h>
00023 #include <iostream>
00024 #include <vector>
00025 #include <set>
00026 #include <string>
00027 #include <memory>
00028 #include <cassert>
00029
00030 using namespace std;
00031
00032
00033 void help(const char* prog)
00034 {
00035 cout << "Usage: " << prog << " training_data_file" << endl;
00036 cout << "\t Options: " << endl;
00037 cout << "\t-h --- help " << endl;
00038 cout << "\t-m order of Fisher " << endl;
00039 cout << "\t\t 1 = linear " << endl;
00040 cout << "\t\t 2 = quadratic " << endl;
00041 cout << "\t\t 3 = both " << endl;
00042 cout << "\t-l use logistic regression " << endl;
00043 cout << "\t-e accuracy for logistic regression (default=0.001)" << endl;
00044 cout << "\t-u update factor for logistic regression (default=1)"<< endl;
00045 cout << "\t-i initialize logistic regression coeffs to 0 (def=LDA output)"
00046 << endl;
00047 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00048 cout << "\t-Q apply variable transformation saved in file " << endl;
00049 cout << "\t-o output Tuple file " << endl;
00050 cout << "\t-s use standard output ranging from -infty to +infty"<< endl;
00051 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00052 cout << "\t-A save output data in ascii instead of Root " << endl;
00053 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00054 cout << "\t-f store classifier configuration to file " << endl;
00055 cout << "\t-K keep this fraction in training set and " << endl;
00056 cout << "\t\t put the rest into validation set " << endl;
00057 cout << "\t-D randomize training set split-up " << endl;
00058 cout << "\t-t read validation/test data from a file " << endl;
00059 cout << "\t\t (must be in same format as input data!!! " << endl;
00060 cout << "\t-p output file to store validation/test data " << endl;
00061 cout << "\t-w scale all signal weights by this factor " << endl;
00062 cout << "\t-V include only these input variables " << endl;
00063 cout << "\t-z exclude input variables from the list " << endl;
00064 cout << "\t\t Variables must be listed in quotes and separated by commas."
00065 << endl;
00066 }
00067
00068
00069 int main(int argc, char ** argv)
00070 {
00071
00072 if( argc < 2 ) {
00073 help(argv[0]);
00074 return 1;
00075 }
00076
00077
00078 int fisherMode = 0;
00079 bool useLogit = false;
00080 double eps = 0.001;
00081 double updateFactor = 1;
00082 bool initToZero = false;
00083 string tupleFile;
00084 int readMode = 0;
00085 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00086 int verbose = 0;
00087 string outFile;
00088 string valFile;
00089 string valHbkFile;
00090 bool scaleWeights = false;
00091 double sW = 1.;
00092 string includeList, excludeList;
00093 string inputClassesString;
00094 bool useStandard = false;
00095 bool split = false;
00096 double splitFactor = 0;
00097 bool splitRandomize = false;
00098 string transformerFile;
00099
00100
00101 int c;
00102 extern char* optarg;
00103 extern int optind;
00104 while( (c = getopt(argc,argv,"hm:le:u:iy:Q:o:sa:Av:f:K:Dt:p:w:V:z:")) != EOF ) {
00105 switch( c )
00106 {
00107 case 'h' :
00108 help(argv[0]);
00109 return 1;
00110 case 'm' :
00111 fisherMode = (optarg==0 ? 1 : atoi(optarg));
00112 break;
00113 case 'l' :
00114 useLogit = true;
00115 break;
00116 case 'e' :
00117 eps = (optarg==0 ? 0.001 : atof(optarg));
00118 break;
00119 case 'u' :
00120 updateFactor = (optarg==0 ? 1. : atof(optarg));
00121 break;
00122 case 'i' :
00123 initToZero = true;
00124 break;
00125 case 'y' :
00126 inputClassesString = optarg;
00127 break;
00128 case 'Q' :
00129 transformerFile = optarg;
00130 break;
00131 case 'o' :
00132 tupleFile = optarg;
00133 break;
00134 case 's' :
00135 useStandard = true;
00136 break;
00137 case 'a' :
00138 readMode = (optarg==0 ? 0 : atoi(optarg));
00139 break;
00140 case 'A' :
00141 writeMode = SprRWFactory::Ascii;
00142 break;
00143 case 'v' :
00144 verbose = (optarg==0 ? 0 : atoi(optarg));
00145 break;
00146 case 'f' :
00147 outFile = optarg;
00148 break;
00149 case 'K' :
00150 split = true;
00151 splitFactor = (optarg==0 ? 0 : atof(optarg));
00152 break;
00153 case 'D' :
00154 splitRandomize = true;
00155 break;
00156 case 't' :
00157 valFile = optarg;
00158 break;
00159 case 'p' :
00160 valHbkFile = optarg;
00161 break;
00162 case 'w' :
00163 if( optarg != 0 ) {
00164 scaleWeights = true;
00165 sW = atof(optarg);
00166 }
00167 break;
00168 case 'V' :
00169 includeList = optarg;
00170 break;
00171 case 'z' :
00172 excludeList = optarg;
00173 break;
00174 }
00175 }
00176
00177
00178
00179 string trFile;
00180 if( optind == argc-1 )
00181 trFile = argv[optind];
00182 if( trFile.empty() ) {
00183 cerr << "No training file is specified." << endl;
00184 return 1;
00185 }
00186
00187
00188 if( fisherMode==0 && !useLogit ) {
00189 cerr << "Neither Fisher nor logistic regression is requested." << endl;
00190 return 1;
00191 }
00192
00193
00194 SprRWFactory::DataType inputType
00195 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00196 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00197
00198
00199 set<string> includeSet;
00200 if( !includeList.empty() ) {
00201 vector<vector<string> > includeVars;
00202 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00203 assert( !includeVars.empty() );
00204 for( int i=0;i<includeVars[0].size();i++ )
00205 includeSet.insert(includeVars[0][i]);
00206 if( !reader->chooseVars(includeSet) ) {
00207 cerr << "Unable to include variables in training set." << endl;
00208 return 2;
00209 }
00210 else {
00211 cout << "Following variables have been included in optimization: ";
00212 for( set<string>::const_iterator
00213 i=includeSet.begin();i!=includeSet.end();i++ )
00214 cout << "\"" << *i << "\"" << " ";
00215 cout << endl;
00216 }
00217 }
00218
00219
00220 set<string> excludeSet;
00221 if( !excludeList.empty() ) {
00222 vector<vector<string> > excludeVars;
00223 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00224 assert( !excludeVars.empty() );
00225 for( int i=0;i<excludeVars[0].size();i++ )
00226 excludeSet.insert(excludeVars[0][i]);
00227 if( !reader->chooseAllBut(excludeSet) ) {
00228 cerr << "Unable to exclude variables from training set." << endl;
00229 return 2;
00230 }
00231 else {
00232 cout << "Following variables have been excluded from optimization: ";
00233 for( set<string>::const_iterator
00234 i=excludeSet.begin();i!=excludeSet.end();i++ )
00235 cout << "\"" << *i << "\"" << " ";
00236 cout << endl;
00237 }
00238 }
00239
00240
00241 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00242 if( filter.get() == 0 ) {
00243 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00244 return 2;
00245 }
00246 vector<string> vars;
00247 filter->vars(vars);
00248 cout << "Read data from file " << trFile.c_str()
00249 << " for variables";
00250 for( int i=0;i<vars.size();i++ )
00251 cout << " \"" << vars[i].c_str() << "\"";
00252 cout << endl;
00253 cout << "Total number of points read: " << filter->size() << endl;
00254
00255
00256 vector<SprClass> inputClasses;
00257 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00258 cerr << "Cannot choose input classes for string "
00259 << inputClassesString << endl;
00260 return 2;
00261 }
00262 filter->classes(inputClasses);
00263 assert( inputClasses.size() > 1 );
00264 cout << "Training data filtered by class." << endl;
00265 for( int i=0;i<inputClasses.size();i++ ) {
00266 cout << "Points in class " << inputClasses[i] << ": "
00267 << filter->ptsInClass(inputClasses[i]) << endl;
00268 }
00269
00270
00271 if( scaleWeights ) {
00272 cout << "Signal weights are multiplied by " << sW << endl;
00273 filter->scaleWeights(inputClasses[1],sW);
00274 }
00275
00276
00277 auto_ptr<SprAbsFilter> valFilter;
00278 if( split && !valFile.empty() ) {
00279 cerr << "Unable to split training data and use validation data "
00280 << "from a separate file." << endl;
00281 return 2;
00282 }
00283 if( split ) {
00284 cout << "Splitting training data with factor " << splitFactor << endl;
00285 if( splitRandomize )
00286 cout << "Will use randomized splitting." << endl;
00287 vector<double> weights;
00288 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00289 if( splitted == 0 ) {
00290 cerr << "Unable to split training data." << endl;
00291 return 2;
00292 }
00293 bool ownData = true;
00294 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00295 cout << "Training data re-filtered:" << endl;
00296 for( int i=0;i<inputClasses.size();i++ ) {
00297 cout << "Points in class " << inputClasses[i] << ": "
00298 << filter->ptsInClass(inputClasses[i]) << endl;
00299 }
00300 } if( !valFile.empty() ) {
00301 auto_ptr<SprAbsReader>
00302 valReader(SprRWFactory::makeReader(inputType,readMode));
00303 if( !includeSet.empty() ) {
00304 if( !valReader->chooseVars(includeSet) ) {
00305 cerr << "Unable to include variables in validation set." << endl;
00306 return 2;
00307 }
00308 }
00309 if( !excludeSet.empty() ) {
00310 if( !valReader->chooseAllBut(excludeSet) ) {
00311 cerr << "Unable to exclude variables from validation set." << endl;
00312 return 2;
00313 }
00314 }
00315 valFilter.reset(valReader->read(valFile.c_str()));
00316 if( valFilter.get() == 0 ) {
00317 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00318 return 2;
00319 }
00320 vector<string> valVars;
00321 valFilter->vars(valVars);
00322 cout << "Read validation data from file " << valFile.c_str()
00323 << " for variables";
00324 for( int i=0;i<valVars.size();i++ )
00325 cout << " \"" << valVars[i].c_str() << "\"";
00326 cout << endl;
00327 cout << "Total number of points read: " << valFilter->size() << endl;
00328 }
00329
00330
00331 if( valFilter.get() != 0 ) {
00332 if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00333 cerr << "Cannot choose input classes for string "
00334 << inputClassesString << endl;
00335 return 2;
00336 }
00337 valFilter->classes(inputClasses);
00338 cout << "Validation data filtered by class." << endl;
00339 for( int i=0;i<inputClasses.size();i++ ) {
00340 cout << "Points in class " << inputClasses[i] << ": "
00341 << valFilter->ptsInClass(inputClasses[i]) << endl;
00342 }
00343 }
00344
00345
00346 if( scaleWeights && valFilter.get()!=0 )
00347 valFilter->scaleWeights(inputClasses[1],sW);
00348
00349
00350 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00351 if( !transformerFile.empty() ) {
00352 SprVarTransformerReader transReader;
00353 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00354 if( t == 0 ) {
00355 cerr << "Unable to read VarTransformer from file "
00356 << transformerFile.c_str() << endl;
00357 return 2;
00358 }
00359 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00360 SprTransformerFilter* t_valid = 0;
00361 if( valFilter.get() != 0 )
00362 t_valid = new SprTransformerFilter(valFilter.get());
00363 bool replaceOriginalData = true;
00364 if( !t_train->transform(t,replaceOriginalData) ) {
00365 cerr << "Unable to apply VarTransformer to training data." << endl;
00366 return 2;
00367 }
00368 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00369 cerr << "Unable to apply VarTransformer to validation data." << endl;
00370 return 2;
00371 }
00372 cout << "Variable transformation from file "
00373 << transformerFile.c_str() << " has been applied to "
00374 << "training and validation data." << endl;
00375 garbage_train.reset(filter.release());
00376 garbage_valid.reset(valFilter.release());
00377 filter.reset(t_train);
00378 valFilter.reset(t_valid);
00379 }
00380
00381
00382 auto_ptr<SprFisher> fisher;
00383 auto_ptr<SprTrainedFisher> trainedFisher1, trainedFisher2;
00384 bool both = false;
00385 if( fisherMode != 0 ) {
00386 if( fisherMode!=1 && fisherMode!=2 && fisherMode!=3 ) {
00387 cerr << "Unknown mode for Fisher " << fisherMode << endl;
00388 return 3;
00389 }
00390 if( fisherMode == 3 ) {
00391 both = true;
00392 fisherMode = 1;
00393 }
00394 cout << "Initializing Fisher in mode " << fisherMode << endl;
00395 fisher.reset(new SprFisher(filter.get(),fisherMode));
00396 if( !fisher->train(verbose) ) {
00397 cerr << "Unable to train Fisher." << endl;
00398 return 3;
00399 }
00400 else {
00401 cout << "Trained Fisher:" << endl;
00402 fisher->print(cout);
00403 }
00404
00405
00406 trainedFisher1.reset(fisher->makeTrained());
00407 if( trainedFisher1.get() == 0 ) {
00408 cerr << "Unable to make a trained Fisher." << endl;
00409 return 4;
00410 }
00411 if( useStandard ) trainedFisher1->useStandard();
00412
00413
00414 if( both ) {
00415 fisher->setMode(2);
00416 if( !fisher->train(verbose) ) {
00417 cerr << "Unable to train 2nd Fisher." << endl;
00418 return 5;
00419 }
00420 else {
00421 cout << "Trained 2nd Fisher:" << endl;
00422 fisher->print(cout);
00423 }
00424 trainedFisher2.reset(fisher->makeTrained());
00425 if( trainedFisher2.get() == 0 ) {
00426 cerr << "Unable to make a trained 2nd Fisher." << endl;
00427 return 6;
00428 }
00429 if( useStandard ) trainedFisher2->useStandard();
00430 }
00431 }
00432
00433
00434 auto_ptr<SprLogitR> logit;
00435 auto_ptr<SprTrainedLogitR> trainedLogit;
00436 if( useLogit ) {
00437
00438 if( initToZero ) {
00439 SprVector beta(filter->dim());
00440 for( int i=0;i<filter->dim();i++ ) beta[i] = 0;
00441 logit.reset(new SprLogitR(filter.get(),0,beta,eps,updateFactor));
00442 }
00443 else {
00444 logit.reset(new SprLogitR(filter.get(),eps,updateFactor));
00445 }
00446
00447
00448 if( !logit->train(verbose) ) {
00449 cerr << "Unable to train logistic regression." << endl;
00450 return 7;
00451 }
00452 else {
00453 cout << "Trained Logistic Regression:" << endl;
00454 logit->print(cout);
00455 }
00456
00457
00458 trainedLogit.reset(logit->makeTrained());
00459 if( trainedLogit.get() == 0 ) {
00460 cerr << "Unable to make trained logistic regression." << endl;
00461 return 8;
00462 }
00463 if( useStandard ) trainedLogit->useStandard();
00464 }
00465
00466
00467 if( !outFile.empty() ) {
00468 if( both || (fisherMode>0 && useLogit) ) {
00469 cerr << "More than one classifier trained. "
00470 << "Cannot save classifier configurations to file." << endl;
00471 return 9;
00472 }
00473 SprAbsClassifier* trainable = 0;
00474 if( fisher.get() != 0 ) trainable = fisher.get();
00475 if( logit.get() != 0 ) trainable = logit.get();
00476 assert( trainable != 0 );
00477 if( !trainable->store(outFile.c_str()) ) {
00478 cerr << "Cannot store classifier in file " << outFile.c_str() << endl;
00479 return 9;
00480 }
00481 }
00482
00483
00484 if( tupleFile.empty() && valHbkFile.empty() )
00485 return 0;
00486
00487
00488 if( !tupleFile.empty() ) {
00489
00490 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00491 if( !tuple->init(tupleFile.c_str()) ) {
00492 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00493 return 10;
00494 }
00495 string firstClassifier;
00496 if( trainedFisher2.get()!=0 || fisherMode==1 )
00497 firstClassifier = "lin";
00498 else
00499 firstClassifier = "qua";
00500
00501 SprDataFeeder feeder(filter.get(),tuple.get());
00502 feeder.addClassifier(trainedFisher1.get(),firstClassifier.c_str());
00503 feeder.addClassifier(trainedFisher2.get(),"qua");
00504 feeder.addClassifier(trainedLogit.get(),"logit");
00505 if( !feeder.feed(1000) ) {
00506 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00507 return 11;
00508 }
00509 }
00510
00511
00512 if( !valHbkFile.empty() ) {
00513
00514 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"test"));
00515 if( !tuple->init(valHbkFile.c_str()) ) {
00516 cerr << "Unable to open output file " << valHbkFile.c_str() << endl;
00517 return 12;
00518 }
00519 string firstClassifier;
00520 if( trainedFisher2.get()!=0 || fisherMode==1 )
00521 firstClassifier = "lin";
00522 else
00523 firstClassifier = "qua";
00524
00525 SprDataFeeder feeder(valFilter.get(),tuple.get());
00526 feeder.addClassifier(trainedFisher1.get(),firstClassifier.c_str());
00527 feeder.addClassifier(trainedFisher2.get(),"qua");
00528 feeder.addClassifier(trainedLogit.get(),"logit");
00529 if( !feeder.feed(1000) ) {
00530 cerr << "Cannot feed data into file " << valHbkFile.c_str() << endl;
00531 return 13;
00532 }
00533 }
00534
00535
00536 return 0;
00537 }