CMS 3D CMS Logo

SprVariableImportanceApp.cc File Reference

#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprCoordinateMapper.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierEvaluator.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <iostream>
#include <set>
#include <vector>
#include <memory>
#include <string>
#include <cassert>
#include <algorithm>
#include <utility>

Go to the source code of this file.

Functions

void help (const char *prog)
int main (int argc, char **argv)


Function Documentation

void help ( const char *  prog  ) 

Definition at line 40 of file SprVariableImportanceApp.cc.

References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().

00041 {
00042   cout << "Usage:  " << prog << " classifier_config_file"
00043        << " input_data_file" << endl;
00044   cout << "\t Options: " << endl;
00045   cout << "\t-h --- help                                        " << endl;
00046   cout << "\t-y list of input classes (see SprAbsFilter.hh)     " << endl;
00047   cout << "\t-Q apply variable transformation saved in file     " << endl;
00048   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)  " << endl;
00049   cout << "\t-k keep the specified fraction in input data       " << endl;
00050   cout << "\t-K keep (1-this_fraction) in input data            " << endl;
00051   cout << "\t\t For consistency with other executables,         " << endl;
00052   cout << "\t\t this option will use \"test\" data to estimate "
00053        << "variable importance." << endl;
00054   cout << "\t-m use multiclass learner                          " << endl;
00055   cout << "\t-n number of class permutations per variable (def=1)"<< endl;
00056   cout << "\t-S subset of variables used to compute interactions" << endl;
00057   cout << "\t\t Variables must be entered in quotes, "
00058        << "separated by commas." << endl;
00059   cout << "\t\t Interactions between this subset and each of the rest " 
00060        << "of variables will be computed." << endl;
00061   cout << "\t\t If an empty list is entered, interaction between "
00062        << "each variable and all other variables will be computed." << endl;
00063   cout << "\t-N number of points used for data integration      " << endl;
00064   cout << "\t\t -N is only used for computation of interactions "
00065        << "between variables. "                                   << endl;
00066   cout << "\t\t The greater the N, the more accurate the estimate." << endl;
00067   cout << "\t\t The smaller the N, the faster you get results."   << endl;
00068   cout << "\t\t N cannot exceed data size."                       << endl;
00069   cout << "\t\t By default all available points in data are used."<< endl;
00070   cout << "\t-v verbose level (0=silent default,1,2)            " << endl;
00071   cout << "\t-w scale all signal weights by this factor         " << endl;
00072   cout << "\t-V include only these input variables              " << endl;
00073   cout << "\t-z exclude input variables from the list           " << endl;
00074   cout << "\t-M map variable lists from trained classifiers onto" << endl;
00075   cout << "\t\t variables available in input data."               << endl;
00076   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00077        << endl;
00078 }

int main ( int  argc,
char **  argv 
)

Definition at line 81 of file SprVariableImportanceApp.cc.

References c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, lat::endl(), filter, first, help(), i, j, s, split, t, vars, and weights.

