00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00098 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00099 #include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
00100 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00101 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00102 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00103 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00104 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
00105 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerPermutator.hh"
00106
00107 #include <stdlib.h>
00108 #include <unistd.h>
00109 #include <iostream>
00110 #include <vector>
00111 #include <set>
00112 #include <string>
00113 #include <memory>
00114 #include <algorithm>
00115 #include <functional>
00116 #include <utility>
00117 #include <cassert>
00118
00119 using namespace std;
00120
00121
00122 void help(const char* prog)
00123 {
00124 cout << "Usage: " << prog
00125 << " training_data_file" << endl;
00126 cout << "\t Options: " << endl;
00127 cout << "\t-h --- help " << endl;
00128 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00129 cout << "\t-s random seed for permutations (default=0) " << endl;
00130 cout << "\t-n number of cycles for GoF evaluation " << endl;
00131 cout << "\t-l minimal number of entries per tree leaf (def=1) " << endl;
00132 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00133 cout << "\t-w scale all signal weights by this factor " << endl;
00134 cout << "\t-V include only these input variables " << endl;
00135 cout << "\t-z exclude input variables from the list " << endl;
00136 cout << "\t\t Variables must be listed in quotes and separated by commas."
00137 << endl;
00138 }
00139
00140
00141 int main(int argc, char ** argv)
00142 {
00143
00144 if( argc < 2 ) {
00145 help(argv[0]);
00146 return 1;
00147 }
00148
00149
00150 string tupleFile;
00151 int readMode = 0;
00152 unsigned cycles = 0;
00153 unsigned nmin = 1;
00154 int verbose = 0;
00155 string includeList, excludeList;
00156 int seed = 0;
00157 bool scaleWeights = false;
00158 double sW = 1.;
00159
00160
00161 int c;
00162 extern char* optarg;
00163
00164 while( (c = getopt(argc,argv,"ha:s:n:l:v:w:V:z:")) != EOF ) {
00165 switch( c )
00166 {
00167 case 'h' :
00168 help(argv[0]);
00169 return 1;
00170 case 'a' :
00171 readMode = (optarg==0 ? 0 : atoi(optarg));
00172 break;
00173 case 's' :
00174 seed = (optarg==0 ? 0 : atoi(optarg));
00175 break;
00176 case 'n' :
00177 cycles = (optarg==0 ? 1 : atoi(optarg));
00178 break;
00179 case 'l' :
00180 nmin = (optarg==0 ? 1 : atoi(optarg));
00181 break;
00182 case 'v' :
00183 verbose = (optarg==0 ? 0 : atoi(optarg));
00184 break;
00185 case 'w' :
00186 if( optarg != 0 ) {
00187 scaleWeights = true;
00188 sW = atof(optarg);
00189 }
00190 break;
00191 case 'V' :
00192 includeList = optarg;
00193 break;
00194 case 'z' :
00195 excludeList = optarg;
00196 break;
00197 }
00198 }
00199
00200
00201 string trFile = argv[argc-1];
00202 if( trFile.empty() ) {
00203 cerr << "No training file is specified." << endl;
00204 return 1;
00205 }
00206
00207
00208 SprRWFactory::DataType inputType
00209 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00210 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00211
00212
00213 set<string> includeSet;
00214 if( !includeList.empty() ) {
00215 vector<vector<string> > includeVars;
00216 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00217 assert( !includeVars.empty() );
00218 for( int i=0;i<includeVars[0].size();i++ )
00219 includeSet.insert(includeVars[0][i]);
00220 if( !reader->chooseVars(includeSet) ) {
00221 cerr << "Unable to include variables in training set." << endl;
00222 return 2;
00223 }
00224 else {
00225 cout << "Following variables have been included in optimization: ";
00226 for( set<string>::const_iterator
00227 i=includeSet.begin();i!=includeSet.end();i++ )
00228 cout << "\"" << *i << "\"" << " ";
00229 cout << endl;
00230 }
00231 }
00232
00233
00234 set<string> excludeSet;
00235 if( !excludeList.empty() ) {
00236 vector<vector<string> > excludeVars;
00237 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00238 assert( !excludeVars.empty() );
00239 for( int i=0;i<excludeVars[0].size();i++ )
00240 excludeSet.insert(excludeVars[0][i]);
00241 if( !reader->chooseAllBut(excludeSet) ) {
00242 cerr << "Unable to exclude variables from training set." << endl;
00243 return 2;
00244 }
00245 else {
00246 cout << "Following variables have been excluded from optimization: ";
00247 for( set<string>::const_iterator
00248 i=excludeSet.begin();i!=excludeSet.end();i++ )
00249 cout << "\"" << *i << "\"" << " ";
00250 cout << endl;
00251 }
00252 }
00253
00254
00255 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00256 if( filter.get() == 0 ) {
00257 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00258 return 2;
00259 }
00260 vector<string> vars;
00261 filter->vars(vars);
00262 cout << "Read data from file " << trFile.c_str()
00263 << " for variables";
00264 for( int i=0;i<vars.size();i++ )
00265 cout << " \"" << vars[i].c_str() << "\"";
00266 cout << endl;
00267 cout << "Total number of points read: " << filter->size() << endl;
00268 const unsigned n0 = filter->ptsInClass(0);
00269 const unsigned n1 = filter->ptsInClass(1);
00270 cout << "Points in class 0: " << n0 << " 1: " << n1 << endl;
00271
00272
00273 vector<double> origWeights;
00274 if( scaleWeights ) {
00275 filter->weights(origWeights);
00276 cout << "Signal weights are multiplied by " << sW << endl;
00277 filter->scaleWeights(1,sW);
00278 }
00279
00280
00281 SprTwoClassGiniIndex crit;
00282
00283
00284 bool doMerge = false;
00285 SprDecisionTree tree(filter.get(),&crit,nmin,doMerge,true);
00286
00287
00288 if( !tree.train(verbose) ) {
00289 cerr << "Unable to train decision tree." << endl;
00290 return 3;
00291 }
00292 double origFom = tree.fom();
00293
00294
00295 vector<pair<SprPoint*,int> > origLabels(filter->size());
00296 for( int i=0;i<filter->size();i++ ) {
00297 SprPoint* p = (*filter.get())[i];
00298 origLabels[i] = pair<SprPoint*,int>(p,p->class_);
00299 }
00300
00301
00302 cout << "Will perform " << cycles
00303 << " toy experiments for GoF calculation." << endl;
00304 vector<double> fom;
00305 SprIntegerPermutator permu(filter->size(),seed);
00306 assert( (n0+n1) == filter->size() );
00307 for( int ic=0;ic<cycles;ic++ ) {
00308
00309 if( (ic%10) == 0 )
00310 cout << "Performing toy experiment " << ic << endl;
00311
00312
00313 vector<unsigned> labels;
00314 permu.sequence(labels);
00315 for( int i=0;i<n0;i++ ) {
00316 unsigned ip = labels[i];
00317 (*filter.get())[ip]->class_ = 0;
00318 }
00319 for( int i=n0;i<n0+n1;i++ ) {
00320 unsigned ip = labels[i];
00321 (*filter.get())[ip]->class_ = 1;
00322 }
00323
00324
00325 if( scaleWeights ) {
00326 filter->setPermanentWeights(origWeights);
00327 filter->scaleWeights(1,sW);
00328 }
00329
00330
00331 tree.reset();
00332
00333
00334 if( !tree.train(verbose) ) continue;
00335 fom.push_back(tree.fom());
00336 }
00337 if( fom.empty() ) {
00338 cerr << "Failed to compute FOMs for any experiments." << endl;
00339 return 4;
00340 }
00341
00342
00343
00344
00345 for( int i=0;i<origLabels.size();i++ )
00346 origLabels[i].first->class_ = origLabels[i].second;
00347 if( scaleWeights ) filter->setPermanentWeights(origWeights);
00348
00349
00350 stable_sort(fom.begin(),fom.end());
00351 vector<double>::iterator iter = find_if(fom.begin(),fom.end(),
00352 bind2nd(greater<double>(),origFom));
00353 int below = iter - fom.begin();
00354 int above = fom.size() - below;
00355 cout << below << " experiments out of " << fom.size()
00356 << " have better GoF values than the data." << endl;
00357 cout << "GoF=" << double(above)/double(fom.size()) << endl;
00358
00359
00360 return 0;
00361 }