CMS 3D CMS Logo

SprMultiClassApp.cc

Go to the documentation of this file.
00001 //$Id: SprMultiClassApp.cc,v 1.5 2007/11/12 06:19:11 narsky Exp $
00002 /*
00003   Note: "-y" option has a different meaning for this executable than
00004   for other executables in the package. Instead of specifying what
00005   classes should be treated as background and what classes should be
00006   treated as signal, the "-y" option simply selects input classes for
00007   inclusion in the multi-class algorithm. Therefore, entering groups
00008   of classes separated by semicolons or specifying "." as an input
00009   class would make no sense. This executable expects a list of classes
00010   separated by commas.
00011 */
00012 
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassLearner.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00026 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00027 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00028 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00029 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00030 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00031 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00032 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00033 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
00034 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00035 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00036 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassPlotter.hh"
00037 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00038 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00039 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00040 
00041 #include <stdlib.h>
00042 #include <unistd.h>
00043 #include <iostream>
00044 #include <fstream>
00045 #include <vector>
00046 #include <set>
00047 #include <map>
00048 #include <string>
00049 #include <memory>
00050 
00051 using namespace std;
00052 
00053 
00054 void help(const char* prog) 
00055 {
00056   cout << "Usage:  " << prog << " training_data_file" << endl;
00057   cout << "\t Options: " << endl;
00058   cout << "\t-h --- help                                             " << endl;
00059   cout << "\t-o output Tuple file                                    " << endl;
00060   cout << "\t-a input ascii file mode (see SprSimpleReader.hh)       " << endl;
00061   cout << "\t-A save output data in ascii instead of Root            " << endl;
00062   cout << "\t-y list of input classes                                " << endl;
00063   cout << "\t\t Classes must be listed in quotes and separated by commas." 
00064        << endl;
00065   cout << "\t-Q apply variable transformation saved in file     " << endl;
00066   cout << "\t-e Multi class mode                                     " << endl;
00067   cout << "\t\t 1 - OneVsAll (default)                               " << endl;
00068   cout << "\t\t 2 - OneVsOne                                         " << endl;
00069   cout << "\t\t 3 - user-defined (must use -i option)                " << endl;
00070   cout << "\t-i input file with user-defined indicator matrix        " << endl;
00071   cout << "\t-c file with trainable classifier configurations        " << endl;
00072   cout << "\t-g per-event loss to be displayed for each input class  " << endl;
00073   cout << "\t\t 1 - quadratic loss (y-f(x))^2                        " << endl;
00074   cout << "\t\t 2 - exponential loss exp(-y*f(x))                    " << endl;
00075   cout << "\t-m replace data values below this cutoff with medians   " << endl;
00076   cout << "\t-v verbose level (0=silent default,1,2)                 " << endl;
00077   cout << "\t-f store trained multi class learner to file            " << endl;
00078   cout << "\t-r read multi class learner configuration stored in file" << endl;
00079   cout << "\t-K keep this fraction in training set and          " << endl;
00080   cout << "\t\t put the rest into validation set                " << endl;
00081   cout << "\t-D randomize training set split-up                 " << endl;
00082   cout << "\t-t read validation/test data from a file           " << endl;
00083   cout << "\t\t (must be in same format as input data!!!        " << endl;
00084   cout << "\t-V include only these input variables              " << endl;
00085   cout << "\t-z exclude input variables from the list           " << endl;
00086   cout << "\t-Z exclude input variables from the list, "
00087        << "but put them in the output file " << endl;
00088   cout << "\t\t Variables must be listed in quotes and separated by commas." 
00089        << endl;
00090 }
00091 
00092 
00093 void prepareExit(vector<SprAbsTwoClassCriterion*>& criteria,
00094                  vector<SprAbsClassifier*>& classifiers,
00095                  vector<SprIntegerBootstrap*>& bstraps) 
00096 {
00097   for( int i=0;i<criteria.size();i++ ) delete criteria[i];
00098   for( int i=0;i<classifiers.size();i++ ) delete classifiers[i];
00099   for( int i=0;i<bstraps.size();i++ ) delete bstraps[i];
00100 }
00101 
00102 
00103 int main(int argc, char ** argv)
00104 {
00105   // check command line
00106   if( argc < 2 ) {
00107     help(argv[0]);
00108     return 1;
00109   }
00110 
00111   // init
00112   string tupleFile;
00113   int readMode = 0;
00114   SprRWFactory::DataType writeMode = SprRWFactory::Root;
00115   int verbose = 0;
00116   string outFile;
00117   string resumeFile;
00118   string configFile;
00119   string valFile;
00120   bool scaleWeights = false;
00121   double sW = 1.;
00122   bool setLowCutoff = false;
00123   double lowCutoff = 0;
00124   string includeList, excludeList;
00125   string inputClassesString;
00126   int iLoss = 1;
00127   int iMode = 1;
00128   string indicatorFile;
00129   string stringVarsDoNotFeed;
00130   bool split = false;
00131   double splitFactor = 0;
00132   bool splitRandomize = false;
00133   string transformerFile;
00134 
00135   // decode command line
00136   int c;
00137   extern char* optarg;
00138   //  extern int optind;
00139   while( (c = getopt(argc,argv,"ho:a:Ay:Q:e:i:c:g:m:v:f:r:K:Dt:V:z:Z:")) != EOF ) {
00140     switch( c )
00141       {
00142       case 'h' :
00143         help(argv[0]);
00144         return 1;
00145       case 'o' :
00146         tupleFile = optarg;
00147         break;
00148       case 'a' :
00149         readMode = (optarg==0 ? 0 : atoi(optarg));
00150         break;
00151       case 'A' :
00152         writeMode = SprRWFactory::Ascii;
00153         break;
00154       case 'y' :
00155         inputClassesString = optarg;
00156         break;
00157       case 'Q' :
00158         transformerFile = optarg;
00159         break;
00160       case 'e' :
00161         iMode = (optarg==0 ? 1 : atoi(optarg));
00162         break;
00163       case 'i' :
00164         indicatorFile = optarg;
00165         break;
00166       case 'c' :
00167         configFile = optarg;
00168         break;
00169       case 'g' :
00170         iLoss = (optarg==0 ? 1 : atoi(optarg));
00171         break;
00172       case 'm' :
00173         if( optarg != 0 ) {
00174           setLowCutoff = true;
00175           lowCutoff = atof(optarg);
00176         }
00177         break;
00178       case 'v' :
00179         verbose = (optarg==0 ? 0 : atoi(optarg));
00180         break;
00181       case 'f' :
00182         outFile = optarg;
00183         break;
00184       case 'r' :
00185         resumeFile = optarg;
00186         break;
00187       case 'K' :
00188         split = true;
00189         splitFactor = (optarg==0 ? 0 : atof(optarg));
00190         break;
00191       case 'D' :
00192         splitRandomize = true;
00193         break;
00194       case 't' :
00195         valFile = optarg;
00196         break;
00197       case 'w' :
00198         if( optarg != 0 ) {
00199           scaleWeights = true;
00200           sW = atof(optarg);
00201         }
00202         break;
00203       case 'V' :
00204         includeList = optarg;
00205         break;
00206       case 'z' :
00207         excludeList = optarg;
00208         break;
00209       case 'Z' :
00210         stringVarsDoNotFeed = optarg;
00211         break;
00212       }
00213   }
00214 
00215   // sanity check
00216   if( configFile.empty() && resumeFile.empty()) {
00217     cerr << "No classifier configuration file specified." << endl;
00218     return 1;
00219   }
00220   if( !configFile.empty() && !resumeFile.empty() ) {
00221     cerr << "Cannot train and use saved configuration at the same time." << endl;
00222     return 1;
00223   }
00224 
00225   // Must have 2 arguments after all options.
00226   string trFile = argv[argc-1];
00227   if( trFile.empty() ) {
00228     cerr << "No training file is specified." << endl;
00229     return 1;
00230   }
00231 
00232   // make reader
00233   SprRWFactory::DataType inputType 
00234     = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00235   auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00236 
00237   // include variables
00238   set<string> includeSet;
00239   if( !includeList.empty() ) {
00240     vector<vector<string> > includeVars;
00241     SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00242     assert( !includeVars.empty() );
00243     for( int i=0;i<includeVars[0].size();i++ ) 
00244       includeSet.insert(includeVars[0][i]);
00245     if( !reader->chooseVars(includeSet) ) {
00246       cerr << "Unable to include variables in training set." << endl;
00247       return 2;
00248     }
00249     else {
00250       cout << "Following variables have been included in optimization: ";
00251       for( set<string>::const_iterator 
00252              i=includeSet.begin();i!=includeSet.end();i++ )
00253         cout << "\"" << *i << "\"" << " ";
00254       cout << endl;
00255     }
00256   }
00257 
00258   // exclude variables
00259   set<string> excludeSet;
00260   if( !excludeList.empty() ) {
00261     vector<vector<string> > excludeVars;
00262     SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00263     assert( !excludeVars.empty() );
00264     for( int i=0;i<excludeVars[0].size();i++ ) 
00265       excludeSet.insert(excludeVars[0][i]);
00266     if( !reader->chooseAllBut(excludeSet) ) {
00267       cerr << "Unable to exclude variables from training set." << endl;
00268       return 2;
00269     }
00270     else {
00271       cout << "Following variables have been excluded from optimization: ";
00272       for( set<string>::const_iterator 
00273              i=excludeSet.begin();i!=excludeSet.end();i++ )
00274         cout << "\"" << *i << "\"" << " ";
00275       cout << endl;
00276     }
00277   }
00278 
00279   // read training data from file
00280   auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00281   if( filter.get() == 0 ) {
00282     cerr << "Unable to read data from file " << trFile.c_str() << endl;
00283     return 2;
00284   }
00285   vector<string> vars;
00286   filter->vars(vars);
00287   cout << "Read data from file " << trFile.c_str() 
00288        << " for variables";
00289   for( int i=0;i<vars.size();i++ ) 
00290     cout << " \"" << vars[i].c_str() << "\"";
00291   cout << endl;
00292   cout << "Total number of points read: " << filter->size() << endl;
00293 
00294   // decode input classes
00295   if( inputClassesString.empty() ) {
00296     cerr << "No input classes specified." << endl;
00297     return 2;
00298   }
00299   vector<vector<int> > inputIntClasses;
00300   SprStringParser::parseToInts(inputClassesString.c_str(),inputIntClasses);
00301   if( inputIntClasses.empty() || inputIntClasses[0].size()<2 ) {
00302     cerr << "Found less than 2 classes in the input class string." << endl;
00303     return 2;
00304   }
00305   vector<SprClass> inputClasses(inputIntClasses[0].size());
00306   for( int i=0;i<inputIntClasses[0].size();i++ )
00307     inputClasses[i] = inputIntClasses[0][i];
00308 
00309   // filter training data by class
00310   filter->chooseClasses(inputClasses);
00311   if( !filter->filter() ) {
00312     cerr << "Unable to filter training data by class." << endl;
00313     return 2;
00314   }
00315   cout << "Training data filtered by class." << endl;
00316   for( int i=0;i<inputClasses.size();i++ ) {
00317     unsigned npts = filter->ptsInClass(inputClasses[i]);
00318     if( npts == 0 ) {
00319       cerr << "Error!!! No points in class " << inputClasses[i] << endl;
00320       return 2;
00321     }
00322     cout << "Points in class " << inputClasses[i] << ":   " << npts << endl;
00323   }
00324 
00325   // apply low cutoff
00326   if( setLowCutoff ) {
00327     if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00328       cerr << "Unable to replace missing values in training data." << endl;
00329       return 2;
00330     }
00331     else
00332       cout << "Values below " << lowCutoff << " in training data"
00333            << " have been replaced with medians." << endl;
00334   }
00335 
00336   // read validation data from file
00337   auto_ptr<SprAbsFilter> valFilter;
00338   if( split && !valFile.empty() ) {
00339     cerr << "Unable to split training data and use validation data " 
00340          << "from a separate file." << endl;
00341     return 2;
00342   }
00343   if( split ) {
00344     cout << "Splitting training data with factor " << splitFactor << endl;
00345     if( splitRandomize )
00346       cout << "Will use randomized splitting." << endl;
00347     vector<double> weights;
00348     SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00349     if( splitted == 0 ) {
00350       cerr << "Unable to split training data." << endl;
00351       return 2;
00352     }
00353     bool ownData = true;
00354     valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00355     cout << "Training data re-filtered:" << endl;
00356     for( int i=0;i<inputClasses.size();i++ ) {
00357       cout << "Points in class " << inputClasses[i] << ":   " 
00358            << filter->ptsInClass(inputClasses[i]) << endl;
00359     }
00360   }
00361   if( !valFile.empty() ) {
00362     auto_ptr<SprAbsReader> 
00363       valReader(SprRWFactory::makeReader(inputType,readMode));
00364     if( !includeSet.empty() ) {
00365       if( !valReader->chooseVars(includeSet) ) {
00366         cerr << "Unable to include variables in validation set." << endl;
00367         return 2;
00368       }
00369     }
00370     if( !excludeSet.empty() ) {
00371       if( !valReader->chooseAllBut(excludeSet) ) {
00372         cerr << "Unable to exclude variables from validation set." << endl;
00373         return 2;
00374       }
00375     }
00376     valFilter.reset(valReader->read(valFile.c_str()));
00377     if( valFilter.get() == 0 ) {
00378       cerr << "Unable to read data from file " << valFile.c_str() << endl;
00379       return 2;
00380     }
00381     vector<string> valVars;
00382     valFilter->vars(valVars);
00383     cout << "Read validation data from file " << valFile.c_str() 
00384          << " for variables";
00385     for( int i=0;i<valVars.size();i++ ) 
00386       cout << " \"" << valVars[i].c_str() << "\"";
00387     cout << endl;
00388     cout << "Total number of points read: " << valFilter->size() << endl;
00389   }
00390 
00391   // filter validation data by class
00392   if( valFilter.get() != 0 ) {
00393     valFilter->chooseClasses(inputClasses);
00394     if( !valFilter->filter() ) {
00395       cerr << "Unable to filter validation data by class." << endl;
00396       return 2;
00397     }
00398     cout << "Validation data filtered by class." << endl;
00399     for( int i=0;i<inputClasses.size();i++ ) {
00400      unsigned npts = valFilter->ptsInClass(inputClasses[i]);
00401      if( npts == 0 )
00402        cerr << "Warning!!! No points in class " << inputClasses[i] << endl;
00403      cout << "Points in class " << inputClasses[i] << ":   " << npts << endl;
00404     }
00405   }
00406 
00407   // apply low cutoff
00408   if( setLowCutoff && valFilter.get()!=0 ) {
00409     if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00410       cerr << "Unable to replace missing values in validation data." << endl;
00411       return 2;
00412     }
00413     else
00414       cout << "Values below " << lowCutoff << " in validation data"
00415            << " have been replaced with medians." << endl;
00416   }
00417 
00418   // apply transformation of variables to training and test data
00419   auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00420   if( !transformerFile.empty() ) {
00421     SprVarTransformerReader transReader;
00422     const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00423     if( t == 0 ) {
00424       cerr << "Unable to read VarTransformer from file "
00425            << transformerFile.c_str() << endl;
00426       return 2;
00427     }
00428     SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00429     SprTransformerFilter* t_valid = 0;
00430     if( valFilter.get() != 0 )
00431       t_valid = new SprTransformerFilter(valFilter.get());
00432     bool replaceOriginalData = true;
00433     if( !t_train->transform(t,replaceOriginalData) ) {
00434       cerr << "Unable to apply VarTransformer to training data." << endl;
00435       return 2;
00436     }
00437     if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00438       cerr << "Unable to apply VarTransformer to validation data." << endl;
00439       return 2;
00440     }
00441     cout << "Variable transformation from file "
00442          << transformerFile.c_str() << " has been applied to "
00443          << "training and validation data." << endl;
00444     garbage_train.reset(filter.release());
00445     garbage_valid.reset(valFilter.release());
00446     filter.reset(t_train);
00447     valFilter.reset(t_valid);
00448   }
00449 
00450   // prepare trained classifier holder
00451   auto_ptr<SprTrainedMultiClassLearner> trainedMulti;
00452 
00453   // prepare vectors of objects
00454   vector<SprAbsTwoClassCriterion*> criteria;
00455   vector<SprAbsClassifier*> destroyC;// classifiers to be deleted
00456   vector<SprIntegerBootstrap*> bstraps;
00457   vector<SprCCPair> useC;// classifiers and cuts to be used
00458 
00459   // open file with classifier configs
00460   if( !configFile.empty() ) {
00461     ifstream file(configFile.c_str());
00462     if( !file ) {
00463       cerr << "Unable to open file " << configFile.c_str() << endl;
00464       return 3;
00465     }
00466     
00467     // read classifier params
00468     unsigned nLine = 0;
00469     bool discreteTree = false;
00470     bool mixedNodesTree = false;
00471     bool fastSort = false;
00472     bool readOneEntry = true;
00473     if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(),
00474                                                   discreteTree,mixedNodesTree,
00475                                                   fastSort,criteria,
00476                                                   bstraps,destroyC,useC,
00477                                                   readOneEntry) ) {
00478       cerr << "Unable to read classifier configurations from file " 
00479            << configFile.c_str() << endl;
00480       prepareExit(criteria,destroyC,bstraps);
00481       return 3;
00482     }
00483     cout << "Finished reading " << useC.size() << " classifiers from file "
00484          << configFile.c_str() << endl;
00485     assert( useC.size() == 1 );
00486     SprAbsClassifier* trainable = useC[0].first;
00487 
00488     // find the multi class mode
00489     SprMultiClassLearner::MultiClassMode multiClassMode 
00490       = SprMultiClassLearner::OneVsAll;
00491     switch( iMode )
00492       {
00493       case 1 :
00494         multiClassMode = SprMultiClassLearner::OneVsAll;
00495         cout << "Multi class learning mode set to OneVsAll." << endl;
00496         break;
00497       case 2 :
00498         multiClassMode = SprMultiClassLearner::OneVsOne;
00499         cout << "Multi class learning mode set to OneVsOne." << endl;
00500         break;
00501       case 3:
00502         if( indicatorFile.empty() ) {
00503           cerr << "No indicator matrix specified." << endl;
00504           return 4;
00505         }
00506         multiClassMode = SprMultiClassLearner::User;
00507         cout << "Multi class learning mode set to User." << endl;
00508         break;
00509       default :
00510         cerr << "No multi class learning mode chosen." << endl;
00511         prepareExit(criteria,destroyC,bstraps);
00512         return 4;
00513       }
00514 
00515     // get indicator matrix
00516     SprMatrix indicator;
00517     if( multiClassMode==SprMultiClassLearner::User 
00518         && !indicatorFile.empty() ) {
00519       if( !SprMultiClassReader::readIndicatorMatrix(indicatorFile.c_str(),
00520                                                     indicator) ) {
00521         cerr << "Unable to read indicator matrix from file " 
00522              << indicatorFile.c_str() << endl;
00523         return 4;
00524       }
00525     }
00526 
00527     // make a multi class learner
00528     SprMultiClassLearner multi(filter.get(),trainable,inputIntClasses[0],
00529                                indicator,multiClassMode);
00530 
00531     // train
00532     if( resumeFile.empty() ) {
00533       if( !multi.train(verbose) ) {
00534         cerr << "Unable to train Multi class learner." << endl;
00535         prepareExit(criteria,destroyC,bstraps);
00536         return 5;
00537       }
00538       else {
00539         trainedMulti.reset(multi.makeTrained());
00540         cout << "Multi class learner finished successfully." << endl;
00541       }
00542     }
00543 
00544     // save trained multi class learner
00545     if( !outFile.empty() ) {
00546       if( !multi.store(outFile.c_str()) ) {
00547         cerr << "Cannot store multi class learner in file " 
00548              << outFile.c_str() << endl;
00549         prepareExit(criteria,destroyC,bstraps);
00550         return 6;
00551       }
00552     }
00553   }
00554 
00555   // read saved learner from file
00556   if( !resumeFile.empty() ) {
00557     SprMultiClassReader multiReader;
00558     if( !multiReader.read(resumeFile.c_str()) ) {
00559       cerr << "Failed to read saved multi class learner from file " 
00560            << resumeFile.c_str() << endl;
00561       prepareExit(criteria,destroyC,bstraps);
00562       return 7;
00563     }
00564     else {
00565       trainedMulti.reset(multiReader.makeTrained());
00566       cout << "Read saved multi class learner from file " 
00567            << resumeFile.c_str() << endl;
00568       trainedMulti->printIndicatorMatrix(cout);
00569     }
00570   }
00571 
00572   // by now the trained learner should be filled
00573   if( trainedMulti.get() == 0 ) {
00574     cerr << "Trained multi learner has not been set." << endl;
00575     prepareExit(criteria,destroyC,bstraps);
00576     return 8;
00577   }
00578 
00579   // set loss
00580   switch( iLoss )
00581     {
00582     case 1 :
00583       trainedMulti->setLoss(&SprLoss::quadratic,
00584                             &SprTransformation::zeroOneToMinusPlusOne);
00585       cout << "Per-event loss set to "
00586            << "Quadratic loss (y-f(x))^2 " << endl;
00587       break;
00588     case 2 :
00589       trainedMulti->setLoss(&SprLoss::exponential,
00590                             &SprTransformation::logitInverse);
00591       cout << "Per-event loss set to "
00592            << "Exponential loss exp(-y*f(x)) " << endl;
00593       break;
00594     default :
00595       cerr << "No per-event loss specified." << endl;
00596       prepareExit(criteria,destroyC,bstraps);
00597       return 9;
00598     }
00599 
00600   // analyze validation data
00601   if( valFilter.get() != 0 ) {
00602 
00603     // compute response
00604     vector<SprMultiClassPlotter::Response> responses(valFilter->size());
00605     for( int i=0;i<valFilter->size();i++ ) {
00606       if( ((i+1)%1000) == 0 )
00607         cout << "Computing response for validation point " << i+1 << endl;
00608 
00609       // get point, class and weight
00610       const SprPoint* p = (*(valFilter.get()))[i];
00611       int cls = p->class_;
00612       double w = valFilter->w(i);
00613 
00614       // compute loss
00615       map<int,double> output;
00616       int resp = trainedMulti->response(p,output);
00617       responses[i] = SprMultiClassPlotter::Response(cls,w,resp,output);
00618     }    
00619 
00620     // get the loss table
00621     SprMultiClassPlotter plotter(responses);
00622     vector<int> classes;
00623     trainedMulti->classes(classes);
00624     map<int,vector<double> > lossTable;
00625     map<int,double> weightInClass;
00626     double totalLoss = plotter.multiClassTable(classes,lossTable,
00627                                                weightInClass);
00628 
00629     // print out
00630     cout << "=====================================" << endl;
00631     cout << "Overall validation misid fraction = " << totalLoss << endl;
00632     cout << "=====================================" << endl;
00633     cout << "Classification table: Fractions of total class weight" << endl;
00634     char s[200];
00635     sprintf(s,"True Class \\ Classification |");
00636     string temp = "------------------------------";
00637     cout << s;
00638     for( int i=0;i<classes.size();i++ ) {
00639       sprintf(s," %5i      |",classes[i]);
00640       cout << s;
00641       temp += "-------------";
00642     }
00643     sprintf(s,"   Total weight in class |");
00644     temp += "-------------------------";
00645     cout << s << endl;
00646     cout << temp.c_str() << endl;
00647     for( map<int,vector<double> >::const_iterator
00648            i=lossTable.begin();i!=lossTable.end();i++ ) {
00649       sprintf(s,"%5i                       |",i->first);
00650       cout << s;
00651       for( int j=0;j<i->second.size();j++ ) {
00652         sprintf(s," %10.4f |",i->second[j]);
00653         cout << s;
00654       }
00655       sprintf(s,"              %10.4f |",weightInClass[i->first]);
00656       cout << s << endl;
00657     }
00658     cout << temp.c_str() << endl;
00659   }
00660 
00661   // make histogram if requested
00662   if( tupleFile.empty() ) {
00663     prepareExit(criteria,destroyC,bstraps);
00664     return 0;
00665   }
00666 
00667   // make a writer
00668   auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00669   if( !tuple->init(tupleFile.c_str()) ) {
00670     cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00671     prepareExit(criteria,destroyC,bstraps);
00672     return 10;
00673   }
00674 
00675   // determine if certain variables are to be excluded from usage,
00676   // but included in the output storage file (-Z option)
00677   string printVarsDoNotFeed;
00678   vector<vector<string> > varsDoNotFeed;
00679   SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00680   vector<unsigned> mapper;
00681   for( int d=0;d<vars.size();d++ ) {
00682     if( varsDoNotFeed.empty() ||
00683         (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00684          ==varsDoNotFeed[0].end()) ) {
00685       mapper.push_back(d);
00686     }
00687     else {
00688       printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00689       printVarsDoNotFeed += vars[d];
00690     }
00691   }
00692   if( !printVarsDoNotFeed.empty() ) {
00693     cout << "The following variables are not used in the algorithm, " 
00694          << "but will be included in the output file: " 
00695          << printVarsDoNotFeed.c_str() << endl;
00696   }
00697 
00698   // feed
00699   SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00700   feeder.addMultiClassLearner(trainedMulti.get(),"multi");
00701   if( !feeder.feed(1000) ) {
00702     cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00703     prepareExit(criteria,destroyC,bstraps);
00704     return 11;
00705   }
00706 
00707   // cleanup
00708   prepareExit(criteria,destroyC,bstraps);
00709 
00710   // exit
00711   return 0;
00712 }

Generated on Tue Jun 9 17:41:59 2009 for CMSSW by  doxygen 1.5.4