00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassLearner.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
00022 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00026 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00027 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00028 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00029 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00030 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00031 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00032 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00033 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
00034 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00035 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00036 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassPlotter.hh"
00037 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00038 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00039 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00040
00041 #include <stdlib.h>
00042 #include <unistd.h>
00043 #include <iostream>
00044 #include <fstream>
00045 #include <vector>
00046 #include <set>
00047 #include <map>
00048 #include <string>
00049 #include <memory>
00050
00051 using namespace std;
00052
00053
00054 void help(const char* prog)
00055 {
00056 cout << "Usage: " << prog << " training_data_file" << endl;
00057 cout << "\t Options: " << endl;
00058 cout << "\t-h --- help " << endl;
00059 cout << "\t-o output Tuple file " << endl;
00060 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00061 cout << "\t-A save output data in ascii instead of Root " << endl;
00062 cout << "\t-y list of input classes " << endl;
00063 cout << "\t\t Classes must be listed in quotes and separated by commas."
00064 << endl;
00065 cout << "\t-Q apply variable transformation saved in file " << endl;
00066 cout << "\t-e Multi class mode " << endl;
00067 cout << "\t\t 1 - OneVsAll (default) " << endl;
00068 cout << "\t\t 2 - OneVsOne " << endl;
00069 cout << "\t\t 3 - user-defined (must use -i option) " << endl;
00070 cout << "\t-i input file with user-defined indicator matrix " << endl;
00071 cout << "\t-c file with trainable classifier configurations " << endl;
00072 cout << "\t-g per-event loss to be displayed for each input class " << endl;
00073 cout << "\t\t 1 - quadratic loss (y-f(x))^2 " << endl;
00074 cout << "\t\t 2 - exponential loss exp(-y*f(x)) " << endl;
00075 cout << "\t-m replace data values below this cutoff with medians " << endl;
00076 cout << "\t-v verbose level (0=silent default,1,2) " << endl;
00077 cout << "\t-f store trained multi class learner to file " << endl;
00078 cout << "\t-r read multi class learner configuration stored in file" << endl;
00079 cout << "\t-K keep this fraction in training set and " << endl;
00080 cout << "\t\t put the rest into validation set " << endl;
00081 cout << "\t-D randomize training set split-up " << endl;
00082 cout << "\t-t read validation/test data from a file " << endl;
00083 cout << "\t\t (must be in same format as input data!!! " << endl;
00084 cout << "\t-V include only these input variables " << endl;
00085 cout << "\t-z exclude input variables from the list " << endl;
00086 cout << "\t-Z exclude input variables from the list, "
00087 << "but put them in the output file " << endl;
00088 cout << "\t\t Variables must be listed in quotes and separated by commas."
00089 << endl;
00090 }
00091
00092
00093 void prepareExit(vector<SprAbsTwoClassCriterion*>& criteria,
00094 vector<SprAbsClassifier*>& classifiers,
00095 vector<SprIntegerBootstrap*>& bstraps)
00096 {
00097 for( int i=0;i<criteria.size();i++ ) delete criteria[i];
00098 for( int i=0;i<classifiers.size();i++ ) delete classifiers[i];
00099 for( int i=0;i<bstraps.size();i++ ) delete bstraps[i];
00100 }
00101
00102
00103 int main(int argc, char ** argv)
00104 {
00105
00106 if( argc < 2 ) {
00107 help(argv[0]);
00108 return 1;
00109 }
00110
00111
00112 string tupleFile;
00113 int readMode = 0;
00114 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00115 int verbose = 0;
00116 string outFile;
00117 string resumeFile;
00118 string configFile;
00119 string valFile;
00120 bool scaleWeights = false;
00121 double sW = 1.;
00122 bool setLowCutoff = false;
00123 double lowCutoff = 0;
00124 string includeList, excludeList;
00125 string inputClassesString;
00126 int iLoss = 1;
00127 int iMode = 1;
00128 string indicatorFile;
00129 string stringVarsDoNotFeed;
00130 bool split = false;
00131 double splitFactor = 0;
00132 bool splitRandomize = false;
00133 string transformerFile;
00134
00135
00136 int c;
00137 extern char* optarg;
00138
00139 while( (c = getopt(argc,argv,"ho:a:Ay:Q:e:i:c:g:m:v:f:r:K:Dt:V:z:Z:")) != EOF ) {
00140 switch( c )
00141 {
00142 case 'h' :
00143 help(argv[0]);
00144 return 1;
00145 case 'o' :
00146 tupleFile = optarg;
00147 break;
00148 case 'a' :
00149 readMode = (optarg==0 ? 0 : atoi(optarg));
00150 break;
00151 case 'A' :
00152 writeMode = SprRWFactory::Ascii;
00153 break;
00154 case 'y' :
00155 inputClassesString = optarg;
00156 break;
00157 case 'Q' :
00158 transformerFile = optarg;
00159 break;
00160 case 'e' :
00161 iMode = (optarg==0 ? 1 : atoi(optarg));
00162 break;
00163 case 'i' :
00164 indicatorFile = optarg;
00165 break;
00166 case 'c' :
00167 configFile = optarg;
00168 break;
00169 case 'g' :
00170 iLoss = (optarg==0 ? 1 : atoi(optarg));
00171 break;
00172 case 'm' :
00173 if( optarg != 0 ) {
00174 setLowCutoff = true;
00175 lowCutoff = atof(optarg);
00176 }
00177 break;
00178 case 'v' :
00179 verbose = (optarg==0 ? 0 : atoi(optarg));
00180 break;
00181 case 'f' :
00182 outFile = optarg;
00183 break;
00184 case 'r' :
00185 resumeFile = optarg;
00186 break;
00187 case 'K' :
00188 split = true;
00189 splitFactor = (optarg==0 ? 0 : atof(optarg));
00190 break;
00191 case 'D' :
00192 splitRandomize = true;
00193 break;
00194 case 't' :
00195 valFile = optarg;
00196 break;
00197 case 'w' :
00198 if( optarg != 0 ) {
00199 scaleWeights = true;
00200 sW = atof(optarg);
00201 }
00202 break;
00203 case 'V' :
00204 includeList = optarg;
00205 break;
00206 case 'z' :
00207 excludeList = optarg;
00208 break;
00209 case 'Z' :
00210 stringVarsDoNotFeed = optarg;
00211 break;
00212 }
00213 }
00214
00215
00216 if( configFile.empty() && resumeFile.empty()) {
00217 cerr << "No classifier configuration file specified." << endl;
00218 return 1;
00219 }
00220 if( !configFile.empty() && !resumeFile.empty() ) {
00221 cerr << "Cannot train and use saved configuration at the same time." << endl;
00222 return 1;
00223 }
00224
00225
00226 string trFile = argv[argc-1];
00227 if( trFile.empty() ) {
00228 cerr << "No training file is specified." << endl;
00229 return 1;
00230 }
00231
00232
00233 SprRWFactory::DataType inputType
00234 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00235 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00236
00237
00238 set<string> includeSet;
00239 if( !includeList.empty() ) {
00240 vector<vector<string> > includeVars;
00241 SprStringParser::parseToStrings(includeList.c_str(),includeVars);
00242 assert( !includeVars.empty() );
00243 for( int i=0;i<includeVars[0].size();i++ )
00244 includeSet.insert(includeVars[0][i]);
00245 if( !reader->chooseVars(includeSet) ) {
00246 cerr << "Unable to include variables in training set." << endl;
00247 return 2;
00248 }
00249 else {
00250 cout << "Following variables have been included in optimization: ";
00251 for( set<string>::const_iterator
00252 i=includeSet.begin();i!=includeSet.end();i++ )
00253 cout << "\"" << *i << "\"" << " ";
00254 cout << endl;
00255 }
00256 }
00257
00258
00259 set<string> excludeSet;
00260 if( !excludeList.empty() ) {
00261 vector<vector<string> > excludeVars;
00262 SprStringParser::parseToStrings(excludeList.c_str(),excludeVars);
00263 assert( !excludeVars.empty() );
00264 for( int i=0;i<excludeVars[0].size();i++ )
00265 excludeSet.insert(excludeVars[0][i]);
00266 if( !reader->chooseAllBut(excludeSet) ) {
00267 cerr << "Unable to exclude variables from training set." << endl;
00268 return 2;
00269 }
00270 else {
00271 cout << "Following variables have been excluded from optimization: ";
00272 for( set<string>::const_iterator
00273 i=excludeSet.begin();i!=excludeSet.end();i++ )
00274 cout << "\"" << *i << "\"" << " ";
00275 cout << endl;
00276 }
00277 }
00278
00279
00280 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00281 if( filter.get() == 0 ) {
00282 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00283 return 2;
00284 }
00285 vector<string> vars;
00286 filter->vars(vars);
00287 cout << "Read data from file " << trFile.c_str()
00288 << " for variables";
00289 for( int i=0;i<vars.size();i++ )
00290 cout << " \"" << vars[i].c_str() << "\"";
00291 cout << endl;
00292 cout << "Total number of points read: " << filter->size() << endl;
00293
00294
00295 if( inputClassesString.empty() ) {
00296 cerr << "No input classes specified." << endl;
00297 return 2;
00298 }
00299 vector<vector<int> > inputIntClasses;
00300 SprStringParser::parseToInts(inputClassesString.c_str(),inputIntClasses);
00301 if( inputIntClasses.empty() || inputIntClasses[0].size()<2 ) {
00302 cerr << "Found less than 2 classes in the input class string." << endl;
00303 return 2;
00304 }
00305 vector<SprClass> inputClasses(inputIntClasses[0].size());
00306 for( int i=0;i<inputIntClasses[0].size();i++ )
00307 inputClasses[i] = inputIntClasses[0][i];
00308
00309
00310 filter->chooseClasses(inputClasses);
00311 if( !filter->filter() ) {
00312 cerr << "Unable to filter training data by class." << endl;
00313 return 2;
00314 }
00315 cout << "Training data filtered by class." << endl;
00316 for( int i=0;i<inputClasses.size();i++ ) {
00317 unsigned npts = filter->ptsInClass(inputClasses[i]);
00318 if( npts == 0 ) {
00319 cerr << "Error!!! No points in class " << inputClasses[i] << endl;
00320 return 2;
00321 }
00322 cout << "Points in class " << inputClasses[i] << ": " << npts << endl;
00323 }
00324
00325
00326 if( setLowCutoff ) {
00327 if( !filter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00328 cerr << "Unable to replace missing values in training data." << endl;
00329 return 2;
00330 }
00331 else
00332 cout << "Values below " << lowCutoff << " in training data"
00333 << " have been replaced with medians." << endl;
00334 }
00335
00336
00337 auto_ptr<SprAbsFilter> valFilter;
00338 if( split && !valFile.empty() ) {
00339 cerr << "Unable to split training data and use validation data "
00340 << "from a separate file." << endl;
00341 return 2;
00342 }
00343 if( split ) {
00344 cout << "Splitting training data with factor " << splitFactor << endl;
00345 if( splitRandomize )
00346 cout << "Will use randomized splitting." << endl;
00347 vector<double> weights;
00348 SprData* splitted = filter->split(splitFactor,weights,splitRandomize);
00349 if( splitted == 0 ) {
00350 cerr << "Unable to split training data." << endl;
00351 return 2;
00352 }
00353 bool ownData = true;
00354 valFilter.reset(new SprEmptyFilter(splitted,weights,ownData));
00355 cout << "Training data re-filtered:" << endl;
00356 for( int i=0;i<inputClasses.size();i++ ) {
00357 cout << "Points in class " << inputClasses[i] << ": "
00358 << filter->ptsInClass(inputClasses[i]) << endl;
00359 }
00360 }
00361 if( !valFile.empty() ) {
00362 auto_ptr<SprAbsReader>
00363 valReader(SprRWFactory::makeReader(inputType,readMode));
00364 if( !includeSet.empty() ) {
00365 if( !valReader->chooseVars(includeSet) ) {
00366 cerr << "Unable to include variables in validation set." << endl;
00367 return 2;
00368 }
00369 }
00370 if( !excludeSet.empty() ) {
00371 if( !valReader->chooseAllBut(excludeSet) ) {
00372 cerr << "Unable to exclude variables from validation set." << endl;
00373 return 2;
00374 }
00375 }
00376 valFilter.reset(valReader->read(valFile.c_str()));
00377 if( valFilter.get() == 0 ) {
00378 cerr << "Unable to read data from file " << valFile.c_str() << endl;
00379 return 2;
00380 }
00381 vector<string> valVars;
00382 valFilter->vars(valVars);
00383 cout << "Read validation data from file " << valFile.c_str()
00384 << " for variables";
00385 for( int i=0;i<valVars.size();i++ )
00386 cout << " \"" << valVars[i].c_str() << "\"";
00387 cout << endl;
00388 cout << "Total number of points read: " << valFilter->size() << endl;
00389 }
00390
00391
00392 if( valFilter.get() != 0 ) {
00393 valFilter->chooseClasses(inputClasses);
00394 if( !valFilter->filter() ) {
00395 cerr << "Unable to filter validation data by class." << endl;
00396 return 2;
00397 }
00398 cout << "Validation data filtered by class." << endl;
00399 for( int i=0;i<inputClasses.size();i++ ) {
00400 unsigned npts = valFilter->ptsInClass(inputClasses[i]);
00401 if( npts == 0 )
00402 cerr << "Warning!!! No points in class " << inputClasses[i] << endl;
00403 cout << "Points in class " << inputClasses[i] << ": " << npts << endl;
00404 }
00405 }
00406
00407
00408 if( setLowCutoff && valFilter.get()!=0 ) {
00409 if( !valFilter->replaceMissing(SprUtils::lowerBound(lowCutoff),1) ) {
00410 cerr << "Unable to replace missing values in validation data." << endl;
00411 return 2;
00412 }
00413 else
00414 cout << "Values below " << lowCutoff << " in validation data"
00415 << " have been replaced with medians." << endl;
00416 }
00417
00418
00419 auto_ptr<SprAbsFilter> garbage_train, garbage_valid;
00420 if( !transformerFile.empty() ) {
00421 SprVarTransformerReader transReader;
00422 const SprAbsVarTransformer* t = transReader.read(transformerFile.c_str());
00423 if( t == 0 ) {
00424 cerr << "Unable to read VarTransformer from file "
00425 << transformerFile.c_str() << endl;
00426 return 2;
00427 }
00428 SprTransformerFilter* t_train = new SprTransformerFilter(filter.get());
00429 SprTransformerFilter* t_valid = 0;
00430 if( valFilter.get() != 0 )
00431 t_valid = new SprTransformerFilter(valFilter.get());
00432 bool replaceOriginalData = true;
00433 if( !t_train->transform(t,replaceOriginalData) ) {
00434 cerr << "Unable to apply VarTransformer to training data." << endl;
00435 return 2;
00436 }
00437 if( t_valid!=0 && !t_valid->transform(t,replaceOriginalData) ) {
00438 cerr << "Unable to apply VarTransformer to validation data." << endl;
00439 return 2;
00440 }
00441 cout << "Variable transformation from file "
00442 << transformerFile.c_str() << " has been applied to "
00443 << "training and validation data." << endl;
00444 garbage_train.reset(filter.release());
00445 garbage_valid.reset(valFilter.release());
00446 filter.reset(t_train);
00447 valFilter.reset(t_valid);
00448 }
00449
00450
00451 auto_ptr<SprTrainedMultiClassLearner> trainedMulti;
00452
00453
00454 vector<SprAbsTwoClassCriterion*> criteria;
00455 vector<SprAbsClassifier*> destroyC;
00456 vector<SprIntegerBootstrap*> bstraps;
00457 vector<SprCCPair> useC;
00458
00459
00460 if( !configFile.empty() ) {
00461 ifstream file(configFile.c_str());
00462 if( !file ) {
00463 cerr << "Unable to open file " << configFile.c_str() << endl;
00464 return 3;
00465 }
00466
00467
00468 unsigned nLine = 0;
00469 bool discreteTree = false;
00470 bool mixedNodesTree = false;
00471 bool fastSort = false;
00472 bool readOneEntry = true;
00473 if( !SprClassifierReader::readTrainableConfig(file,nLine,filter.get(),
00474 discreteTree,mixedNodesTree,
00475 fastSort,criteria,
00476 bstraps,destroyC,useC,
00477 readOneEntry) ) {
00478 cerr << "Unable to read classifier configurations from file "
00479 << configFile.c_str() << endl;
00480 prepareExit(criteria,destroyC,bstraps);
00481 return 3;
00482 }
00483 cout << "Finished reading " << useC.size() << " classifiers from file "
00484 << configFile.c_str() << endl;
00485 assert( useC.size() == 1 );
00486 SprAbsClassifier* trainable = useC[0].first;
00487
00488
00489 SprMultiClassLearner::MultiClassMode multiClassMode
00490 = SprMultiClassLearner::OneVsAll;
00491 switch( iMode )
00492 {
00493 case 1 :
00494 multiClassMode = SprMultiClassLearner::OneVsAll;
00495 cout << "Multi class learning mode set to OneVsAll." << endl;
00496 break;
00497 case 2 :
00498 multiClassMode = SprMultiClassLearner::OneVsOne;
00499 cout << "Multi class learning mode set to OneVsOne." << endl;
00500 break;
00501 case 3:
00502 if( indicatorFile.empty() ) {
00503 cerr << "No indicator matrix specified." << endl;
00504 return 4;
00505 }
00506 multiClassMode = SprMultiClassLearner::User;
00507 cout << "Multi class learning mode set to User." << endl;
00508 break;
00509 default :
00510 cerr << "No multi class learning mode chosen." << endl;
00511 prepareExit(criteria,destroyC,bstraps);
00512 return 4;
00513 }
00514
00515
00516 SprMatrix indicator;
00517 if( multiClassMode==SprMultiClassLearner::User
00518 && !indicatorFile.empty() ) {
00519 if( !SprMultiClassReader::readIndicatorMatrix(indicatorFile.c_str(),
00520 indicator) ) {
00521 cerr << "Unable to read indicator matrix from file "
00522 << indicatorFile.c_str() << endl;
00523 return 4;
00524 }
00525 }
00526
00527
00528 SprMultiClassLearner multi(filter.get(),trainable,inputIntClasses[0],
00529 indicator,multiClassMode);
00530
00531
00532 if( resumeFile.empty() ) {
00533 if( !multi.train(verbose) ) {
00534 cerr << "Unable to train Multi class learner." << endl;
00535 prepareExit(criteria,destroyC,bstraps);
00536 return 5;
00537 }
00538 else {
00539 trainedMulti.reset(multi.makeTrained());
00540 cout << "Multi class learner finished successfully." << endl;
00541 }
00542 }
00543
00544
00545 if( !outFile.empty() ) {
00546 if( !multi.store(outFile.c_str()) ) {
00547 cerr << "Cannot store multi class learner in file "
00548 << outFile.c_str() << endl;
00549 prepareExit(criteria,destroyC,bstraps);
00550 return 6;
00551 }
00552 }
00553 }
00554
00555
00556 if( !resumeFile.empty() ) {
00557 SprMultiClassReader multiReader;
00558 if( !multiReader.read(resumeFile.c_str()) ) {
00559 cerr << "Failed to read saved multi class learner from file "
00560 << resumeFile.c_str() << endl;
00561 prepareExit(criteria,destroyC,bstraps);
00562 return 7;
00563 }
00564 else {
00565 trainedMulti.reset(multiReader.makeTrained());
00566 cout << "Read saved multi class learner from file "
00567 << resumeFile.c_str() << endl;
00568 trainedMulti->printIndicatorMatrix(cout);
00569 }
00570 }
00571
00572
00573 if( trainedMulti.get() == 0 ) {
00574 cerr << "Trained multi learner has not been set." << endl;
00575 prepareExit(criteria,destroyC,bstraps);
00576 return 8;
00577 }
00578
00579
00580 switch( iLoss )
00581 {
00582 case 1 :
00583 trainedMulti->setLoss(&SprLoss::quadratic,
00584 &SprTransformation::zeroOneToMinusPlusOne);
00585 cout << "Per-event loss set to "
00586 << "Quadratic loss (y-f(x))^2 " << endl;
00587 break;
00588 case 2 :
00589 trainedMulti->setLoss(&SprLoss::exponential,
00590 &SprTransformation::logitInverse);
00591 cout << "Per-event loss set to "
00592 << "Exponential loss exp(-y*f(x)) " << endl;
00593 break;
00594 default :
00595 cerr << "No per-event loss specified." << endl;
00596 prepareExit(criteria,destroyC,bstraps);
00597 return 9;
00598 }
00599
00600
00601 if( valFilter.get() != 0 ) {
00602
00603
00604 vector<SprMultiClassPlotter::Response> responses(valFilter->size());
00605 for( int i=0;i<valFilter->size();i++ ) {
00606 if( ((i+1)%1000) == 0 )
00607 cout << "Computing response for validation point " << i+1 << endl;
00608
00609
00610 const SprPoint* p = (*(valFilter.get()))[i];
00611 int cls = p->class_;
00612 double w = valFilter->w(i);
00613
00614
00615 map<int,double> output;
00616 int resp = trainedMulti->response(p,output);
00617 responses[i] = SprMultiClassPlotter::Response(cls,w,resp,output);
00618 }
00619
00620
00621 SprMultiClassPlotter plotter(responses);
00622 vector<int> classes;
00623 trainedMulti->classes(classes);
00624 map<int,vector<double> > lossTable;
00625 map<int,double> weightInClass;
00626 double totalLoss = plotter.multiClassTable(classes,lossTable,
00627 weightInClass);
00628
00629
00630 cout << "=====================================" << endl;
00631 cout << "Overall validation misid fraction = " << totalLoss << endl;
00632 cout << "=====================================" << endl;
00633 cout << "Classification table: Fractions of total class weight" << endl;
00634 char s[200];
00635 sprintf(s,"True Class \\ Classification |");
00636 string temp = "------------------------------";
00637 cout << s;
00638 for( int i=0;i<classes.size();i++ ) {
00639 sprintf(s," %5i |",classes[i]);
00640 cout << s;
00641 temp += "-------------";
00642 }
00643 sprintf(s," Total weight in class |");
00644 temp += "-------------------------";
00645 cout << s << endl;
00646 cout << temp.c_str() << endl;
00647 for( map<int,vector<double> >::const_iterator
00648 i=lossTable.begin();i!=lossTable.end();i++ ) {
00649 sprintf(s,"%5i |",i->first);
00650 cout << s;
00651 for( int j=0;j<i->second.size();j++ ) {
00652 sprintf(s," %10.4f |",i->second[j]);
00653 cout << s;
00654 }
00655 sprintf(s," %10.4f |",weightInClass[i->first]);
00656 cout << s << endl;
00657 }
00658 cout << temp.c_str() << endl;
00659 }
00660
00661
00662 if( tupleFile.empty() ) {
00663 prepareExit(criteria,destroyC,bstraps);
00664 return 0;
00665 }
00666
00667
00668 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00669 if( !tuple->init(tupleFile.c_str()) ) {
00670 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00671 prepareExit(criteria,destroyC,bstraps);
00672 return 10;
00673 }
00674
00675
00676
00677 string printVarsDoNotFeed;
00678 vector<vector<string> > varsDoNotFeed;
00679 SprStringParser::parseToStrings(stringVarsDoNotFeed.c_str(),varsDoNotFeed);
00680 vector<unsigned> mapper;
00681 for( int d=0;d<vars.size();d++ ) {
00682 if( varsDoNotFeed.empty() ||
00683 (find(varsDoNotFeed[0].begin(),varsDoNotFeed[0].end(),vars[d])
00684 ==varsDoNotFeed[0].end()) ) {
00685 mapper.push_back(d);
00686 }
00687 else {
00688 printVarsDoNotFeed += ( printVarsDoNotFeed.empty() ? "" : ", " );
00689 printVarsDoNotFeed += vars[d];
00690 }
00691 }
00692 if( !printVarsDoNotFeed.empty() ) {
00693 cout << "The following variables are not used in the algorithm, "
00694 << "but will be included in the output file: "
00695 << printVarsDoNotFeed.c_str() << endl;
00696 }
00697
00698
00699 SprDataFeeder feeder(filter.get(),tuple.get(),mapper);
00700 feeder.addMultiClassLearner(trainedMulti.get(),"multi");
00701 if( !feeder.feed(1000) ) {
00702 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00703 prepareExit(criteria,destroyC,bstraps);
00704 return 11;
00705 }
00706
00707
00708 prepareExit(criteria,destroyC,bstraps);
00709
00710
00711 return 0;
00712 }