00001
00002
00003
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprCombiner.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprStdBackprop.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00026 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00027 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00028
00029
00030 #include <stdlib.h>
00031 #include <unistd.h>
00032 #include <iostream>
00033 #include <fstream>
00034 #include <sstream>
00035 #include <vector>
00036 #include <memory>
00037 #include <string>
00038 #include <cassert>
00039 #include <map>
00040 #include <utility>
00041
00042 using namespace std;
00043
00044
00045 void help(const char* prog)
00046 {
00047 cout << "Usage: " << prog
00048 << " list_of_input_config_subclassifier_files"
00049 << " input_config_file_for_global_classifier"
00050 << " input_data_file" << endl;
00051 cout << "\t Options: " << endl;
00052 cout << "\t-h --- help " << endl;
00053 cout << "\t-y list of input classes (see SprAbsFilter.hh) " << endl;
00054 cout << "\t-Q apply variable transformation saved in file " << endl;
00055 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00056 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00057 cout << "\t-w scale all signal weights by this factor " << endl;
00058 cout << "\t-f save trained classifier configuration to file " << endl;
00059 cout << "\t-K keep this fraction in training set and " << endl;
00060 cout << "\t\t put the rest into validation set " << endl;
00061 cout << "\t-D randomize training set split-up " << endl;
00062 cout << "\t-t read validation/test data from a file " << endl;
00063 cout << "\t\t (must be in same format as input data!!! " << endl;
00064 cout << "\t-d frequency of print-outs for validation data " << endl;
00065 }
00066
00067
00068 void prepareExit(vector<SprAbsTwoClassCriterion*>& criteria,
00069 vector<SprIntegerBootstrap*>& bstraps,
00070 vector<SprAbsClassifier*>& classifiers)
00071 {
00072 for( int i=0;i<criteria.size();i++ ) delete criteria[i];
00073 for( int i=0;i<classifiers.size();i++ ) delete classifiers[i];
00074 for( int i=0;i<bstraps.size();i++ ) delete bstraps[i];
00075 }
00076
00077
00078 int main(int argc, char ** argv)
00079 {
00080
00081 if( argc < 4 ) {
00082 help(argv[0]);
00083 return 1;
00084 }
00085
00086
00087 int readMode = 0;
00088 int verbose = 0;
00089 bool scaleWeights = false;
00090 double sW = 1.;
00091 bool useStandard = false;
00092 string inputClassesString;
00093 string valFile;
00094 unsigned valPrint = 0;
00095 string outFile;
00096 bool split = false;
00097 double splitFactor = 0;
00098 bool splitRandomize = false;
00099 string transformerFile;
00100
00101
00102 int c;
00103 extern char* optarg;
00104 extern int optind;
00105 while( (c = getopt(argc,argv,"hy:a:v:w:f:K:Dt:d:")) != EOF ) {
00106 switch( c )
00107 {
00108 case 'h' :
00109 help(argv[0]);
00110 return 1;
00111 case 'y' :
00112 inputClassesString = optarg;
00113 break;
00114 case 'Q' :
00115 transformerFile = optarg;
00116 break;
00117 case 'a' :
00118 readMode = (optarg==0 ? 0 : atoi(optarg));
00119 break;
00120 case 'v' :
00121 verbose = (optarg==0 ? 0 : atoi(optarg));
00122 break;
00123 case 'w' :
00124 if( optarg != 0 ) {
00125 scaleWeights = true;
00126 sW = atof(optarg);
00127 }
00128 break;
00129 case 'f' :
00130 outFile = optarg;
00131 break;
00132 case 'K' :
00133 split = true;
00134 splitFactor = (optarg==0 ? 0 : atof(optarg));
00135 break;
00136 case 'D' :
00137 splitRandomize = true;
00138 break;
00139 case 't' :
00140 valFile = optarg;
00141 break;
00142 case 'd' :
00143 valPrint = (optarg==0 ? 0 : atoi(optarg));
00144 break;
00145 }
00146 }
00147
00148
00149 string trainFile = argv[argc-1];
00150 if( trainFile.empty() ) {
00151 cerr << "No input data file is specified." << endl;
00152 return 1;
00153 }
00154 cout << "Will read input data from file " << trainFile.c_str() << endl;
00155 string configFile = argv[argc-2];
00156 if( configFile.empty() ) {
00157 cerr << "No config file for the global classifier specified." << endl;
00158 return 1;
00159 }
00160 cout << "Will read global classifier config from file "
00161 << configFile.c_str() << endl;
00162 string subConfigList = argv[argc-3];
00163 if( subConfigList.empty() ) {
00164 cerr << "No config file list found for sub-classifiers." << endl;
00165 return 1;
00166 }
00167 cout << "Will read sub-classifier configs from files "
00168 << subConfigList.c_str() << endl;
00169
00170
00171 if( subConfigList.empty() || configFile.empty() ) {
00172 cerr << "User must specify combiner configuration." << endl;
00173 return 1;
00174 }
00175
00176
00177 vector<vector<string> > subConfigFiles;
00178 SprStringParser::parseToStrings(subConfigList.c_str(),subConfigFiles);
00179 bool useSubConfig
00180 = ( !subConfigFiles.empty() && !subConfigFiles[0].empty() );
00181 if( !useSubConfig ) {
00182 cerr << "Unable to process list of sub-classifier config files." << endl;
00183 return 1;
00184 }
00185 int nTrained = subConfigFiles[0].size();
00186
00187
00188 SprRWFactory::DataType inputType
00189 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00190 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00191
00192
00193 auto_ptr<SprAbsFilter> filter(reader->read(trainFile.c_str()));
00194 if( filter.get() == 0 ) {
00195 cerr << "Unable to read data from file " << trainFile.c_str() << endl;
00196 return 2;
00197 }
00198 vector<string> vars;
00199 filter->vars(vars);
00200 cout << "Read data from file " << trainFile.c_str() << " for variables";
00201 for( int i=0;i<vars.size();i++ )
00202 cout << " \"" << vars[i].c_str() << "\"";
00203 cout << endl;
00204 cout << "Total number of points read: " << filter->size() << endl;
00205
00206
00207 vector<SprClass> inputClasses;
00208 if( !filter->filterByClass(inputClassesString.c_str()) ) {
00209 cerr << "Cannot choose input classes for string "
00210 << inputClassesString << endl;
00211 return 2;
00212 }
00213 filter->classes(inputClasses);
00214 assert( inputClasses.size() > 1 );
00215 cout << "Training data filtered by class." << endl;
00216 for( int i=0;i<inputClasses.size();i++ ) {
00217 cout << "Points in class " << inputClasses[i] << ": "
00218 << filter->ptsInClass(inputClasses[i]) << endl;
00219 }
00220
00221
00222 if( scaleWeights ) {
00223 cout << "Signal weights are multiplied by " << sW << endl;
00224 filter->scaleWeights(inputClasses[1],sW);
00225 }
00226
00227
00228 auto_ptr<SprAbsFilter> valFilter;
00229 if( split && !valFile.empty() ) {
00230 cerr << "Unable to split training data and use validation data "
00231 << "from a separate file." << endl;
00232 return 2;
00233 }
00234 if( split ) {
00235 cout << "Splitting training data with factor " << splitFactor << endl;
00236 if( splitRandomize )
00237 cout << "Will use randomized splitting." << endl;
00238 vector<double> weights;
00239 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00240 if( splitted == 0 ) {
00241 cerr << "Unable to split training data." << endl;
00242 return 2;
00243 }
00244 bool ownData = true;
00245 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00246 cout << "Training data re-filtered:" << endl;
00247 for( int i=0;i<inputClasses.size();i++ ) {
00248 cout << "Points in class " << inputClasses[i] << ": "
00249 << filter->ptsInClass(inputClasses[i]) << endl;
00250 }
00251 }
00252 if( !valFile.empty() ) {
00253
00254 auto_ptr<SprAbsReader>
00255 valReader(SprRWFactory::makeReader(inputType,readMode));
00256
00257
00258 valFilter.reset(valReader->read(valFile.c_str()));
00259 if( valFilter.get() == 0 ) {
00260 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00261 return 2;
00262 }
00263 vector<string> valVars;
00264 valFilter->vars(valVars);
00265 cout << "Read data from file " << valFile.c_str() << " for variables";
00266 for( int i=0;i<valVars.size();i++ )
00267 cout << " \"" << valVars[i].c_str() << "\"";
00268 cout << endl;
00269 cout << "Total number of points read: " << valFilter->size() << endl;
00270
00271
00272 if( !valFilter->filterByClass(inputClassesString.c_str()) ) {
00273 cerr << "Cannot choose input classes for string "
00274 << inputClassesString << endl;
00275 return 2;
00276 }
00277 valFilter->classes(inputClasses);
00278 assert( inputClasses.size() > 1 );
00279 cout << "Validation data filtered by class." << endl;
00280 for( int i=0;i<inputClasses.size();i++ ) {
00281 cout << "Points in class " << inputClasses[i] << ": "
00282 << valFilter->ptsInClass(inputClasses[i]) << endl;
00283 }
00284
00285
00286 if( scaleWeights ) {
00287 cout << "Signal weights are multiplied by " << sW << endl;
00288 valFilter->scaleWeights(inputClasses[1],sW);
00289 }
00290 }
00291
00292
00293 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00294 if( !transformerFile.empty() ) {
00295 SprVarTransformerReader transReader;
00296 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00297 if( t == 0 ) {
00298 cerr << "Unable to read VarTransformer from file "
00299 << transformerFile.c_str() << endl;
00300 return 2;
00301 }
00302 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00303 SprTransformerFilter* t_valid = 0;
00304 if( valFilter.get() != 0 )
00305 t_valid = new SprTransformerFilter(valFilter.get());
00306 bool replaceOriginalData = true;
00307 if( !t_train->transform(t,replaceOriginalData) ) {
00308 cerr << "Unable to apply VarTransformer to training data." << endl;
00309 return 2;
00310 }
00311 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00312 cerr << "Unable to apply VarTransformer to validation data." << endl;
00313 return 2;
00314 }
00315 cout << "Variable transformation from file "
00316 << transformerFile.c_str() << " has been applied to "
00317 << "training and validation data." << endl;
00318 garbage_train.reset(filter.release());
00319 garbage_valid.reset(valFilter.release());
00320 filter.reset(t_train);
00321 valFilter.reset(t_valid);
00322 }
00323
00324
00325
00326
00327 SprCombiner combiner(filter.get());
00328
00329
00330
00331
00332 for( int ic=0;ic<nTrained;ic++ ) {
00333
00334
00335 string fname = subConfigFiles[0][ic];
00336 ifstream file(fname.c_str());
00337 if( !file ) {
00338 cerr << "Unable to open file " << fname.c_str() << endl;
00339 return 3;
00340 }
00341 cout << "Reading classifier configuration from file "
00342 << fname.c_str() << endl;
00343
00344
00345 string line;
00346 unsigned nLine = 1;
00347 if( !getline(file,line) ) {
00348 cerr << "Cannot read line " << nLine
00349 << " from file " << fname.c_str() << endl;
00350 return 3;
00351 }
00352 string pathToConfig, dummy;
00353 istringstream istpath(line);
00354 istpath >> dummy >> pathToConfig;
00355 if( pathToConfig.empty() ) {
00356 cerr << "Path to classifier not specified in file "
00357 << fname.c_str() << endl;
00358 }
00359
00360
00361 nLine++;
00362 if( !getline(file,line) ) {
00363 cerr << "Cannot read line " << nLine
00364 << " from file " << fname.c_str() << endl;
00365 return 3;
00366 }
00367 string subName;
00368 istringstream istname(line);
00369 istname >> dummy >> subName;
00370 if( subName.empty() ) {
00371 cout << "Name for classifier " << ic << " not specified."
00372 << " Will use the default." << endl;
00373 }
00374
00375
00376 nLine++;
00377 if( !getline(file,line) ) {
00378 cerr << "Cannot read line " << nLine
00379 << " from file " << fname.c_str() << endl;
00380 return 3;
00381 }
00382 double defaultValue = 0;
00383 istringstream istdefault(line);
00384 istdefault >> dummy >> defaultValue;
00385 cout << "Will use default response " << defaultValue
00386 << " for classifier " << ic << endl;
00387
00388
00389 nLine++;
00390 if( !getline(file,line) ) {
00391 cerr << "Cannot read line " << nLine
00392 << " from file " << fname.c_str() << endl;
00393 return 3;
00394 }
00395 unsigned nConstraints = 0;
00396 istringstream istconst(line);
00397 istconst >> dummy >> nConstraints;
00398 cout << "Will use " << nConstraints << " constraints "
00399 << "for classifier " << ic << endl;
00400
00401
00402 map<string,SprCut> constraints;
00403 for( int j=0;j<nConstraints;j++ ) {
00404 nLine++;
00405 if( !getline(file,line) ) {
00406 cerr << "Cannot read line " << nLine
00407 << " from file " << fname.c_str() << endl;
00408 return 3;
00409 }
00410 istringstream ist(line);
00411 string varName;
00412 unsigned nCut = 0;
00413 ist >> varName >> nCut;
00414 if( varName.empty() ) {
00415 cerr << "Unable to read variable name on line " << nLine
00416 << " in file " << fname.c_str() << endl;
00417 }
00418 SprCut cut;
00419 double xa(0), xb(0);
00420 for( unsigned k=0;k<nCut;k++ ) {
00421 ist >> xa >> xb;
00422 cut.push_back(SprInterval(xa,xb));
00423 }
00424 cout << "Applying constraint on variable " << varName.c_str()
00425 << " for classifier " << ic << " : ";
00426 for( int k=0;k<cut.size();k++ )
00427 cout << cut[k].first << " " << cut[k].second << " | ";
00428 cout << endl;
00429 constraints.insert(pair<const string,SprCut>(varName,cut));
00430 }
00431
00432
00433 SprAbsTrainedClassifier* trained
00434 = SprClassifierReader::readTrained(pathToConfig.c_str(),verbose);
00435 if( trained == 0 ) {
00436 cerr << "Unable to read classifier configuration from file "
00437 << pathToConfig.c_str() << endl;
00438 return 3;
00439 }
00440 cout << "Read classifier " << trained->name().c_str()
00441 << " with dimensionality " << trained->dim() << endl;
00442
00443
00444 vector<string> trainedVars;
00445 trained->vars(trainedVars);
00446 cout << "Variables: " << endl;
00447 for( int j=0;j<trainedVars.size();j++ )
00448 cout << trainedVars[j].c_str() << " ";
00449 cout << endl;
00450
00451
00452 bool ownTrained = true;
00453 if( !combiner.addTrained(trained,subName.c_str(),constraints,
00454 defaultValue,ownTrained) ) {
00455 cerr << "Unable to add trained classifier " << ic
00456 << " to combiner." << endl;
00457 return 3;
00458 }
00459 }
00460
00461
00462 if( !combiner.closeClassifierList() ) {
00463 cerr << "Unable to close the trained classifier list for the combiner."
00464 << endl;
00465 return 4;
00466 }
00467 SprEmptyFilter* features = combiner.features();
00468
00469
00470
00471
00472 ifstream file(configFile.c_str());
00473 if( !file ) {
00474 cerr << "Unable to open file " << configFile.c_str() << endl;
00475 return 5;
00476 }
00477 cout << "Reading classifier configuration from file "
00478 << configFile.c_str() << endl;
00479 unsigned nLine = 0;
00480 bool discreteTree = false;
00481 bool mixedNodesTree = false;
00482 bool fastSort = false;
00483 bool readOneEntry = true;
00484 vector<SprAbsTwoClassCriterion*> crits;
00485 vector<SprIntegerBootstrap*> bstraps;
00486 vector<SprAbsClassifier*> destroyC;
00487 vector<SprCCPair> useC;
00488 if( !SprClassifierReader::readTrainableConfig(file,nLine,features,
00489 discreteTree,mixedNodesTree,
00490 fastSort,crits,
00491 bstraps,destroyC,
00492 useC,readOneEntry) ) {
00493 cerr << "Unable to read trainable classifier config from file "
00494 << configFile.c_str() << endl;
00495 prepareExit(crits,bstraps,destroyC);
00496 return 5;
00497 }
00498 SprAbsClassifier* trainable = useC[0].first;
00499 cout << "Setting trainable classifier for combiner to "
00500 << trainable->name() << endl;
00501 combiner.setTrainable(trainable);
00502
00503
00504 auto_ptr<SprAverageLoss> loss;
00505 if( valFilter.get()!=0 && valPrint>0 ) {
00506 string trainableName = trainable->name();
00507 if( trainableName=="AdaBoost" || trainableName=="Bagger"
00508 || trainableName=="ArcE4" || trainableName=="StdBackprop" ) {
00509 cout << "For simplicity only quadratic loss can be displayed." << endl;
00510 if( trainableName=="AdaBoost" ) {
00511 loss.reset(new SprAverageLoss(&SprLoss::quadratic,
00512 &SprTransformation::logit));
00513 }
00514 else {
00515 loss.reset(new SprAverageLoss(&SprLoss::quadratic));
00516 }
00517 if( trainableName=="AdaBoost" ) {
00518 if( !static_cast<SprAdaBoost*>(trainable)
00519 ->setValidation(features,valPrint,loss.get()) ) {
00520 cerr << "Unable to set validation loss." << endl;
00521 return 6;
00522 }
00523 }
00524 else if( trainableName=="Bagger" || trainableName=="ArcE4" ) {
00525 if( !static_cast<SprBagger*>(trainable)
00526 ->setValidation(features,valPrint,0,loss.get()) ) {
00527 cerr << "Unable to set validation loss." << endl;
00528 return 6;
00529 }
00530 }
00531 else if( trainableName=="StdBackprop" ) {
00532 if( !static_cast<SprStdBackprop*>(trainable)
00533 ->setValidation(features,valPrint,loss.get()) ) {
00534 cerr << "Unable to set validation loss." << endl;
00535 return 6;
00536 }
00537 }
00538 }
00539 }
00540
00541
00542 if( !combiner.train(verbose) ) {
00543 cerr << "Combiner finished with error." << endl;
00544 prepareExit(crits,bstraps,destroyC);
00545 return 7;
00546 }
00547
00548
00549 if( !outFile.empty() ) {
00550 if( !combiner.store(outFile.c_str()) ) {
00551 cerr << "Cannot store Combiner to file " << outFile.c_str() << endl;
00552 prepareExit(crits,bstraps,destroyC);
00553 return 8;
00554 }
00555 }
00556
00557
00558 prepareExit(crits,bstraps,destroyC);
00559 return 0;
00560 }