00082 {
00083   // check command line
00084   if( argc < 3 ) {
00085     help(argv[0]);
00086     return 1;
00087   }
00088 
00089   // init
00090   int readMode = 0;
00091   int verbose = 0;
00092   bool scaleWeights = false;
00093   double sW = 1.;
00094   string includeList, excludeList;
00095   string inputClassesString;
00096   bool mapTrainedVars = false;
00097   int nPerm = 1;
00098   bool useMCLearner = false;
00099   string transformerFile;
00100   unsigned nPoints = 0;
00101   bool computeInteraction = false;
00102   string varList;
00103   bool split = false;
00104   double splitFactor = 0;
00105   bool useTrainingData = false;
00106 
00107   // decode command line
00108   int c;
00109   extern char* optarg;
00110   extern int optind;
00111   while( (c = getopt(argc,argv,"hy:Q:a:k:K:mn:S:N:v:w:V:z:M")) != EOF ) {
00112     switch( c )
00113       {
00114       case 'h' :
00115         help(argv[0]);
00116         return 1;
00117       case 'y' :
00118         inputClassesString = optarg;
00119         break;
00120       case 'Q' :
00121         transformerFile = optarg;
00122         break;
00123       case 'a' :
00124         readMode = (optarg==0 ? 0 : atoi(optarg));
00125         break;
00126       case 'k' :
00127         split = true;
00128         splitFactor = (optarg==0 ? 0 : atof(optarg));
00129         useTrainingData = true;
00130         break;
00131       case 'K' :
00132         split = true;
00133         splitFactor = (optarg==0 ? 0 : atof(optarg));
00134         useTrainingData = false;
00135         break;
00136       case 'm' :
00137         useMCLearner = true;
00138         break;
00139       case 'n' :
00140         nPerm = (optarg==0 ? 1 : atoi(optarg));
00141         break;
00142       case 'S' :
00143         computeInteraction = true;
00144         varList = (optarg==0 ? "" : optarg);
00145         break;
00146       case 'N' :
00147         nPoints = (optarg==0 ? 1 : atoi(optarg));
00148         break;
00149       case 'v' :
00150         verbose = (optarg==0 ? 0 : atoi(optarg));
00151         break;
00152       case 'w' :
00153         if( optarg != 0 ) {
00154           scaleWeights = true;
00155           sW = atof(optarg);
00156         }
00157         break;
00158       case 'V' :
00159         includeList = optarg;
00160         break;
00161       case 'z' :
00162         excludeList = optarg;
00163         break;
00164       case 'M' :
00165         mapTrainedVars = true;
00166         break;
00167       }
00168   }
00169 
00170   // Must have 3 arguments on the command line
00171   string configFile     = argv[argc-2];
00172   string dataFile       = argv[argc-1];
00173   if( configFile.empty() ) {
00174     cerr << "No classifier configuration file is specified." << endl;
00175     return 1;
00176   }
00177   if( dataFile.empty() ) {
00178     cerr << "No input data file is specified." << endl;
00179     return 1;
00180   }
00181 
00182   // make reader
00183   SprRWFactory::DataType inputType 
00184     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00185   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00186 
00187   // include variables
00188   set<string> includeSet;
00189   if( !includeList.empty() ) {
00190     vector<vector<string> > includeVars;
00191     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00192     assert( !includeVars.empty() );
00193     for( int i=0;i<includeVars[0].size();i++ ) 
00194       includeSet.insert(includeVars[0][i]);
00195     if( !reader->chooseVars(includeSet) ) {
00196       cerr << "Unable to include variables in training set." << endl;
00197       return 2;
00198     }
00199     else {
00200       cout << "Following variables have been included in optimization: ";
00201       for( set<string>::const_iterator 
00202              i=includeSet.begin();i!=includeSet.end();i++ )
00203         cout << "\"" << *i << "\"" << " ";
00204       cout << endl;
00205     }
00206   }
00207 
00208   // exclude variables
00209   set<string> excludeSet;
00210   if( !excludeList.empty() ) {
00211     vector<vector<string> > excludeVars;
00212     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00213     assert( !excludeVars.empty() );
00214     for( int i=0;i<excludeVars[0].size();i++ ) 
00215       excludeSet.insert(excludeVars[0][i]);
00216     if( !reader->chooseAllBut(excludeSet) ) {
00217       cerr << "Unable to exclude variables from training set." << endl;
00218       return 2;
00219     }
00220     else {
00221       cout << "Following variables have been excluded from optimization: ";
00222       for( set<string>::const_iterator 
00223              i=excludeSet.begin();i!=excludeSet.end();i++ )
00224         cout << "\"" << *i << "\"" << " ";
00225       cout << endl;
00226     }
00227   }
00228 
00229   // read input data from file
00230   auto_ptr<SprAbsFilter> filter(reader->read(dataFile.c_str()));
00231   if( filter.get() == 0 ) {
00232     cerr << "Unable to read data from file " << dataFile.c_str() << endl;
00233     return 2;
00234   }
00235   vector<string> vars;
00236   filter->vars(vars);
00237   cout << "Read data from file " << dataFile.c_str() << " for variables";
00238   for( int i=0;i<vars.size();i++ ) 
00239     cout << " \"" << vars[i].c_str() << "\"";
00240   cout << endl;
00241   cout << "Total number of points read: " << filter->size() << endl;
00242 
00243   // filter training data by class
00244   vector<SprClass> inputClasses;
00245   if( !filter->filterByClass(inputClassesString.c_str()) ) {
00246     cerr << "Cannot choose input classes for string " 
00247          << inputClassesString << endl;
00248     return 2;
00249   }
00250   filter->classes(inputClasses);
00251   assert( inputClasses.size() > 1 );
00252   cout << "Training data filtered by class." << endl;
00253   for( int i=0;i<inputClasses.size();i++ ) {
00254     cout << "Points in class " << inputClasses[i] << ":   " 
00255          << filter->ptsInClass(inputClasses[i]) << endl;
00256   }
00257 
00258   // scale weights
00259   if( scaleWeights ) {
00260     cout << "Signal weights are multiplied by " << sW << endl;
00261     filter->scaleWeights(inputClasses[1],sW);
00262   }
00263 
00264   // split data
00265   auto_ptr<SprAbsFilter> garbage_split;
00266   if( split ) {
00267     cout << "Splitting input data with factor " << splitFactor << endl;
00268     vector<double> weights;
00269     bool splitRandomize = false;
00270     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00271     if( splitted == 0 ) {
00272       cerr << "Unable to split input data." << endl;
00273       return 2;
00274     }
00275     if( !useTrainingData ) {
00276       garbage_split.reset(filter.release());
00277       bool ownData = true;
00278       filter.reset(new SprEmptyFilter(splitted,weights,ownData));
00279     }
00280     cout << "Input data re-filtered:" << endl;
00281     for( int i=0;i<inputClasses.size();i++ ) {
00282       cout << "Points in class " << inputClasses[i] << ":   " 
00283            << filter->ptsInClass(inputClasses[i]) << endl;
00284     }
00285   }
00286 
00287   // apply transformation of variables to training and test data
00288   auto_ptr<SprAbsFilter> garbage_trans;
00289   if( !transformerFile.empty() ) {
00290     SprVarTransformerReader transReader;
00291     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00292     if( t == 0 ) {
00293       cerr << "Unable to read VarTransformer from file "
00294            << transformerFile.c_str() << endl;
00295       return 2;
00296     }
00297     SprTransformerFilter* tf = new SprTransformerFilter(filter.get());
00298     bool replaceOriginalData = true;
00299     if( !tf->transform(t,replaceOriginalData) ) {
00300       cerr << "Unable to apply VarTransformer to training data." << endl;
00301       return 2;
00302     }
00303     cout << "Variable transformation from file "
00304          << transformerFile.c_str() << " has been applied to "
00305          << "training data." << endl;
00306     garbage_trans.reset(filter.release());
00307     filter.reset(tf);
00308     filter->vars(vars);
00309   }
00310 
00311   // read classifier configuration
00312   auto_ptr<SprAbsTrainedClassifier> trained;
00313   auto_ptr<SprTrainedMultiClassLearner> mcTrained;
00314   if( useMCLearner ) {
00315     SprMultiClassReader multiReader;
00316     if( !multiReader.read(configFile.c_str()) ) {
00317       cerr << "Failed to read saved multi class learner from file "
00318            << configFile.c_str() << endl;
00319       return 3;
00320     }
00321     mcTrained.reset(multiReader.makeTrained());
00322     cout << "Read classifier " << mcTrained->name().c_str()
00323          << " with dimensionality " << mcTrained->dim() << endl;
00324   }
00325   else {
00326     trained.reset(SprClassifierReader::readTrained(configFile.c_str(),
00327                                                    verbose));
00328     if( trained.get() == 0 ) {
00329       cerr << "Unable to read classifier configuration from file "
00330            << configFile.c_str() << endl;
00331       return 3;
00332     }
00333     cout << "Read classifier " << trained->name().c_str()
00334          << " with dimensionality " << trained->dim() << endl;
00335   }
00336 
00337   // get a list of trained variables
00338   vector<string> trainedVars;
00339   if( trained.get() != 0 )
00340     trained->vars(trainedVars);
00341   else
00342     mcTrained->vars(trainedVars);
00343   if( verbose > 0 ) {
00344     cout << "Variables:      " << endl;
00345     for( int j=0;j<trainedVars.size();j++ ) 
00346       cout << trainedVars[j].c_str() << " ";
00347     cout << endl;
00348   }
00349 
00350   // map trained-classifier variables onto data variables
00351   auto_ptr<SprCoordinateMapper> mapper;
00352   if( mapTrainedVars || 
00353       (trained.get()!=0 && trained->name()=="Combiner") ) {
00354     mapper.reset(SprCoordinateMapper::createMapper(trainedVars,vars));
00355     if( mapper.get() == 0 ) {
00356       cerr << "Unable to map trained classifier vars onto data vars." << endl;
00357       return 4;
00358     }
00359   }
00360 
00361   // call evaluator
00362   vector<SprClassifierEvaluator::NameAndValue> lossIncrease, 
00363     interaction(trainedVars.size());
00364   if( !SprClassifierEvaluator::variableImportance(filter.get(),
00365                                                   trained.get(),
00366                                                   mcTrained.get(),
00367                                                   mapper.get(),
00368                                                   nPerm,
00369                                                   lossIncrease) ) {
00370     cerr << "Unable to estimate variable importance." << endl;
00371     return 5;
00372   }
00373   if( computeInteraction &&
00374       !SprClassifierEvaluator::variableInteraction(filter.get(),
00375                                                    trained.get(),
00376                                                    mcTrained.get(),
00377                                                    mapper.get(),
00378                                                    varList.c_str(),
00379                                                    nPoints,
00380                                                    interaction,
00381                                                    verbose) ) {
00382     cerr << "Unable to estimate variable interactions." << endl;
00383     return 6;
00384   }
00385   assert( lossIncrease.size() == interaction.size() );
00386 
00387   //
00388   // process computed loss
00389   //
00390   cout << "==============================================================================================================================" << endl;
00391   if( computeInteraction ) {
00392     cout << "Displaying interactions with variable block " 
00393          << varList.c_str() << endl;
00394   }
00395   char t [200];
00396   sprintf(t,"%35s        %15s                      %15s","Variable",
00397           "Change in loss","Interaction");
00398   cout << t << endl;
00399   for( int d=0;d<lossIncrease.size();d++ ) {
00400     char s [200];
00401     sprintf(s,"%35s      %15.10f +- %15.10f      %15.10f +- %15.10f",
00402             lossIncrease[d].first.c_str(),
00403             lossIncrease[d].second.first,lossIncrease[d].second.second,
00404             interaction[d].second.first,interaction[d].second.second);
00405     cout << s << endl;
00406   }
00407   cout << "==============================================================================================================================" << endl;
00408 
00409   // exit
00410   return 0;
00411 }


Generated on Tue Jun 9 17:55:01 2009 for CMSSW by  doxygen 1.5.4