CMS 3D CMS Logo

SprMultiClassApp.cc File Reference

#include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassLearner.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassPlotter.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
#include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
#include <stdlib.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <memory>

Go to the source code of this file.

Functions

void help (const char *prog)
int main (int argc, char **argv)
void prepareExit (vector< SprAbsTwoClassCriterion * > &criteria, vector< SprAbsClassifier * > &classifiers, vector< SprIntegerBootstrap * > &bstraps)


Function Documentation

void help ( const char *  prog  ) 

Definition at line 54 of file SprMultiClassApp.cc.

References GenMuonPlsPt100GeV_cfg::cout, and lat::endl().

00055 {
00056   cout << "Usage:  " << prog << " training_data_file" << endl;
00057   cout << "\t Options: " << endl;
00058   cout << "\t-h --- help                                             " << endl;
00059   cout << "\t-o output Tuple file                                    " << endl;
00060   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)       " << endl;
00061   cout << "\t-A save output data in ascii instead of Root            " << endl;
00062   cout << "\t-y list of input classes                                " << endl;
00063   cout << "\t\t Classes must be listed in quotes and separated by commas." 
00064        << endl;
00065   cout << "\t-Q apply variable transformation saved in file     " << endl;
00066   cout << "\t-e Multi class mode                                     " << endl;
00067   cout << "\t\t 1 - OneVsAll (default)                               " << endl;
00068   cout << "\t\t 2 - OneVsOne                                         " << endl;
00069   cout << "\t\t 3 - user-defined (must use -i option)                " << endl;
00070   cout << "\t-i input file with user-defined indicator matrix        " << endl;
00071   cout << "\t-c file with trainable classifier configurations        " << endl;
00072   cout << "\t-g per-event loss to be displayed for each input class  " << endl;
00073   cout << "\t\t 1 - quadratic loss (y-f(x))^2                        " << endl;
00074   cout << "\t\t 2 - exponential loss exp(-y*f(x))                    " << endl;
00075   cout << "\t-m replace data values below this cutoff with medians   " << endl;
00076   cout << "\t-v verbose level (0=silent default,1,2)                 " << endl;
00077   cout << "\t-f store trained multi class learner to file            " << endl;
00078   cout << "\t-r read multi class learner configuration stored in file" << endl;
00079   cout << "\t-K keep this fraction in training set and          " << endl;
00080   cout << "\t\t put the rest into validation set                " << endl;
00081   cout << "\t-D randomize training set split-up                 " << endl;
00082   cout << "\t-t read validation/test data from a file           " << endl;
00083   cout << "\t\t (must be in same format as input data!!!        " << endl;
00084   cout << "\t-V include only these input variables              " << endl;
00085   cout << "\t-z exclude input variables from the list           " << endl;
00086   cout << "\t-Z exclude input variables from the list, "
00087        << "but put them in the output file " << endl;
00088   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00089        << endl;
00090 }

int main ( int  argc,
char **  argv 
)

Definition at line 103 of file SprMultiClassApp.cc.

References begin, c, TestMuL1L2Filter_cff::cerr, GenMuonPlsPt100GeV_cfg::cout, d, end, lat::endl(), file, filter, find(), help(), i, j, output(), p, prepareExit(), s, size, split, t, pyDBSRunClass::temp, vars, w, and weights.

