#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassLearner.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassPlotter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <memory>
Go to the source code of this file.
Functions | |
void | help (const char *prog) |
int | main (int argc, char **argv) |
void | prepareExit (vector< SprAbsTwoClassCriterion * > &criteria, vector< SprAbsClassifier * > &classifiers, vector< SprIntegerBootstrap * > &bstraps) |
void help | ( | const char * | prog | ) |
Definition at line 54 of file SprMultiClassApp.cc.
References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().
00055 { 00056 cout << "Usage: " << prog << " training_data_file" << endl; 00057 cout << "\t Options: " << endl; 00058 cout << "\t-h --- help " << endl; 00059 cout << "\t-o output Tuple file " << endl; 00060 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl; 00061 cout << "\t-A save output data in ascii instead of Root " << endl; 00062 cout << "\t-y list of input classes " << endl; 00063 cout << "\t\t Classes must be listed in quotes and separated by commas." 00064 << endl; 00065 cout << "\t-Q apply variable transformation saved in file " << endl; 00066 cout << "\t-e Multi class mode " << endl; 00067 cout << "\t\t 1 - OneVsAll (default) " << endl; 00068 cout << "\t\t 2 - OneVsOne " << endl; 00069 cout << "\t\t 3 - user-defined (must use -i option) " << endl; 00070 cout << "\t-i input file with user-defined indicator matrix " << endl; 00071 cout << "\t-c file with trainable classifier configurations " << endl; 00072 cout << "\t-g per-event loss to be displayed for each input class " << endl; 00073 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl; 00074 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl; 00075 cout << "\t-m replace data values below this cutoff with medians " << endl; 00076 cout << "\t-v verbose level (0=silent default,1,2) " << endl; 00077 cout << "\t-f store trained multi class learner to file " << endl; 00078 cout << "\t-r read multi class learner configuration stored in file" << endl; 00079 cout << "\t-K keep this fraction in training set and " << endl; 00080 cout << "\t\t put the rest into validation set " << endl; 00081 cout << "\t-D randomize training set split-up " << endl; 00082 cout << "\t-t read validation/test data from a file " << endl; 00083 cout << "\t\t (must be in same format as input data!!! " << endl; 00084 cout << "\t-V include only these input variables " << endl; 00085 cout << "\t-z exclude input variables from the list " << endl; 00086 cout << "\t-Z exclude input variables from the list, " 00087 << "but put them in the output file " << endl; 00088 cout << "\t\t Variables must be listed in quotes and separated by commas." 00089 << endl; 00090 }
Definition at line 103 of file SprMultiClassApp.cc.
References begin, c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, end, lat::endl(), file, filter, find(), help(), i, j, output(), p, prepareExit(), s, size, split, t, pyDBSRunClass::temp, vars, w, and weights.
00104 { 00105 // check command line 00106 if( argc < 2 ) { 00107 help(argv[0]); 00108 return 1; 00109 } 00110 00111 // init 00112 string tupleFile; 00113 int readMode = 0; 00114 SprRWFactory::DataType writeMode = SprRWFactory::Root; 00115 int verbose = 0; 00116 string outFile; 00117 string resumeFile; 00118 string configFile; 00119 string valFile; 00120 bool scaleWeights = false; 00121 double sW = 1.; 00122 bool setLowCutoff = false; 00123 double lowCutoff = 0; 00124 string includeList, excludeList; 00125 string inputClassesString; 00126 int iLoss = 1; 00127 int iMode = 1; 00128 string indicatorFile; 00129 string stringVarsDoNotFeed; 00130 bool split = false; 00131 double splitFactor = 0; 00132 bool splitRandomize = false; 00133 string transformerFile; 00134 00135 // decode command line 00136 int c; 00137 extern char* optarg; 00138 // extern int optind; 00139 while( (c = getopt(argc,argv,"ho:a:Ay:Q:e:i:c:g:m:v:f:r:K:Dt:V:z:Z:")) != EOF ) { 00140 switch( c ) 00141 { 00142 case 'h' : 00143 help(argv[0]); 00144 return 1; 00145 case 'o' : 00146 tupleFile = optarg; 00147 break; 00148 case 'a' : 00149 readMode = (optarg==0 ? 0 : atoi(optarg)); 00150 break; 00151 case 'A' : 00152 writeMode = SprRWFactory::Ascii; 00153 break; 00154 case 'y' : 00155 inputClassesString = optarg; 00156 break; 00157 case 'Q' : 00158 transformerFile = optarg; 00159 break; 00160 case 'e' : 00161 iMode = (optarg==0 ? 1 : atoi(optarg)); 00162 break; 00163 case 'i' : 00164 indicatorFile = optarg; 00165 break; 00166 case 'c' : 00167 configFile = optarg; 00168 break; 00169 case 'g' : 00170 iLoss = (optarg==0 ? 1 : atoi(optarg)); 00171 break; 00172 case 'm' : 00173 if( optarg != 0 ) { 00174 setLowCutoff = true; 00175 lowCutoff = atof(optarg); 00176 } 00177 break; 00178 case 'v' : 00179 verbose = (optarg==0 ? 0 : atoi(optarg)); 00180 break; 00181 case 'f' : 00182 outFile = optarg; 00183 break; 00184 case 'r' : 00185 resumeFile = optarg; 00186 break; 00187 case 'K' : 00188 split = true; 00189 splitFactor = (optarg==0 ? 0 : atof(optarg)); 00190 break; 00191 case 'D' : 00192 splitRandomize = true; 00193 break; 00194 case 't' : 00195 valFile = optarg; 00196 break; 00197 case 'w' : 00198 if( optarg != 0 ) { 00199 scaleWeights = true; 00200 sW = atof(optarg); 00201 } 00202 break; 00203 case 'V' : 00204 includeList = optarg; 00205 break; 00206 case 'z' : 00207 excludeList = optarg; 00208 break; 00209 case 'Z' : 00210 stringVarsDoNotFeed = optarg; 00211 break; 00212 } 00213 } 00214 00215 // sanity check 00216 if( configFile.empty() && resumeFile.empty()) { 00217 cerr << "No classifier configuration file specified." << endl; 00218 return 1; 00219 } 00220 if( !configFile.empty() && !resumeFile.empty() ) { 00221 cerr << "Cannot train and use saved configuration at the same time." << endl; 00222 return 1; 00223 } 00224 00225 // Must have 2 arguments after all options. 00226 string trFile = argv[argc-1]; 00227 if( trFile.empty() ) { 00228 cerr << "No training file is specified." << endl; 00229 return 1; 00230 } 00231 00232 // make reader 00233 SprRWFactory::DataType inputType 00234 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii ); 00235 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode)); 00236 00237 // include variables 00238 set<string> includeSet; 00239 if( !includeList.empty() ) { 00240 vector<vector<string> > includeVars; 00241 SprStringParser::parseToStrings(includeList.c_str(),includeVars); 00242 assert( !includeVars.empty() ); 00243 for( int i=0;i<includeVars[0].size();i++ ) 00244 includeSet.insert(includeVars[0][i]); 00245 if( !reader->chooseVars(includeSet) ) { 00246 cerr << "Unable to include variables in training set." << endl; 00247 return 2; 00248 } 00249 else { 00250 cout << "Following variables have been included in optimization: "; 00251 for( set<string>::const_iterator 00252 i=includeSet.begin();i!=includeSet.end();i++ ) 00253 cout << "\"" << *i << "\"" << " "; 00254 cout << endl; 00255 } 00256 } 00257 00258 // exclude variables 00259 set<string> excludeSet; 00260 if( !excludeList.empty() ) { 00261 vector<vector<string> > excludeVars; 00262 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars); 00263 assert( !excludeVars.empty() ); 00264 for( int i=0;i<excludeVars[0].size();i++ ) 00265 excludeSet.insert(excludeVars[0][i]); 00266 if( !reader->chooseAllBut(excludeSet) ) { 00267 cerr << "Unable to exclude variables from training set." << endl; 00268 return 2; 00269 } 00270 else { 00271 cout << "Following variables have been excluded from optimization: "; 00272 for( set<string>::const_iterator 00273 i=excludeSet.begin();i!=excludeSet.end();i++ ) 00274 cout << "\"" << *i << "\"" << " "; 00275 cout << endl; 00276 } 00277 } 00278 00279 // read training data from file 00280 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str())); 00281 if( filter.get() == 0 ) { 00282 cerr << "Unable to read data from file " << trFile.c_str() << endl; 00283 return 2; 00284 } 00285 vector<string> vars; 00286 filter->vars(vars); 00287 cout << "Read data from file " << trFile.c_str() 00288 << " for variables"; 00289 for( int i=0;i<vars.size();i++ ) 00290 cout << " \"" << vars[i].c_str() << "\""; 00291 cout << endl; 00292 cout << "Total number of points read: " << filter->size() << endl; 00293 00294 // decode input classes 00295 if( inputClassesString.empty() ) { 00296 cerr << "No input classes specified." << endl; 00297 return 2; 00298 } 00299 vector<vector<int> > inputIntClasses; 00300 SprStringParser::parseToInts(inputClassesString.c_str(),inputIntClasses); 00301 if( inputIntClasses.empty() || inputIntClasses[0].size()<2 ) { 00302 cerr << "Found less than 2 classes in the input class string." << endl; 00303 return 2; 00304 } 00305 vector<SprClass> inputClasses(inputIntClasses[0].size()); 00306 for( int i=0;i<inputIntClasses[0].size();i++ ) 00307 inputClasses[i] = inputIntClasses[0][i]; 00308 00309 // filter training data by class 00310 filter->chooseClasses(inputClasses); 00311 if( !filter->filter() ) { 00312 cerr << "Unable to filter training data by class." << endl; 00313 return 2; 00314 } 00315 cout << "Training data filtered by class." << endl; 00316 for( int i=0;i<inputClasses.size();i++ ) { 00317 unsigned npts = filter->ptsInClass(inputClasses[i]); 00318 if( npts == 0 ) { 00319 cerr << "Error!!! No points in class " << inputClasses[i] << endl; 00320 return 2; 00321 } 00322 cout << "Points in class " << inputClasses[i] << ": " << npts << endl; 00323 } 00324 00325 // apply low cutoff 00326 if( setLowCutoff ) { 00327 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00328 cerr << "Unable to replace missing values in training data." << endl; 00329 return 2; 00330 } 00331 else 00332 cout << "Values below " << lowCutoff << " in training data" 00333 << " have been replaced with medians." << endl; 00334 } 00335 00336 // read validation data from file 00337 auto_ptr<SprAbsFilter> valFilter; 00338 if( split && !valFile.empty() ) { 00339 cerr << "Unable to split training data and use validation data " 00340 << "from a separate file." << endl; 00341 return 2; 00342 } 00343 if( split ) { 00344 cout << "Splitting training data with factor " << splitFactor << endl; 00345 if( splitRandomize ) 00346 cout << "Will use randomized splitting." << endl; 00347 vector<double> weights; 00348 SprData* splitted = filter->split(splitFactor,weights,splitRandomize); 00349 if( splitted == 0 ) { 00350 cerr << "Unable to split training data." << endl; 00351 return 2; 00352 } 00353 bool ownData = true; 00354 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData)); 00355 cout << "Training data re-filtered:" << endl; 00356 for( int i=0;i<inputClasses.size();i++ ) { 00357 cout << "Points in class " << inputClasses[i] << ": " 00358 << filter->ptsInClass(inputClasses[i]) << endl; 00359 } 00360 } 00361 if( !valFile.empty() ) { 00362 auto_ptr<SprAbsReader> 00363 valReader(SprRWFactory::makeReader(inputType,readMode)); 00364 if( !includeSet.empty() ) { 00365 if( !valReader->chooseVars(includeSet) ) { 00366 cerr << "Unable to include variables in validation set." << endl; 00367 return 2; 00368 } 00369 } 00370 if( !excludeSet.empty() ) { 00371 if( !valReader->chooseAllBut(excludeSet) ) { 00372 cerr << "Unable to exclude variables from validation set." << endl; 00373 return 2; 00374 } 00375 } 00376 valFilter.reset(valReader->read(valFile.c_str())); 00377 if( valFilter.get() == 0 ) { 00378 cerr << "Unable to read data from file " << valFile.c_str() << endl; 00379 return 2; 00380 } 00381 vector<string> valVars; 00382 valFilter->vars(valVars); 00383 cout << "Read validation data from file " << valFile.c_str() 00384 << " for variables"; 00385 for( int i=0;i<valVars.size();i++ ) 00386 cout << " \"" << valVars[i].c_str() << "\""; 00387 cout << endl; 00388 cout << "Total number of points read: " << valFilter->size() << endl; 00389 } 00390 00391 // filter validation data by class 00392 if( valFilter.get() != 0 ) { 00393 valFilter->chooseClasses(inputClasses); 00394 if( !valFilter->filter() ) { 00395 cerr << "Unable to filter validation data by class." << endl; 00396 return 2; 00397 } 00398 cout << "Validation data filtered by class." << endl; 00399 for( int i=0;i<inputClasses.size();i++ ) { 00400 unsigned npts = valFilter->ptsInClass(inputClasses[i]); 00401 if( npts == 0 ) 00402 cerr << "Warning!!! No points in class " << inputClasses[i] << endl; 00403 cout << "Points in class " << inputClasses[i] << ": " << npts << endl; 00404 } 00405 } 00406 00407 // apply low cutoff 00408 if( setLowCutoff && valFilter.get()!=0 ) { 00409 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) { 00410 cerr << "Unable to replace missing values in validation data." << endl; 00411 return 2; 00412 } 00413 else 00414 cout << "Values below " << lowCutoff << " in validation data" 00415 << " have been replaced with medians." << endl; 00416 } 00417 00418 // apply transformation of variables to training and test data 00419 auto_ptr<SprAbsFilter> garbage_train, garbage_valid; 00420 if( !transformerFile.empty() ) { 00421 SprVarTransformerReader transReader; 00422 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str()); 00423 if( t == 0 ) { 00424 cerr << "Unable to read VarTransformer from file " 00425 << transformerFile.c_str() << endl; 00426 return 2; 00427 } 00428 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get()); 00429 SprTransformerFilter* t_valid = 0; 00430 if( valFilter.get() != 0 ) 00431 t_valid = new SprTransformerFilter(valFilter.get()); 00432 bool replaceOriginalData = true; 00433 if( !t_train->transform(t,replaceOriginalData) ) { 00434 cerr << "Unable to apply VarTransformer to training data." << endl; 00435 return 2; 00436 } 00437 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) { 00438 cerr << "Unable to apply VarTransformer to validation data." << endl; 00439 return 2; 00440 } 00441 cout << "Variable transformation from file " 00442 << transformerFile.c_str() << " has been applied to " 00443 << "training and validation data." << endl; 00444 garbage_train.reset(filter.release()); 00445 garbage_valid.reset(valFilter.release()); 00446 filter.reset(t_train); 00447 valFilter.reset(t_valid); 00448 } 00449 00450 // prepare trained classifier holder 00451 auto_ptr<SprTrainedMultiClassLearner> trainedMulti; 00452 00453 // prepare vectors of objects 00454 vector<SprAbsTwoClassCriterion*> criteria; 00455 vector<SprAbsClassifier*> destroyC;// classifiers to be deleted 00456 vector<SprIntegerBootstrap*> bstraps; 00457 vector<SprCCPair> useC;// classifiers and cuts to be used 00458 00459 // open file with classifier configs 00460 if( !configFile.empty() ) { 00461 ifstream file(configFile.c_str()); 00462 if( !file ) { 00463 cerr << "Unable to open file " << configFile.c_str() << endl; 00464 return 3; 00465 } 00466 00467 // read classifier params 00468 unsigned nLine = 0; 00469 bool discreteTree = false; 00470 bool mixedNodesTree = false; 00471 bool fastSort = false; 00472 bool readOneEntry = true; 00473 if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(), 00474 discreteTree,mixedNodesTree, 00475 fastSort,criteria, 00476 bstraps,destroyC,useC, 00477 readOneEntry) ) { 00478 cerr << "Unable to read classifier configurations from file " 00479 << configFile.c_str() << endl; 00480 prepareExit(criteria,destroyC,bstraps); 00481 return 3; 00482 } 00483 cout << "Finished reading " << useC.size() << " classifiers from file " 00484 << configFile.c_str() << endl; 00485 assert( useC.size() == 1 ); 00486 SprAbsClassifier* trainable = useC[0].first; 00487 00488 // find the multi class mode 00489 SprMultiClassLearner::MultiClassMode multiClassMode 00490 = SprMultiClassLearner::OneVsAll; 00491 switch( iMode ) 00492 { 00493 case 1 : 00494 multiClassMode = SprMultiClassLearner::OneVsAll; 00495 cout << "Multi class learning mode set to OneVsAll." << endl; 00496 break; 00497 case 2 : 00498 multiClassMode = SprMultiClassLearner::OneVsOne; 00499 cout << "Multi class learning mode set to OneVsOne." << endl; 00500 break; 00501 case 3: 00502 if( indicatorFile.empty() ) { 00503 cerr << "No indicator matrix specified." << endl; 00504 return 4; 00505 } 00506 multiClassMode = SprMultiClassLearner::User; 00507 cout << "Multi class learning mode set to User." << endl; 00508 break; 00509 default : 00510 cerr << "No multi class learning mode chosen." << endl; 00511 prepareExit(criteria,destroyC,bstraps); 00512 return 4; 00513 } 00514 00515 // get indicator matrix 00516 SprMatrix indicator; 00517 if( multiClassMode==SprMultiClassLearner::User 00518 && !indicatorFile.empty() ) { 00519 if( !SprMultiClassReader::readIndicatorMatrix(indicatorFile.c_str(), 00520 indicator) ) { 00521 cerr << "Unable to read indicator matrix from file " 00522 << indicatorFile.c_str() << endl; 00523 return 4; 00524 } 00525 } 00526 00527 // make a multi class learner 00528 SprMultiClassLearner multi(filter.get(),trainable,inputIntClasses[0], 00529 indicator,multiClassMode); 00530 00531 // train 00532 if( resumeFile.empty() ) { 00533 if( !multi.train(verbose) ) { 00534 cerr << "Unable to train Multi class learner." << endl; 00535 prepareExit(criteria,destroyC,bstraps); 00536 return 5; 00537 } 00538 else { 00539 trainedMulti.reset(multi.makeTrained()); 00540 cout << "Multi class learner finished successfully." << endl; 00541 } 00542 } 00543 00544 // save trained multi class learner 00545 if( !outFile.empty() ) { 00546 if( !multi.store(outFile.c_str()) ) { 00547 cerr << "Cannot store multi class learner in file " 00548 << outFile.c_str() << endl; 00549 prepareExit(criteria,destroyC,bstraps); 00550 return 6; 00551 } 00552 } 00553 } 00554 00555 // read saved learner from file 00556 if( !resumeFile.empty() ) { 00557 SprMultiClassReader multiReader; 00558 if( !multiReader.read(resumeFile.c_str()) ) { 00559 cerr << "Failed to read saved multi class learner from file " 00560 << resumeFile.c_str() << endl; 00561 prepareExit(criteria,destroyC,bstraps); 00562 return 7; 00563 } 00564 else { 00565 trainedMulti.reset(multiReader.makeTrained()); 00566 cout << "Read saved multi class learner from file " 00567 << resumeFile.c_str() << endl; 00568 trainedMulti->printIndicatorMatrix(cout); 00569 } 00570 } 00571 00572 // by now the trained learner should be filled 00573 if( trainedMulti.get() == 0 ) { 00574 cerr << "Trained multi learner has not been set." << endl; 00575 prepareExit(criteria,destroyC,bstraps); 00576 return 8; 00577 } 00578 00579 // set loss 00580 switch( iLoss ) 00581 { 00582 case 1 : 00583 trainedMulti->setLoss(&SprLoss::quadratic, 00584 &SprTransformation::zeroOneToMinusPlusOne); 00585 cout << "Per-event loss set to " 00586 << "Quadratic loss (y-f(x))^2 " << endl; 00587 break; 00588 case 2 : 00589 trainedMulti->setLoss(&SprLoss::exponential, 00590 &SprTransformation::logitInverse); 00591 cout << "Per-event loss set to " 00592 << "Exponential loss exp(-y*f(x)) " << endl; 00593 break; 00594 default : 00595 cerr << "No per-event loss specified." << endl; 00596 prepareExit(criteria,destroyC,bstraps); 00597 return 9; 00598 } 00599 00600 // analyze validation data 00601 if( valFilter.get() != 0 ) { 00602 00603 // compute response 00604 vector<SprMultiClassPlotter::Response> responses(valFilter->size()); 00605 for( int i=0;i<valFilter->size();i++ ) { 00606 if( ((i+1)%1000) == 0 ) 00607 cout << "Computing response for validation point " << i+1 << endl; 00608 00609 // get point, class and weight 00610 const SprPoint* p = (*(valFilter.get()))[i]; 00611 int cls = p->class_; 00612 double w = valFilter->w(i); 00613 00614 // compute loss 00615 map<int,double> output; 00616 int resp = trainedMulti->response(p,output); 00617 responses[i] = SprMultiClassPlotter::Response(cls,w,resp,output); 00618 } 00619 00620 // get the loss table 00621 SprMultiClassPlotter plotter(responses); 00622 vector<int> classes; 00623 trainedMulti->classes(classes); 00624 map<int,vector<double> > lossTable; 00625 map<int,double> weightInClass; 00626 double totalLoss = plotter.multiClassTable(classes,lossTable, 00627 weightInClass); 00628 00629 // print out 00630 cout << "=====================================" << endl; 00631 cout << "Overall validation misid fraction = " << totalLoss << endl; 00632 cout << "=====================================" << endl; 00633 cout << "Classification table: Fractions of total class weight" << endl; 00634 char s[200]; 00635 sprintf(s,"True Class \\ Classification |"); 00636 string temp = "------------------------------"; 00637 cout << s; 00638 for( int i=0;i<classes.size();i++ ) { 00639 sprintf(s," %5i |",classes[i]); 00640 cout << s; 00641 temp += "-------------"; 00642 } 00643 sprintf(s," Total weight in class |"); 00644 temp += "-------------------------"; 00645 cout << s << endl; 00646 cout << temp.c_str() << endl; 00647 for( map<int,vector<double> >::const_iterator 00648 i=lossTable.begin();i!=lossTable.end();i++ ) { 00649 sprintf(s,"%5i |",i->first); 00650 cout << s; 00651 for( int j=0;j<i->second.size();j++ ) { 00652 sprintf(s," %10.4f |",i->second[j]); 00653 cout << s; 00654 } 00655 sprintf(s," %10.4f |",weightInClass[i->first]); 00656 cout << s << endl; 00657 } 00658 cout << temp.c_str() << endl; 00659 } 00660 00661 // make histogram if requested 00662 if( tupleFile.empty() ) { 00663 prepareExit(criteria,destroyC,bstraps); 00664 return 0; 00665 } 00666 00667 // make a writer 00668 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training")); 00669 if( !tuple->init(tupleFile.c_str()) ) { 00670 cerr << "Unable to open output file " << tupleFile.c_str() << endl; 00671 prepareExit(criteria,destroyC,bstraps); 00672 return 10; 00673 } 00674 00675 // determine if certain variables are to be excluded from usage, 00676 // but included in the output storage file (-Z option) 00677 string printVarsDoNotFeed; 00678 vector<vector<string> > varsDoNotFeed; 00679 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed); 00680 vector<unsigned> mapper; 00681 for( int d=0;d<vars.size();d++ ) { 00682 if( varsDoNotFeed.empty() || 00683 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d]) 00684 ==varsDoNotFeed[0].end()) ) { 00685 mapper.push_back(d); 00686 } 00687 else { 00688 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " ); 00689 printVarsDoNotFeed += vars[d]; 00690 } 00691 } 00692 if( !printVarsDoNotFeed.empty() ) { 00693 cout << "The following variables are not used in the algorithm, " 00694 << "but will be included in the output file: " 00695 << printVarsDoNotFeed.c_str() << endl; 00696 } 00697 00698 // feed 00699 SprDataFeeder feeder(filter.get(),tuple.get(),mapper); 00700 feeder.addMultiClassLearner(trainedMulti.get(),"multi"); 00701 if( !feeder.feed(1000) ) { 00702 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl; 00703 prepareExit(criteria,destroyC,bstraps); 00704 return 11; 00705 } 00706 00707 // cleanup 00708 prepareExit(criteria,destroyC,bstraps); 00709 00710 // exit 00711 return 0; 00712 }
void prepareExit | ( | vector< SprAbsTwoClassCriterion * > & | criteria, | |
vector< SprAbsClassifier * > & | classifiers, | |||
vector< SprIntegerBootstrap * > & | bstraps | |||
) |