CMS 3D CMS Logo

SprFisherLogitApp.cc

Go to the documentation of this file.
00001 //$Id: SprFisherLogitApp.cc,v 1.4 2007/11/12 06:19:11 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprFisher.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprLogitR.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedFisher.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedLogitR.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00019 #include "PhysicsTools/StatPatternRecognition/src/SprVector.hh"
00020 
00021 #include <stdlib.h>
00022 #include <unistd.h>
00023 #include <iostream>
00024 #include <vector>
00025 #include <set>
00026 #include <string>
00027 #include <memory>
00028 #include <cassert>
00029 
00030 using namespace std;
00031 
00032 
00033 void help(const char* prog) 
00034 {
00035   cout << "Usage:  " << prog << " training_data_file" << endl;
00036   cout << "\t Options: " << endl;
00037   cout << "\t-h --- help                                        " << endl;
00038   cout << "\t-m order of Fisher                                 " << endl;
00039   cout << "\t\t 1 = linear                                      " << endl;
00040   cout << "\t\t 2 = quadratic                                   " << endl;
00041   cout << "\t\t 3 = both                                        " << endl;
00042   cout << "\t-l use logistic regression                         " << endl;
00043   cout << "\t-e accuracy for logistic regression (default=0.001)" << endl;
00044   cout << "\t-u update factor for logistic regression (default=1)"<< endl;
00045   cout << "\t-i initialize logistic regression coeffs to 0 (def=LDA output)"
00046        << endl;
00047   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00048   cout << "\t-Q apply variable transformation saved in file     " << endl;
00049   cout << "\t-o output Tuple file                               " << endl;
00050   cout << "\t-s use standard output ranging from -infty to +infty"<< endl;
00051   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00052   cout << "\t-A save output data in ascii instead of Root       " << endl;
00053   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00054   cout << "\t-f store classifier configuration to file          " << endl;
00055   cout << "\t-K keep this fraction in training set and          " << endl;
00056   cout << "\t\t put the rest into validation set                " << endl;
00057   cout << "\t-D randomize training set split-up                 " << endl;
00058   cout << "\t-t read validation/test data from a file           " << endl;
00059   cout << "\t\t (must be in same format as input data!!!        " << endl;
00060   cout << "\t-p output file to store validation/test data       " << endl;
00061   cout << "\t-w scale all signal weights by this factor         " << endl;
00062   cout << "\t-V include only these input variables              " << endl;
00063   cout << "\t-z exclude input variables from the list           " << endl;
00064   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00065        << endl;
00066 }
00067 
00068 
00069 int main(int argc, char ** argv)
00070 {
00071   // check command line
00072   if( argc < 2 ) {
00073     help(argv[0]);
00074     return 1;
00075   }
00076 
00077   // init
00078   int fisherMode = 0;
00079   bool useLogit = false;
00080   double eps = 0.001;
00081   double updateFactor = 1;
00082   bool initToZero = false;
00083   string tupleFile;
00084   int readMode = 0;
00085   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00086   int verbose = 0;
00087   string outFile;
00088   string valFile;
00089   string valHbkFile;
00090   bool scaleWeights = false;
00091   double sW = 1.;
00092   string includeList, excludeList;
00093   string inputClassesString;
00094   bool useStandard = false;
00095   bool split = false;
00096   double splitFactor = 0;
00097   bool splitRandomize = false;
00098   string transformerFile;
00099 
00100   // decode command line
00101   int c;
00102   extern char* optarg;
00103   extern int optind;
00104   while( (c = getopt(argc,argv,"hm:le:u:iy:Q:o:sa:Av:f:K:Dt:p:w:V:z:")) != EOF ) {
00105     switch( c )
00106       {
00107       case 'h' :
00108         help(argv[0]);
00109         return 1;
00110       case 'm' :
00111         fisherMode = (optarg==0 ? 1 : atoi(optarg));
00112         break;
00113       case 'l' :
00114         useLogit = true;
00115         break;
00116       case 'e' :
00117         eps = (optarg==0 ? 0.001 : atof(optarg));
00118         break;
00119       case 'u' :
00120         updateFactor = (optarg==0 ? 1. : atof(optarg));
00121         break;
00122       case 'i' :
00123         initToZero = true;
00124         break;
00125       case 'y' :
00126         inputClassesString = optarg;
00127         break;
00128       case 'Q' :
00129         transformerFile = optarg;
00130         break;
00131       case 'o' :
00132         tupleFile = optarg;
00133         break;
00134       case 's' :
00135         useStandard = true;
00136         break;
00137       case 'a' :
00138         readMode = (optarg==0 ? 0 : atoi(optarg));
00139         break;
00140       case 'A' :
00141         writeMode = SprRWFactory::Ascii;
00142         break;
00143       case 'v' :
00144         verbose = (optarg==0 ? 0 : atoi(optarg));
00145         break;
00146       case 'f' :
00147         outFile = optarg;
00148         break;
00149       case 'K' :
00150         split = true;
00151         splitFactor = (optarg==0 ? 0 : atof(optarg));
00152         break;
00153       case 'D' :
00154         splitRandomize = true;
00155         break;
00156       case 't' :
00157         valFile = optarg;
00158         break;
00159       case 'p' :
00160         valHbkFile = optarg;
00161         break;
00162       case 'w' :
00163         if( optarg != 0 ) {
00164           scaleWeights = true;
00165           sW = atof(optarg);
00166         }
00167         break;
00168       case 'V' :
00169         includeList = optarg;
00170         break;
00171       case 'z' :
00172         excludeList = optarg;
00173         break;
00174       }
00175   }
00176 
00177   // training file name must be the only argument that appears
00178   // after all options on the command line
00179   string trFile;
00180   if( optind == argc-1 )
00181     trFile = argv[optind];
00182   if( trFile.empty() ) {
00183     cerr << "No training file is specified." << endl;
00184     return 1;
00185   }
00186 
00187   // sanity check
00188   if( fisherMode==0 && !useLogit ) {
00189     cerr << "Neither Fisher nor logistic regression is requested." << endl;
00190     return 1;
00191   }
00192 
00193   // make reader
00194   SprRWFactory::DataType inputType 
00195     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00196   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00197 
00198   // include variables
00199   set<string> includeSet;
00200   if( !includeList.empty() ) {
00201     vector<vector<string> > includeVars;
00202     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00203     assert( !includeVars.empty() );
00204     for( int i=0;i<includeVars[0].size();i++ ) 
00205       includeSet.insert(includeVars[0][i]);
00206     if( !reader->chooseVars(includeSet) ) {
00207       cerr << "Unable to include variables in training set." << endl;
00208       return 2;
00209     }
00210     else {
00211       cout << "Following variables have been included in optimization: ";
00212       for( set<string>::const_iterator 
00213              i=includeSet.begin();i!=includeSet.end();i++ )
00214         cout << "\"" << *i << "\"" << " ";
00215       cout << endl;
00216     }
00217   }
00218 
00219   // exclude variables
00220   set<string> excludeSet;
00221   if( !excludeList.empty() ) {
00222     vector<vector<string> > excludeVars;
00223     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00224     assert( !excludeVars.empty() );
00225     for( int i=0;i<excludeVars[0].size();i++ ) 
00226       excludeSet.insert(excludeVars[0][i]);
00227     if( !reader->chooseAllBut(excludeSet) ) {
00228       cerr << "Unable to exclude variables from training set." << endl;
00229       return 2;
00230     }
00231     else {
00232       cout << "Following variables have been excluded from optimization: ";
00233       for( set<string>::const_iterator 
00234              i=excludeSet.begin();i!=excludeSet.end();i++ )
00235         cout << "\"" << *i << "\"" << " ";
00236       cout << endl;
00237     }
00238   }
00239 
00240   // read training data from file
00241   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00242   if( filter.get() == 0 ) {
00243     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00244     return 2;
00245   }
00246   vector<string> vars;
00247   filter->vars(vars);
00248   cout << "Read data from file " << trFile.c_str() 
00249        << " for variables";
00250   for( int i=0;i<vars.size();i++ ) 
00251     cout << " \"" << vars[i].c_str() << "\"";
00252   cout << endl;
00253   cout << "Total number of points read: " << filter->size() << endl;
00254 
00255   // filter training data by class
00256   vector<SprClass> inputClasses;
00257   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00258     cerr << "Cannot choose input classes for string " 
00259          << inputClassesString << endl;
00260     return 2;
00261   }
00262   filter->classes(inputClasses);
00263   assert( inputClasses.size() > 1 );
00264   cout << "Training data filtered by class." << endl;
00265   for( int i=0;i<inputClasses.size();i++ ) {
00266     cout << "Points in class " << inputClasses[i] << ":   " 
00267          << filter->ptsInClass(inputClasses[i]) << endl;
00268   }
00269 
00270   // scale weights
00271   if( scaleWeights ) {
00272     cout << "Signal weights are multiplied by " << sW << endl;
00273     filter->scaleWeights(inputClasses[1],sW);
00274   }
00275 
00276   // read validation data from file
00277   auto_ptr<SprAbsFilter> valFilter;
00278   if( split && !valFile.empty() ) {
00279     cerr << "Unable to split training data and use validation data " 
00280          << "from a separate file." << endl;
00281     return 2;
00282   }
00283   if( split ) {
00284     cout << "Splitting training data with factor " << splitFactor << endl;
00285     if( splitRandomize )
00286       cout << "Will use randomized splitting." << endl;
00287     vector<double> weights;
00288     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00289     if( splitted == 0 ) {
00290       cerr << "Unable to split training data." << endl;
00291       return 2;
00292     }
00293     bool ownData = true;
00294     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00295     cout << "Training data re-filtered:" << endl;
00296     for( int i=0;i<inputClasses.size();i++ ) {
00297       cout << "Points in class " << inputClasses[i] << ":   " 
00298            << filter->ptsInClass(inputClasses[i]) << endl;
00299     }
00300   }  if( !valFile.empty() ) {
00301     auto_ptr<SprAbsReader> 
00302       valReader(SprRWFactory::makeReader(inputType,readMode));
00303     if( !includeSet.empty() ) {
00304       if( !valReader->chooseVars(includeSet) ) {
00305         cerr << "Unable to include variables in validation set." << endl;
00306         return 2;
00307       }
00308     }
00309     if( !excludeSet.empty() ) {
00310       if( !valReader->chooseAllBut(excludeSet) ) {
00311         cerr << "Unable to exclude variables from validation set." << endl;
00312         return 2;
00313       }
00314     }
00315     valFilter.reset(valReader->read(valFile.c_str()));
00316     if( valFilter.get() == 0 ) {
00317       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00318       return 2;
00319     }
00320     vector<string> valVars;
00321     valFilter->vars(valVars);
00322     cout << "Read validation data from file " << valFile.c_str() 
00323          << " for variables";
00324     for( int i=0;i<valVars.size();i++ ) 
00325       cout << " \"" << valVars[i].c_str() << "\"";
00326     cout << endl;
00327     cout << "Total number of points read: " << valFilter->size() << endl;
00328   }
00329 
00330   // filter validation data by class
00331   if( valFilter.get() != 0 ) {
00332     if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00333       cerr << "Cannot choose input classes for string " 
00334            << inputClassesString << endl;
00335       return 2;
00336     }
00337     valFilter->classes(inputClasses);
00338     cout << "Validation data filtered by class." << endl;
00339     for( int i=0;i<inputClasses.size();i++ ) {
00340       cout << "Points in class " << inputClasses[i] << ":   " 
00341            << valFilter->ptsInClass(inputClasses[i]) << endl;
00342     }
00343   }
00344 
00345   // scale weights
00346   if( scaleWeights && valFilter.get()!=0 )
00347     valFilter->scaleWeights(inputClasses[1],sW);
00348 
00349   // apply transformation of variables to training and test data
00350   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00351   if( !transformerFile.empty() ) {
00352     SprVarTransformerReader transReader;
00353     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00354     if( t == 0 ) {
00355       cerr << "Unable to read VarTransformer from file "
00356            << transformerFile.c_str() << endl;
00357       return 2;
00358     }
00359     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00360     SprTransformerFilter* t_valid = 0;
00361     if( valFilter.get() != 0 )
00362       t_valid = new SprTransformerFilter(valFilter.get());
00363     bool replaceOriginalData = true;
00364     if( !t_train->transform(t,replaceOriginalData) ) {
00365       cerr << "Unable to apply VarTransformer to training data." << endl;
00366       return 2;
00367     }
00368     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00369       cerr << "Unable to apply VarTransformer to validation data." << endl;
00370       return 2;
00371     }
00372     cout << "Variable transformation from file "
00373          << transformerFile.c_str() << " has been applied to "
00374          << "training and validation data." << endl;
00375     garbage_train.reset(filter.release());
00376     garbage_valid.reset(valFilter.release());
00377     filter.reset(t_train);
00378     valFilter.reset(t_valid);
00379   }
00380 
00381   // train Fisher
00382   auto_ptr<SprFisher> fisher;
00383   auto_ptr<SprTrainedFisher> trainedFisher1, trainedFisher2;
00384   bool both = false;
00385   if( fisherMode != 0 ) {
00386     if( fisherMode!=1 && fisherMode!=2 && fisherMode!=3 ) {
00387       cerr << "Unknown mode for Fisher " << fisherMode << endl;
00388       return 3;
00389     }
00390     if( fisherMode == 3 ) {
00391       both = true;
00392       fisherMode = 1;
00393     }
00394     cout << "Initializing Fisher in mode " << fisherMode << endl;
00395     fisher.reset(new SprFisher(filter.get(),fisherMode));
00396     if( !fisher->train(verbose) ) {
00397       cerr << "Unable to train Fisher." << endl;
00398       return 3;
00399     }
00400     else {
00401       cout << "Trained Fisher:" << endl;
00402       fisher->print(cout);
00403     }
00404 
00405     // make a trained Fisher
00406     trainedFisher1.reset(fisher->makeTrained());
00407     if( trainedFisher1.get() == 0 ) {
00408       cerr << "Unable to make a trained Fisher." << endl;
00409       return 4;
00410     }
00411     if( useStandard ) trainedFisher1->useStandard();
00412 
00413     // train another one if necessary
00414     if( both ) {
00415       fisher->setMode(2);
00416       if( !fisher->train(verbose) ) {
00417         cerr << "Unable to train 2nd Fisher." << endl;
00418         return 5;
00419       }
00420       else {
00421         cout << "Trained 2nd Fisher:" << endl;
00422         fisher->print(cout);
00423       }
00424       trainedFisher2.reset(fisher->makeTrained());
00425       if( trainedFisher2.get() == 0 ) {
00426         cerr << "Unable to make a trained 2nd Fisher." << endl;
00427         return 6;
00428       }
00429       if( useStandard ) trainedFisher2->useStandard();
00430     }
00431   }
00432 
00433   // train logistic regression
00434   auto_ptr<SprLogitR> logit;
00435   auto_ptr<SprTrainedLogitR> trainedLogit;
00436   if( useLogit ) {
00437     // init
00438     if( initToZero ) {
00439       SprVector beta(filter->dim());
00440       for( int i=0;i<filter->dim();i++ ) beta[i] = 0;
00441       logit.reset(new SprLogitR(filter.get(),0,beta,eps,updateFactor));
00442     }
00443     else {
00444       logit.reset(new SprLogitR(filter.get(),eps,updateFactor));
00445     }
00446 
00447     // train
00448     if( !logit->train(verbose) ) {
00449       cerr << "Unable to train logistic regression." << endl;
00450       return 7;
00451     }
00452     else {
00453       cout << "Trained Logistic Regression:" << endl;
00454       logit->print(cout);
00455     }
00456 
00457     // make trained logit
00458     trainedLogit.reset(logit->makeTrained());
00459     if( trainedLogit.get() == 0 ) {
00460       cerr << "Unable to make trained logistic regression." << endl;
00461       return 8;
00462     }
00463     if( useStandard ) trainedLogit->useStandard();
00464   }
00465 
00466   // save classifier configuration into file
00467   if( !outFile.empty() ) {
00468     if( both || (fisherMode>0 && useLogit) ) {
00469       cerr << "More than one classifier trained. " 
00470            << "Cannot save classifier configurations to file." << endl;
00471       return 9;
00472     }
00473     SprAbsClassifier* trainable = 0;
00474     if( fisher.get() != 0 ) trainable = fisher.get();
00475     if( logit.get() != 0 ) trainable = logit.get();
00476     assert( trainable != 0 );
00477     if( !trainable->store(outFile.c_str()) ) {
00478       cerr << "Cannot store classifier in file " << outFile.c_str() << endl;
00479       return 9;
00480     }
00481   }
00482 
00483   // make histogram if requested
00484   if( tupleFile.empty() && valHbkFile.empty() ) 
00485     return 0;
00486 
00487   // feed training data
00488   if( !tupleFile.empty() ) {
00489     // make a writer
00490     auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00491     if( !tuple->init(tupleFile.c_str()) ) {
00492       cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00493       return 10;
00494     }
00495     string firstClassifier;
00496     if( trainedFisher2.get()!=0 || fisherMode==1 )
00497       firstClassifier = "lin";
00498     else
00499       firstClassifier = "qua";
00500     // feed
00501     SprDataFeeder feeder(filter.get(),tuple.get());
00502     feeder.addClassifier(trainedFisher1.get(),firstClassifier.c_str());
00503     feeder.addClassifier(trainedFisher2.get(),"qua");
00504     feeder.addClassifier(trainedLogit.get(),"logit");
00505     if( !feeder.feed(1000) ) {
00506       cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00507       return 11;
00508     }
00509   }
00510 
00511   // feed validation data
00512   if( !valHbkFile.empty() ) {
00513     // make a writer
00514     auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"test"));
00515     if( !tuple->init(valHbkFile.c_str()) ) {
00516       cerr << "Unable to open output file " << valHbkFile.c_str() << endl;
00517       return 12;
00518     }
00519     string firstClassifier;
00520     if( trainedFisher2.get()!=0 || fisherMode==1 )
00521       firstClassifier = "lin";
00522     else
00523       firstClassifier = "qua";
00524     // feed
00525     SprDataFeeder feeder(valFilter.get(),tuple.get());
00526     feeder.addClassifier(trainedFisher1.get(),firstClassifier.c_str());
00527     feeder.addClassifier(trainedFisher2.get(),"qua");
00528     feeder.addClassifier(trainedLogit.get(),"logit");
00529     if( !feeder.feed(1000) ) {
00530       cerr << "Cannot feed data into file " << valHbkFile.c_str() << endl;
00531       return 13;
00532     }
00533   }
00534 
00535   // exit
00536   return 0;
00537 }

Generated on Tue Jun 9 17:41:59 2009 for CMSSW by  doxygen 1.5.4