00104 {
00105   // check command line
00106   if( argc < 2 ) {
00107     help(argv[0]);
00108     return 1;
00109   }
00110 
00111   // init
00112   string tupleFile;
00113   int readMode = 0;
00114   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00115   int verbose = 0;
00116   string outFile;
00117   string resumeFile;
00118   string configFile;
00119   string valFile;
00120   bool scaleWeights = false;
00121   double sW = 1.;
00122   bool setLowCutoff = false;
00123   double lowCutoff = 0;
00124   string includeList, excludeList;
00125   string inputClassesString;
00126   int iLoss = 1;
00127   int iMode = 1;
00128   string indicatorFile;
00129   string stringVarsDoNotFeed;
00130   bool split = false;
00131   double splitFactor = 0;
00132   bool splitRandomize = false;
00133   string transformerFile;
00134 
00135   // decode command line
00136   int c;
00137   extern char* optarg;
00138   //  extern int optind;
00139   while( (c = getopt(argc,argv,"ho:a:Ay:Q:e:i:c:g:m:v:f:r:K:Dt:V:z:Z:")) != EOF ) {
00140     switch( c )
00141       {
00142       case 'h' :
00143         help(argv[0]);
00144         return 1;
00145       case 'o' :
00146         tupleFile = optarg;
00147         break;
00148       case 'a' :
00149         readMode = (optarg==0 ? 0 : atoi(optarg));
00150         break;
00151       case 'A' :
00152         writeMode = SprRWFactory::Ascii;
00153         break;
00154       case 'y' :
00155         inputClassesString = optarg;
00156         break;
00157       case 'Q' :
00158         transformerFile = optarg;
00159         break;
00160       case 'e' :
00161         iMode = (optarg==0 ? 1 : atoi(optarg));
00162         break;
00163       case 'i' :
00164         indicatorFile = optarg;
00165         break;
00166       case 'c' :
00167         configFile = optarg;
00168         break;
00169       case 'g' :
00170         iLoss = (optarg==0 ? 1 : atoi(optarg));
00171         break;
00172       case 'm' :
00173         if( optarg != 0 ) {
00174           setLowCutoff = true;
00175           lowCutoff = atof(optarg);
00176         }
00177         break;
00178       case 'v' :
00179         verbose = (optarg==0 ? 0 : atoi(optarg));
00180         break;
00181       case 'f' :
00182         outFile = optarg;
00183         break;
00184       case 'r' :
00185         resumeFile = optarg;
00186         break;
00187       case 'K' :
00188         split = true;
00189         splitFactor = (optarg==0 ? 0 : atof(optarg));
00190         break;
00191       case 'D' :
00192         splitRandomize = true;
00193         break;
00194       case 't' :
00195         valFile = optarg;
00196         break;
00197       case 'w' :
00198         if( optarg != 0 ) {
00199           scaleWeights = true;
00200           sW = atof(optarg);
00201         }
00202         break;
00203       case 'V' :
00204         includeList = optarg;
00205         break;
00206       case 'z' :
00207         excludeList = optarg;
00208         break;
00209       case 'Z' :
00210         stringVarsDoNotFeed = optarg;
00211         break;
00212       }
00213   }
00214 
00215   // sanity check
00216   if( configFile.empty() && resumeFile.empty()) {
00217     cerr << "No classifier configuration file specified." << endl;
00218     return 1;
00219   }
00220   if( !configFile.empty() && !resumeFile.empty() ) {
00221     cerr << "Cannot train and use saved configuration at the same time." << endl;
00222     return 1;
00223   }
00224 
00225   // Must have 2 arguments after all options.
00226   string trFile = argv[argc-1];
00227   if( trFile.empty() ) {
00228     cerr << "No training file is specified." << endl;
00229     return 1;
00230   }
00231 
00232   // make reader
00233   SprRWFactory::DataType inputType 
00234     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00235   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00236 
00237   // include variables
00238   set<string> includeSet;
00239   if( !includeList.empty() ) {
00240     vector<vector<string> > includeVars;
00241     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00242     assert( !includeVars.empty() );
00243     for( int i=0;i<includeVars[0].size();i++ ) 
00244       includeSet.insert(includeVars[0][i]);
00245     if( !reader->chooseVars(includeSet) ) {
00246       cerr << "Unable to include variables in training set." << endl;
00247       return 2;
00248     }
00249     else {
00250       cout << "Following variables have been included in optimization: ";
00251       for( set<string>::const_iterator 
00252              i=includeSet.begin();i!=includeSet.end();i++ )
00253         cout << "\"" << *i << "\"" << " ";
00254       cout << endl;
00255     }
00256   }
00257 
00258   // exclude variables
00259   set<string> excludeSet;
00260   if( !excludeList.empty() ) {
00261     vector<vector<string> > excludeVars;
00262     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00263     assert( !excludeVars.empty() );
00264     for( int i=0;i<excludeVars[0].size();i++ ) 
00265       excludeSet.insert(excludeVars[0][i]);
00266     if( !reader->chooseAllBut(excludeSet) ) {
00267       cerr << "Unable to exclude variables from training set." << endl;
00268       return 2;
00269     }
00270     else {
00271       cout << "Following variables have been excluded from optimization: ";
00272       for( set<string>::const_iterator 
00273              i=excludeSet.begin();i!=excludeSet.end();i++ )
00274         cout << "\"" << *i << "\"" << " ";
00275       cout << endl;
00276     }
00277   }
00278 
00279   // read training data from file
00280   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00281   if( filter.get() == 0 ) {
00282     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00283     return 2;
00284   }
00285   vector<string> vars;
00286   filter->vars(vars);
00287   cout << "Read data from file " << trFile.c_str() 
00288        << " for variables";
00289   for( int i=0;i<vars.size();i++ ) 
00290     cout << " \"" << vars[i].c_str() << "\"";
00291   cout << endl;
00292   cout << "Total number of points read: " << filter->size() << endl;
00293 
00294   // decode input classes
00295   if( inputClassesString.empty() ) {
00296     cerr << "No input classes specified." << endl;
00297     return 2;
00298   }
00299   vector<vector<int> > inputIntClasses;
00300   SprStringParser::parseToInts(inputClassesString.c_str(),inputIntClasses);
00301   if( inputIntClasses.empty() || inputIntClasses[0].size()<2 ) {
00302     cerr << "Found less than 2 classes in the input class string." << endl;
00303     return 2;
00304   }
00305   vector<SprClass> inputClasses(inputIntClasses[0].size());
00306   for( int i=0;i<inputIntClasses[0].size();i++ )
00307     inputClasses[i] = inputIntClasses[0][i];
00308 
00309   // filter training data by class
00310   filter->chooseClasses(inputClasses);
00311   if( !filter->filter() ) {
00312     cerr << "Unable to filter training data by class." << endl;
00313     return 2;
00314   }
00315   cout << "Training data filtered by class." << endl;
00316   for( int i=0;i<inputClasses.size();i++ ) {
00317     unsigned npts = filter->ptsInClass(inputClasses[i]);
00318     if( npts == 0 ) {
00319       cerr << "Error!!! No points in class " << inputClasses[i] << endl;
00320       return 2;
00321     }
00322     cout << "Points in class " << inputClasses[i] << ":   " << npts << endl;
00323   }
00324 
00325   // apply low cutoff
00326   if( setLowCutoff ) {
00327     if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00328       cerr << "Unable to replace missing values in training data." << endl;
00329       return 2;
00330     }
00331     else
00332       cout << "Values below " << lowCutoff << " in training data"
00333            << " have been replaced with medians." << endl;
00334   }
00335 
00336   // read validation data from file
00337   auto_ptr<SprAbsFilter> valFilter;
00338   if( split && !valFile.empty() ) {
00339     cerr << "Unable to split training data and use validation data " 
00340          << "from a separate file." << endl;
00341     return 2;
00342   }
00343   if( split ) {
00344     cout << "Splitting training data with factor " << splitFactor << endl;
00345     if( splitRandomize )
00346       cout << "Will use randomized splitting." << endl;
00347     vector<double> weights;
00348     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00349     if( splitted == 0 ) {
00350       cerr << "Unable to split training data." << endl;
00351       return 2;
00352     }
00353     bool ownData = true;
00354     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00355     cout << "Training data re-filtered:" << endl;
00356     for( int i=0;i<inputClasses.size();i++ ) {
00357       cout << "Points in class " << inputClasses[i] << ":   " 
00358            << filter->ptsInClass(inputClasses[i]) << endl;
00359     }
00360   }
00361   if( !valFile.empty() ) {
00362     auto_ptr<SprAbsReader> 
00363       valReader(SprRWFactory::makeReader(inputType,readMode));
00364     if( !includeSet.empty() ) {
00365       if( !valReader->chooseVars(includeSet) ) {
00366         cerr << "Unable to include variables in validation set." << endl;
00367         return 2;
00368       }
00369     }
00370     if( !excludeSet.empty() ) {
00371       if( !valReader->chooseAllBut(excludeSet) ) {
00372         cerr << "Unable to exclude variables from validation set." << endl;
00373         return 2;
00374       }
00375     }
00376     valFilter.reset(valReader->read(valFile.c_str()));
00377     if( valFilter.get() == 0 ) {
00378       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00379       return 2;
00380     }
00381     vector<string> valVars;
00382     valFilter->vars(valVars);
00383     cout << "Read validation data from file " << valFile.c_str() 
00384          << " for variables";
00385     for( int i=0;i<valVars.size();i++ ) 
00386       cout << " \"" << valVars[i].c_str() << "\"";
00387     cout << endl;
00388     cout << "Total number of points read: " << valFilter->size() << endl;
00389   }
00390 
00391   // filter validation data by class
00392   if( valFilter.get() != 0 ) {
00393     valFilter->chooseClasses(inputClasses);
00394     if( !valFilter->filter() ) {
00395       cerr << "Unable to filter validation data by class." << endl;
00396       return 2;
00397     }
00398     cout << "Validation data filtered by class." << endl;
00399     for( int i=0;i<inputClasses.size();i++ ) {
00400      unsigned npts = valFilter->ptsInClass(inputClasses[i]);
00401      if( npts == 0 )
00402        cerr << "Warning!!! No points in class " << inputClasses[i] << endl;
00403      cout << "Points in class " << inputClasses[i] << ":   " << npts << endl;
00404     }
00405   }
00406 
00407   // apply low cutoff
00408   if( setLowCutoff && valFilter.get()!=0 ) {
00409     if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00410       cerr << "Unable to replace missing values in validation data." << endl;
00411       return 2;
00412     }
00413     else
00414       cout << "Values below " << lowCutoff << " in validation data"
00415            << " have been replaced with medians." << endl;
00416   }
00417 
00418   // apply transformation of variables to training and test data
00419   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00420   if( !transformerFile.empty() ) {
00421     SprVarTransformerReader transReader;
00422     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00423     if( t == 0 ) {
00424       cerr << "Unable to read VarTransformer from file "
00425            << transformerFile.c_str() << endl;
00426       return 2;
00427     }
00428     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00429     SprTransformerFilter* t_valid = 0;
00430     if( valFilter.get() != 0 )
00431       t_valid = new SprTransformerFilter(valFilter.get());
00432     bool replaceOriginalData = true;
00433     if( !t_train->transform(t,replaceOriginalData) ) {
00434       cerr << "Unable to apply VarTransformer to training data." << endl;
00435       return 2;
00436     }
00437     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00438       cerr << "Unable to apply VarTransformer to validation data." << endl;
00439       return 2;
00440     }
00441     cout << "Variable transformation from file "
00442          << transformerFile.c_str() << " has been applied to "
00443          << "training and validation data." << endl;
00444     garbage_train.reset(filter.release());
00445     garbage_valid.reset(valFilter.release());
00446     filter.reset(t_train);
00447     valFilter.reset(t_valid);
00448   }
00449 
00450   // prepare trained classifier holder
00451   auto_ptr<SprTrainedMultiClassLearner> trainedMulti;
00452 
00453   // prepare vectors of objects
00454   vector<SprAbsTwoClassCriterion*> criteria;
00455   vector<SprAbsClassifier*> destroyC;// classifiers to be deleted
00456   vector<SprIntegerBootstrap*> bstraps;
00457   vector<SprCCPair> useC;// classifiers and cuts to be used
00458 
00459   // open file with classifier configs
00460   if( !configFile.empty() ) {
00461     ifstream file(configFile.c_str());
00462     if( !file ) {
00463       cerr << "Unable to open file " << configFile.c_str() << endl;
00464       return 3;
00465     }
00466     
00467     // read classifier params
00468     unsigned nLine = 0;
00469     bool discreteTree = false;
00470     bool mixedNodesTree = false;
00471     bool fastSort = false;
00472     bool readOneEntry = true;
00473     if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(),
00474                                                   discreteTree,mixedNodesTree,
00475                                                   fastSort,criteria,
00476                                                   bstraps,destroyC,useC,
00477                                                   readOneEntry) ) {
00478       cerr << "Unable to read classifier configurations from file " 
00479            << configFile.c_str() << endl;
00480       prepareExit(criteria,destroyC,bstraps);
00481       return 3;
00482     }
00483     cout << "Finished reading " << useC.size() << " classifiers from file "
00484          << configFile.c_str() << endl;
00485     assert( useC.size() == 1 );
00486     SprAbsClassifier* trainable = useC[0].first;
00487 
00488     // find the multi class mode
00489     SprMultiClassLearner::MultiClassMode multiClassMode 
00490       = SprMultiClassLearner::OneVsAll;
00491     switch( iMode )
00492       {
00493       case 1 :
00494         multiClassMode = SprMultiClassLearner::OneVsAll;
00495         cout << "Multi class learning mode set to OneVsAll." << endl;
00496         break;
00497       case 2 :
00498         multiClassMode = SprMultiClassLearner::OneVsOne;
00499         cout << "Multi class learning mode set to OneVsOne." << endl;
00500         break;
00501       case 3:
00502         if( indicatorFile.empty() ) {
00503           cerr << "No indicator matrix specified." << endl;
00504           return 4;
00505         }
00506         multiClassMode = SprMultiClassLearner::User;
00507         cout << "Multi class learning mode set to User." << endl;
00508         break;
00509       default :
00510         cerr << "No multi class learning mode chosen." << endl;
00511         prepareExit(criteria,destroyC,bstraps);
00512         return 4;
00513       }
00514 
00515     // get indicator matrix
00516     SprMatrix indicator;
00517     if( multiClassMode==SprMultiClassLearner::User 
00518         && !indicatorFile.empty() ) {
00519       if( !SprMultiClassReader::readIndicatorMatrix(indicatorFile.c_str(),
00520                                                     indicator) ) {
00521         cerr << "Unable to read indicator matrix from file " 
00522              << indicatorFile.c_str() << endl;
00523         return 4;
00524       }
00525     }
00526 
00527     // make a multi class learner
00528     SprMultiClassLearner multi(filter.get(),trainable,inputIntClasses[0],
00529                                indicator,multiClassMode);
00530 
00531     // train
00532     if( resumeFile.empty() ) {
00533       if( !multi.train(verbose) ) {
00534         cerr << "Unable to train Multi class learner." << endl;
00535         prepareExit(criteria,destroyC,bstraps);
00536         return 5;
00537       }
00538       else {
00539         trainedMulti.reset(multi.makeTrained());
00540         cout << "Multi class learner finished successfully." << endl;
00541       }
00542     }
00543 
00544     // save trained multi class learner
00545     if( !outFile.empty() ) {
00546       if( !multi.store(outFile.c_str()) ) {
00547         cerr << "Cannot store multi class learner in file " 
00548              << outFile.c_str() << endl;
00549         prepareExit(criteria,destroyC,bstraps);
00550         return 6;
00551       }
00552     }
00553   }
00554 
00555   // read saved learner from file
00556   if( !resumeFile.empty() ) {
00557     SprMultiClassReader multiReader;
00558     if( !multiReader.read(resumeFile.c_str()) ) {
00559       cerr << "Failed to read saved multi class learner from file " 
00560            << resumeFile.c_str() << endl;
00561       prepareExit(criteria,destroyC,bstraps);
00562       return 7;
00563     }
00564     else {
00565       trainedMulti.reset(multiReader.makeTrained());
00566       cout << "Read saved multi class learner from file " 
00567            << resumeFile.c_str() << endl;
00568       trainedMulti->printIndicatorMatrix(cout);
00569     }
00570   }
00571 
00572   // by now the trained learner should be filled
00573   if( trainedMulti.get() == 0 ) {
00574     cerr << "Trained multi learner has not been set." << endl;
00575     prepareExit(criteria,destroyC,bstraps);
00576     return 8;
00577   }
00578 
00579   // set loss
00580   switch( iLoss )
00581     {
00582     case 1 :
00583       trainedMulti->setLoss(&SprLoss::quadratic,
00584                             &SprTransformation::zeroOneToMinusPlusOne);
00585       cout << "Per-event loss set to "
00586            << "Quadratic loss (y-f(x))^2 " << endl;
00587       break;
00588     case 2 :
00589       trainedMulti->setLoss(&SprLoss::exponential,
00590                             &SprTransformation::logitInverse);
00591       cout << "Per-event loss set to "
00592            << "Exponential loss exp(-y*f(x)) " << endl;
00593       break;
00594     default :
00595       cerr << "No per-event loss specified." << endl;
00596       prepareExit(criteria,destroyC,bstraps);
00597       return 9;
00598     }
00599 
00600   // analyze validation data
00601   if( valFilter.get() != 0 ) {
00602 
00603     // compute response
00604     vector<SprMultiClassPlotter::Response> responses(valFilter->size());
00605     for( int i=0;i<valFilter->size();i++ ) {
00606       if( ((i+1)%1000) == 0 )
00607         cout << "Computing response for validation point " << i+1 << endl;
00608 
00609       // get point, class and weight
00610       const SprPoint* p = (*(valFilter.get()))[i];
00611       int cls = p->class_;
00612       double w = valFilter->w(i);
00613 
00614       // compute loss
00615       map<int,double> output;
00616       int resp = trainedMulti->response(p,output);
00617       responses[i] = SprMultiClassPlotter::Response(cls,w,resp,output);
00618     }    
00619 
00620     // get the loss table
00621     SprMultiClassPlotter plotter(responses);
00622     vector<int> classes;
00623     trainedMulti->classes(classes);
00624     map<int,vector<double> > lossTable;
00625     map<int,double> weightInClass;
00626     double totalLoss = plotter.multiClassTable(classes,lossTable,
00627                                                weightInClass);
00628 
00629     // print out
00630     cout << "=====================================" << endl;
00631     cout << "Overall validation misid fraction = " << totalLoss << endl;
00632     cout << "=====================================" << endl;
00633     cout << "Classification table: Fractions of total class weight" << endl;
00634     char s[200];
00635     sprintf(s,"True Class \\ Classification |");
00636     string temp = "------------------------------";
00637     cout << s;
00638     for( int i=0;i<classes.size();i++ ) {
00639       sprintf(s," %5i      |",classes[i]);
00640       cout << s;
00641       temp += "-------------";
00642     }
00643     sprintf(s,"   Total weight in class |");
00644     temp += "-------------------------";
00645     cout << s << endl;
00646     cout << temp.c_str() << endl;
00647     for( map<int,vector<double> >::const_iterator
00648            i=lossTable.begin();i!=lossTable.end();i++ ) {
00649       sprintf(s,"%5i                       |",i->first);
00650       cout << s;
00651       for( int j=0;j<i->second.size();j++ ) {
00652         sprintf(s," %10.4f |",i->second[j]);
00653         cout << s;
00654       }
00655       sprintf(s,"              %10.4f |",weightInClass[i->first]);
00656       cout << s << endl;
00657     }
00658     cout << temp.c_str() << endl;
00659   }
00660 
00661   // make histogram if requested
00662   if( tupleFile.empty() ) {
00663     prepareExit(criteria,destroyC,bstraps);
00664     return 0;
00665   }
00666 
00667   // make a writer
00668   auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00669   if( !tuple->init(tupleFile.c_str()) ) {
00670     cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00671     prepareExit(criteria,destroyC,bstraps);
00672     return 10;
00673   }
00674 
00675   // determine if certain variables are to be excluded from usage,
00676   // but included in the output storage file (-Z option)
00677   string printVarsDoNotFeed;
00678   vector<vector<string> > varsDoNotFeed;
00679   SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00680   vector<unsigned> mapper;
00681   for( int d=0;d<vars.size();d++ ) {
00682     if( varsDoNotFeed.empty() ||
00683         (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00684          ==varsDoNotFeed[0].end()) ) {
00685       mapper.push_back(d);
00686     }
00687     else {
00688       printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00689       printVarsDoNotFeed += vars[d];
00690     }
00691   }
00692   if( !printVarsDoNotFeed.empty() ) {
00693     cout << "The following variables are not used in the algorithm, " 
00694          << "but will be included in the output file: " 
00695          << printVarsDoNotFeed.c_str() << endl;
00696   }
00697 
00698   // feed
00699   SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00700   feeder.addMultiClassLearner(trainedMulti.get(),"multi");
00701   if( !feeder.feed(1000) ) {
00702     cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00703     prepareExit(criteria,destroyC,bstraps);
00704     return 11;
00705   }
00706 
00707   // cleanup
00708   prepareExit(criteria,destroyC,bstraps);
00709 
00710   // exit
00711   return 0;
00712 }

void prepareExit ( vector< SprAbsTwoClassCriterion * > &  criteria,
vector< SprAbsClassifier * > &  classifiers,
vector< SprIntegerBootstrap * > &  bstraps 
)

Definition at line 93 of file SprMultiClassApp.cc.

References i.

00096 {
00097   for( int i=0;i<criteria.size();i++ ) delete criteria[i];
00098   for( int i=0;i<classifiers.size();i++ ) delete classifiers[i];
00099   for( int i=0;i<bstraps.size();i++ ) delete bstraps[i];
00100 }


Generated on Tue Jun 9 17:55:01 2009 for CMSSW by  doxygen 1.5.4