CMS 3D CMS Logo

SprRootAdapter.cc

Go to the documentation of this file.
00001 //$Id: SprRootAdapter.cc,v 1.4 2007/12/01 01:29:46 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprRootAdapter.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassLearner.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprSimpleReader.hh"
00014 #include "PhysicsTools/StatPatternRecognition/interface/SprRootReader.hh"
00015 #include "PhysicsTools/StatPatternRecognition/interface/SprPlotter.hh"
00016 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassPlotter.hh"
00017 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00018 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00019 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierReader.hh"
00020 #include "PhysicsTools/StatPatternRecognition/interface/SprMultiClassReader.hh"
00021 #include "PhysicsTools/StatPatternRecognition/interface/SprCoordinateMapper.hh"
00022 
00023 #include "PhysicsTools/StatPatternRecognition/interface/SprFisher.hh"
00024 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedFisher.hh"
00025 #include "PhysicsTools/StatPatternRecognition/interface/SprLogitR.hh"
00026 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedLogitR.hh"
00027 #include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
00028 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
00029 #include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
00030 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedDecisionTree.hh"
00031 #include "PhysicsTools/StatPatternRecognition/interface/SprAdaBoost.hh"
00032 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedAdaBoost.hh"
00033 #include "PhysicsTools/StatPatternRecognition/interface/SprStdBackprop.hh"
00034 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedStdBackprop.hh"
00035 #include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
00036 #include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
00037 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
00038 #include "PhysicsTools/StatPatternRecognition/interface/SprBinarySplit.hh"
00039 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBinarySplit.hh"
00040 #include "PhysicsTools/StatPatternRecognition/interface/SprBumpHunter.hh"
00041 
00042 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00043 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassSignalSignif.hh"
00044 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassIDFraction.hh"
00045 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassTaggerEff.hh"
00046 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPurity.hh"
00047 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassGiniIndex.hh"
00048 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassCrossEntropy.hh"
00049 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassUniformPriorUL90.hh"
00050 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassBKDiscovery.hh"
00051 #include "PhysicsTools/StatPatternRecognition/interface/SprTwoClassPunzi.hh"
00052 
00053 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00054 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00055 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00056 #include "PhysicsTools/StatPatternRecognition/interface/SprDataMoments.hh"
00057 #include "PhysicsTools/StatPatternRecognition/interface/SprClassifierEvaluator.hh"
00058 
00059 #include "PhysicsTools/StatPatternRecognition/interface/SprRootWriter.hh"
00060 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00061 
00062 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformerFilter.hh"
00063 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsVarTransformer.hh"
00064 #include "PhysicsTools/StatPatternRecognition/interface/SprVarTransformerReader.hh"
00065 #include "PhysicsTools/StatPatternRecognition/interface/SprPCATransformer.hh"
00066 
00067 #include "PhysicsTools/StatPatternRecognition/src/SprSymMatrix.hh"
00068 #include "PhysicsTools/StatPatternRecognition/src/SprVector.hh"
00069 
00070 #include <iostream>
00071 #include <cassert>
00072 #include <cmath>
00073 #include <utility>
00074 #include <algorithm>
00075 #include <memory>
00076 
00077 using namespace std;
00078 
00079 
00080 SprRootAdapter::~SprRootAdapter()
00081 {
00082   delete trainData_;
00083   delete testData_;
00084   delete trainGarbage_;
00085   delete testGarbage_;
00086   delete trans_;
00087   delete showCrit_;
00088   this->clearClassifiers();
00089 }
00090 
00091 
00092 SprRootAdapter::SprRootAdapter()
00093   :
00094   includeVars_(),
00095   excludeVars_(),
00096   trainData_(0),
00097   testData_(0),
00098   needToTest_(true),
00099   trainGarbage_(),
00100   testGarbage_(),
00101   trainable_(),
00102   trained_(),
00103   mcTrainable_(0),
00104   mcTrained_(0),
00105   mapper_(),
00106   mcMapper_(),
00107   trans_(0),
00108   plotter_(0),
00109   mcPlotter_(0),
00110   showCrit_(0),
00111   crit_(),
00112   bootstrap_(),
00113   aux_(),
00114   loss_()
00115 {}
00116 
00117 
00118 void SprRootAdapter::chooseVars(int nVars, const char vars[][200])
00119 {
00120   includeVars_.clear();
00121   for( int i=0;i<nVars;i++ )
00122     includeVars_.insert(vars[i]);
00123 }
00124 
00125 
00126 void SprRootAdapter::chooseAllBut(int nVars, const char vars[][200])
00127 {
00128   excludeVars_.clear();
00129   for( int i=0;i<nVars;i++ )
00130     excludeVars_.insert(vars[i]);
00131 }
00132 
00133 
00134 void SprRootAdapter::chooseAllVars()
00135 {
00136   includeVars_.clear();
00137   excludeVars_.clear();
00138 }
00139 
00140 
00141 bool SprRootAdapter::loadDataFromAscii(int mode, 
00142                                        const char* filename,
00143                                        const char* datatype)
00144 {
00145   // init reader
00146   SprSimpleReader reader(mode);
00147   if( !reader.chooseVars(includeVars_) 
00148       || !reader.chooseAllBut(excludeVars_) ) {
00149     cerr << "Unable to choose variables." << endl;
00150     return false;
00151   }
00152 
00153   // get data type
00154   string sdatatype = datatype;
00155   if(      sdatatype == "train" ) {
00156     cout << "Warning: training data will be reloaded." << endl;
00157     this->clearClassifiers();
00158     delete trainData_;
00159     delete trainGarbage_;
00160     trainGarbage_ = 0;
00161     trainData_ = reader.read(filename);
00162     if( trainData_ == 0 ) {
00163       cerr << "Failed to read training data from file " 
00164            << filename << endl;
00165       return false;
00166     }
00167     return true;
00168   }
00169   else if( sdatatype == "test" ) {
00170     cout << "Warning: test data will be reloaded." << endl;
00171     needToTest_ = true;
00172     delete testData_;
00173     delete testGarbage_;
00174     testGarbage_ = 0;
00175     testData_ = reader.read(filename);
00176     if( testData_ == 0 ) {
00177       cerr << "Failed to read test data from file " 
00178            << filename << endl;
00179       return false;
00180     }
00181     return true;
00182   }
00183   cerr << "Unknown data type. Must be train or test." << endl;
00184 
00185   // exit
00186   return false;
00187 }
00188 
00189 
00190 bool SprRootAdapter::loadDataFromRoot(const char* filename,
00191                                       const char* datatype)
00192 {
00193   SprRootReader reader;
00194   string sdatatype = datatype;
00195   if(      sdatatype == "train" ) {
00196     this->clearClassifiers();
00197     delete trainData_;
00198     delete trainGarbage_;
00199     trainGarbage_ = 0;
00200     trainData_ = reader.read(filename);
00201     if( trainData_ == 0 ) {
00202       cerr << "Failed to read training data from file " 
00203            << filename << endl;
00204       return false;
00205     }
00206     return true;
00207   }
00208   else if( sdatatype == "test" ) {
00209     needToTest_ = true;
00210     delete testData_;
00211     delete testGarbage_;
00212     testGarbage_ = 0;
00213     testData_ = reader.read(filename);
00214     if( testData_ == 0 ) {
00215       cerr << "Failed to read test data from file " 
00216            << filename << endl;
00217       return false;
00218     }
00219     return true;
00220   }
00221   cerr << "Unknown data type. Must be train or test." << endl;
00222   return false;
00223 }
00224 
00225 
00226 unsigned SprRootAdapter::dim() const
00227 {
00228   if( trainData_ == 0 ) {
00229     cerr << "Training data has not been loaded." << endl;
00230     return 0;
00231   }
00232   return trainData_->dim();
00233 }
00234 
00235 
00236 bool SprRootAdapter::vars(char vars[][200]) const
00237 {
00238   if( trainData_ == 0 ) {
00239     cerr << "Training data has not been loaded." << endl;
00240     return false;
00241   }
00242   vector<string> svars;
00243   trainData_->vars(svars);
00244   assert( svars.size() == trainData_->dim() );
00245   for( int i=0;i<svars.size();i++ )
00246     strcpy(vars[i],svars[i].c_str());
00247   return true;
00248 }
00249 
00250 
00251 unsigned SprRootAdapter::nClassifierVars(const char* classifierName) const
00252 {
00253   string sclassifier = classifierName;
00254   if( sclassifier == "MultiClassLearner" ) {
00255     if( mcTrained_ == 0 ) {
00256       cerr << "Classifier MultiClassLearner not found." << endl;
00257       return 0;
00258     }
00259     return mcTrained_->dim();
00260   }
00261   else {
00262     map<string,SprAbsTrainedClassifier*>::const_iterator found
00263       = trained_.find(sclassifier);
00264     if( found == trained_.end() ) {
00265       cerr << "Classifier " << sclassifier.c_str() << " not found." << endl;
00266       return 0;
00267     }
00268     return found->second->dim();
00269   }
00270   return 0;
00271 }
00272 
00273 
00274 bool SprRootAdapter::classifierVars(const char* classifierName, 
00275                                     char vars[][200]) const
00276 {
00277   string sclassifier = classifierName;
00278   vector<string> cVars;
00279   if( sclassifier == "MultiClassLearner" ) {
00280     if( mcTrained_ == 0 ) {
00281       cerr << "Classifier MultiClassLearner not found." << endl;
00282       return false;
00283     }
00284     mcTrained_->vars(cVars);
00285   }
00286   else {
00287     map<string,SprAbsTrainedClassifier*>::const_iterator found
00288       = trained_.find(sclassifier);
00289     if( found == trained_.end() ) {
00290       cerr << "Classifier " << sclassifier.c_str() << " not found." << endl;
00291       return false;
00292     }
00293     found->second->vars(cVars);
00294   }
00295   for( int i=0;i<cVars.size();i++ )
00296     strcpy(vars[i],cVars[i].c_str());
00297   return true;
00298 }
00299 
00300 
00301 int SprRootAdapter::nClasses() const
00302 {
00303   if( trainData_ == 0 ) return 0;
00304   vector<SprClass> classes;
00305   trainData_->classes(classes);
00306   return classes.size();
00307 }
00308 
00309 
00310 bool SprRootAdapter::chooseClasses(const char* inputClassString)
00311 {
00312   // sanity check
00313   if( trainData_ == 0 ) {
00314     cerr << "Training data has not been loaded." << endl;
00315     return false;
00316   }
00317   if( testData_ == 0 ) {
00318     cerr << "Test data has not been loaded." << endl;
00319     return false;
00320   }
00321 
00322   // set classes in data
00323   if( !trainData_->filterByClass(inputClassString) ) {
00324     cerr << "Unable to filter training data by class." << endl;
00325     return false;
00326   }
00327   if( !testData_->filterByClass(inputClassString) ) {
00328     cerr << "Unable to filter test data by class." << endl;
00329     return false;
00330   }
00331 
00332   // clean up
00333   this->clearClassifiers();
00334 
00335   // exit
00336   return true;
00337 }
00338 
00339 
00340 bool SprRootAdapter::scaleWeights(double w, const char* classtype)
00341 {
00342   if( !this->checkData() ) return false;
00343   vector<SprClass> classes;
00344   trainData_->classes(classes);
00345   if(      classtype == "signal" ) {
00346     trainData_->scaleWeights(classes[1],w);
00347     testData_->scaleWeights(classes[1],w);
00348   }
00349   else if( classtype == "background" ) {
00350     trainData_->scaleWeights(classes[0],w);
00351     testData_->scaleWeights(classes[0],w);
00352   }
00353   return true;
00354 }
00355 
00356 
00357 bool SprRootAdapter::split(double fractionForTraining, 
00358                            bool randomize, int seed)
00359 {
00360   // sanity check
00361   if( trainData_ == 0 ) {
00362     cerr << "Training data has not been loaded." << endl;
00363     return false;
00364   }
00365 
00366   // if test data was specified, issue a warning
00367   if( testData_ != 0 ) {
00368     cout << "Test data will be deleted." << endl;
00369     delete testData_;
00370     delete testGarbage_;
00371     testData_ = 0;
00372     testGarbage_ = 0;
00373   }
00374 
00375   // split training data
00376   vector<double> weights;
00377   SprData* splitted 
00378     = trainData_->split(fractionForTraining,weights,randomize,seed);
00379   if( splitted == 0 ) {
00380     cerr << "Unable to split training data." << endl;
00381     return false;
00382   }
00383 
00384   // make test data
00385   bool ownData = true;
00386   testData_ = new SprEmptyFilter(splitted,weights,ownData);
00387 
00388   // clear classifiers
00389   this->clearClassifiers();
00390   needToTest_ = true;
00391 
00392   // exit
00393   return true;
00394 }
00395 
00396 
00397 void SprRootAdapter::removeClassifier(const char* classifierName)
00398 {
00399   bool removed = false;
00400   string sclassifier = classifierName;
00401 
00402   // remove multi-class learner
00403   if( sclassifier == "MultiClassLearner" ) {
00404     if( mcTrainable_!=0 || mcTrained_!=0 )
00405       cout << "Removing multi-class learner." << endl;
00406     else
00407       cout << "Multi-class learner not found." << endl;
00408     delete mcTrainable_; mcTrainable_ = 0;
00409     delete mcTrained_; mcTrained_ = 0;
00410     delete mcMapper_; mcMapper_ = 0;
00411     return;
00412   }
00413 
00414   // remove trainable
00415   map<string,SprAbsClassifier*>::iterator i1 = trainable_.find(sclassifier);
00416   if( i1 != trainable_.end() ) {
00417     delete i1->second;
00418     trainable_.erase(i1);
00419     cout << "Removed trainable classifier " << sclassifier.c_str() << endl;
00420     removed = true;
00421   }
00422 
00423   // remove trained
00424   map<string,SprAbsTrainedClassifier*>::iterator i2 
00425     = trained_.find(sclassifier);
00426   if( i2 != trained_.end() ) {
00427     map<SprAbsTrainedClassifier*,SprCoordinateMapper*>::iterator im
00428       = mapper_.find(i2->second);
00429     if( im != mapper_.end() ) {
00430       delete im->second;
00431       mapper_.erase(im);
00432     }
00433     delete i2->second;
00434     trained_.erase(i2);
00435     cout << "Removed trained classifier " << sclassifier.c_str() << endl;
00436     removed = true;
00437   }
00438 
00439   // exit
00440   if( !removed ) {
00441     cout << "Unable to remove. Classifier " << sclassifier.c_str()
00442          << " not found." << endl;
00443   }
00444 }
00445 
00446 
00447 bool SprRootAdapter::saveClassifier(const char* classifierName,
00448                                     const char* filename) const
00449 {
00450   string sclassifier = classifierName;
00451 
00452   if( sclassifier == "MultiClassLearner" ) {
00453     if( mcTrained_ == 0 ) {
00454       cerr << "MultiClassLearner not found. Unable to save." << endl;
00455       return false;
00456     }
00457     if( !mcTrained_->store(filename) ) {    
00458       cerr << "Unable to store MultiClassLearner "
00459            << " into file " << filename << endl;
00460       return false;
00461     }
00462     return true;
00463   }
00464 
00465   map<string,SprAbsTrainedClassifier*>::const_iterator found 
00466     = trained_.find(sclassifier);
00467   if( found == trained_.end() ) {
00468     cerr << "Classifier " << sclassifier.c_str() << " not found." << endl;
00469     return false;
00470   }
00471   if( !found->second->store(filename) ) {
00472     cerr << "Unable to store classifier " << sclassifier.c_str()
00473          << " into file " << filename << endl;
00474     return false;
00475   }
00476   return true;
00477 }
00478 
00479 
00480 bool SprRootAdapter::loadClassifier(const char* classifierName,
00481                                     const char* filename)
00482 
00483 {
00484   // sanity check
00485   if( testData_ == 0 ) {
00486     cerr << "Test data has not been loaded." << endl;
00487     return false;
00488   }
00489 
00490   // string
00491   string sclassifier = classifierName;
00492 
00493   // load multi-class learner
00494   if( sclassifier == "MultiClassLearner" ) {
00495     if( mcTrained_ != 0 ) {
00496       cerr << "MultiClassLearner already exists. " 
00497            << "Unable to load." << endl;
00498       return false;
00499     }
00500     SprMultiClassReader reader;
00501     if( !reader.read(filename) ) {
00502       cerr << "Unable to read classifier from file " << filename << endl;
00503       return false;
00504     }
00505     mcTrained_ = reader.makeTrained();
00506     assert( mcTrained_ != 0 );
00507     return true;
00508   }
00509 
00510   // check if exists
00511   map<string,SprAbsTrainedClassifier*>::iterator found 
00512     = trained_.find(sclassifier);
00513   if( found != trained_.end() ) {
00514     cerr << "Classifier " << sclassifier.c_str() << " already exists. " 
00515          << "Unable to load." << endl;
00516     return false;
00517   }
00518 
00519   // read trained classifier
00520   SprAbsTrainedClassifier* t = SprClassifierReader::readTrained(filename);
00521   if( t == 0 ) {
00522     cerr << "Unable to read classifier from file " << filename << endl;
00523     return false;
00524   }
00525   if( !trained_.insert(pair<const string,
00526                        SprAbsTrainedClassifier*>(sclassifier,t)).second ) {
00527     cerr << "Unable to add classifier " << sclassifier.c_str() 
00528          << " to list." << endl;
00529     return false;
00530   }
00531 
00532   // exit
00533   return true;
00534 }
00535 
00536 
00537 bool SprRootAdapter::mapVars(SprAbsTrainedClassifier* t)
00538 {
00539   // sanity check
00540   assert( t != 0 );
00541   if( testData_ == 0 ) {
00542     cerr << "Test data has not been loaded." << endl;
00543     return false;
00544   }
00545 
00546   // get var lists
00547   vector<string> trainVars;
00548   vector<string> testVars;
00549   t->vars(trainVars);
00550   testData_->vars(testVars);
00551 
00552   // make mapper and insert if it does not exist yet
00553   SprCoordinateMapper* mapper 
00554     = SprCoordinateMapper::createMapper(trainVars,testVars);
00555   map<SprAbsTrainedClassifier*,SprCoordinateMapper*>::iterator
00556     found = mapper_.find(t);
00557   if( found == mapper_.end() ) {
00558     if( !mapper_.insert(pair<SprAbsTrainedClassifier* const,
00559                         SprCoordinateMapper*>(t,mapper)).second ) {
00560       cerr << "Unable to insert mapper." << endl;
00561       delete mapper;
00562       return false;
00563     }
00564   }
00565   else {
00566     delete found->second;
00567     found->second = mapper;
00568   }
00569 
00570   // exit
00571   return true;
00572 }
00573 
00574 
00575 bool SprRootAdapter::mapMCVars(const SprTrainedMultiClassLearner* t)
00576 {
00577   assert( t != 0 );
00578   if( testData_ == 0 ) {
00579     cerr << "Test data has not been loaded." << endl;
00580     return false;
00581   }
00582   vector<string> trainVars;
00583   vector<string> testVars;
00584   t->vars(trainVars);
00585   testData_->vars(testVars);
00586   delete mcMapper_;
00587   mcMapper_ = SprCoordinateMapper::createMapper(trainVars,testVars);
00588   return true;
00589 }
00590 
00591 
00592 void SprRootAdapter::clearPlotters()
00593 {
00594   delete plotter_; plotter_ = 0;
00595   delete mcPlotter_; mcPlotter_ = 0;
00596 }
00597 
00598 
00599 void SprRootAdapter::clearClassifiers()
00600 {
00601   // multiclass
00602   delete mcTrainable_;
00603   delete mcTrained_;
00604   delete mcMapper_;
00605   mcTrainable_ = 0;
00606   mcTrained_ = 0;
00607   mcMapper_ = 0;
00608 
00609   // others
00610   for( map<string,SprAbsClassifier*>::const_iterator 
00611          i=trainable_.begin();i!=trainable_.end();i++ )
00612     delete i->second;
00613   for( map<string,SprAbsTrainedClassifier*>::const_iterator 
00614          i=trained_.begin();i!=trained_.end();i++ )
00615     delete i->second;
00616   trainable_.clear();
00617   trained_.clear();
00618   for( map<SprAbsTrainedClassifier*,SprCoordinateMapper*>::const_iterator
00619          i=mapper_.begin();i!=mapper_.end();i++ )
00620     delete i->second;
00621   mapper_.clear();
00622   for( int i=0;i<crit_.size();i++ )
00623     delete crit_[i];
00624   crit_.clear();
00625   for( int i=0;i<bootstrap_.size();i++ )
00626     delete bootstrap_[i];
00627   bootstrap_.clear();  
00628   for( set<SprAbsClassifier*>::const_iterator 
00629          i=aux_.begin();i!=aux_.end();i++ ) delete *i;
00630   aux_.clear();
00631   for( int i=0;i<loss_.size();i++ )
00632     delete loss_[i];
00633   loss_.clear();
00634 
00635   // plotters
00636   this->clearPlotters();
00637 }
00638 
00639 
00640 bool SprRootAdapter::showDataInClasses(char classes[][200],
00641                                        int* events,
00642                                        double* weights,
00643                                        const char* datatype) const
00644 {
00645   // check if classes have been set
00646   if( trainData_ == 0 ) {
00647     cerr << "Training data has not been loaded." << endl;
00648     return false;
00649   }
00650   vector<SprClass> found;
00651   trainData_->classes(found);
00652   if( found.size() < 2 ) {
00653     cerr << "Classes have not been chosen." << endl;
00654     return false;
00655   }
00656 
00657   // check data type
00658   string sdatatype = datatype;
00659   SprAbsFilter* data = 0;
00660   if(      sdatatype == "train" )
00661     data = trainData_;
00662   else if( sdatatype == "test" )
00663     data = testData_;
00664   if( data == 0 ) {
00665     cerr << "Data of type " << sdatatype.c_str() 
00666          << " has not been loaded." << endl;
00667     return false;
00668   }
00669 
00670   // get events and weights
00671   for( int i=0;i<found.size();i++ ) {
00672     strcpy(classes[i],found[i].toString().c_str());
00673     events[i] = data->ptsInClass(found[i]);
00674     weights[i] = data->weightInClass(found[i]);
00675   }
00676 
00677   // exit
00678   return true;
00679 }
00680 
00681 
00682 SprAbsClassifier* SprRootAdapter::addFisher(const char* classifierName, 
00683                                             int mode)
00684 {
00685   if( !this->checkData() ) return 0;
00686   SprFisher* c = new SprFisher(trainData_,mode);
00687   if( !this->addTrainable(classifierName,c) ) return 0;
00688   return c;
00689 }
00690 
00691 
00692 SprAbsClassifier* SprRootAdapter::addLogitR(const char* classifierName,
00693                                             double eps,
00694                                             double updateFactor)
00695 {
00696   if( !this->checkData() ) return 0;
00697   SprLogitR* c = new SprLogitR(trainData_,eps,updateFactor);
00698   if( !this->addTrainable(classifierName,c) ) return 0;
00699   return c;
00700 }
00701 
00702 
00703 SprAbsClassifier* SprRootAdapter::addBumpHunter(const char* classifierName,
00704                                                 const char* criterion,
00705                                                 unsigned minEventsPerBump,
00706                                                 double peel)
00707 {
00708   // sanity check
00709   if( !this->checkData() ) return 0;
00710 
00711   // make criterion
00712   const SprAbsTwoClassCriterion* crit = SprRootAdapter::makeCrit(criterion);
00713   crit_.push_back(crit);
00714 
00715   // make bump hunter
00716   SprBumpHunter* c = new SprBumpHunter(trainData_,crit,1,
00717                                        minEventsPerBump,peel);
00718 
00719   // exit
00720   if( !this->addTrainable(classifierName,c) ) return 0;
00721   return c;
00722 }
00723 
00724 
00725 SprAbsClassifier* SprRootAdapter::addDecisionTree(const char* classifierName,
00726                                                   const char* criterion,
00727                                                   unsigned leafSize)
00728 {
00729   // sanity check
00730   if( !this->checkData() ) return 0;
00731 
00732   // make criterion
00733   const SprAbsTwoClassCriterion* crit = SprRootAdapter::makeCrit(criterion);
00734   crit_.push_back(crit);
00735 
00736   // params
00737   bool doMerge = !crit->symmetric();
00738   bool discrete = true;
00739 
00740   // make a tree
00741   SprDecisionTree* c = new SprDecisionTree(trainData_,crit,leafSize,
00742                                            doMerge,discrete,0);
00743 
00744   // exit
00745   if( !this->addTrainable(classifierName,c) ) return 0;
00746   return c;
00747 }
00748 
00749 
00750 SprAbsClassifier* SprRootAdapter::addTopdownTree(const char* classifierName,
00751                                                  const char* criterion,
00752                                                  unsigned leafSize,
00753                                                  unsigned nFeaturesToSample,
00754                                                  bool discrete)
00755 {
00756   // sanity check
00757   if( !this->checkData() ) return 0;
00758 
00759   // make criterion
00760   const SprAbsTwoClassCriterion* crit = SprRootAdapter::makeCrit(criterion);
00761   crit_.push_back(crit);
00762 
00763   // check
00764   bool doMerge = !crit->symmetric();
00765   if( doMerge ) {
00766     cout << "Warning: Merging has no effect for Topdown trees. "
00767          << "Use addDecisionTree() for asymmetric optimization criteria."
00768          << endl;
00769   }
00770 
00771   // params
00772   SprIntegerBootstrap* bs = 0;
00773   if( nFeaturesToSample > 0 ) {
00774     bs = new SprIntegerBootstrap(trainData_->dim(),nFeaturesToSample);
00775     bootstrap_.push_back(bs);
00776   }
00777   
00778   // make a tree
00779   SprTopdownTree* c = new SprTopdownTree(trainData_,crit,leafSize,
00780                                          discrete,bs);
00781 
00782   // exit
00783   if( !this->addTrainable(classifierName,c) ) return 0;
00784   return c;
00785 }
00786 
00787 
00788 SprAbsClassifier* SprRootAdapter::addStdBackprop(const char* classifierName,
00789                                                  const char* structure,
00790                                                  unsigned ncycles,
00791                                                  double eta,
00792                                                  double initEta,
00793                                                  unsigned nInitPoints,
00794                                                  unsigned nValidate)
00795 {
00796   // sanity check
00797   if( !this->checkData() ) return 0;
00798   
00799   // make neural net
00800   SprStdBackprop* c = new SprStdBackprop(trainData_,
00801                                          structure,
00802                                          ncycles,
00803                                          eta);
00804   if( !c->init(initEta,nInitPoints) ) {
00805     cerr << "Unable to initialize neural net." << endl;
00806     return 0;
00807   } 
00808   if( nValidate > 0 ) {
00809     if( testData_==0 || !c->setValidation(testData_,nValidate) ) {
00810         cout << "Unable to set validation data for classifier "
00811              << classifierName << endl;
00812     }
00813   }
00814 
00815   // exit
00816   if( !this->addTrainable(classifierName,c) ) return 0;
00817   return c;
00818 }
00819 
00820 
00821 SprAbsClassifier* SprRootAdapter::addAdaBoost(const char* classifierName,
00822                                               int nClassifier,
00823                                               SprAbsClassifier** classifier,
00824                                               bool* useCut,
00825                                               double* cut,
00826                                               unsigned ncycles,
00827                                               int mode,
00828                                               bool bagInput,
00829                                               double epsilon,
00830                                               unsigned nValidate)
00831 {
00832   // sanity check
00833   if( !this->checkData() ) return 0;
00834 
00835   // make AdaBoost mode
00836   SprTrainedAdaBoost::AdaBoostMode abMode = SprTrainedAdaBoost::Discrete;
00837   switch( mode )
00838     {
00839     case 1 :
00840       abMode = SprTrainedAdaBoost::Discrete;
00841       cout << "Will train Discrete AdaBoost." << endl;
00842       break;
00843     case 2 :
00844       abMode = SprTrainedAdaBoost::Real;
00845       cout << "Will train Real AdaBoost." << endl;
00846       break;
00847     case 3 :
00848       abMode = SprTrainedAdaBoost::Epsilon;
00849       cout << "Will train Epsilon AdaBoost." << endl;
00850       break;
00851    default :
00852       cout << "Will train Discrete AdaBoost." << endl;
00853       break;
00854     }
00855   
00856   // make AdaBoost
00857   bool useStandard = false;
00858   SprAdaBoost* c = new SprAdaBoost(trainData_,ncycles,
00859                                    useStandard,abMode,bagInput);
00860   c->setEpsilon(epsilon);
00861   if( nValidate > 0 ) {
00862     SprAverageLoss* loss = new SprAverageLoss(&SprLoss::exponential);
00863     loss_.push_back(loss);
00864     if( testData_==0 || !c->setValidation(testData_,nValidate,loss) ) {
00865         cout << "Unable to set validation data for classifier "
00866              << classifierName << endl;
00867     }
00868   }
00869 
00870   // add weak classifiers
00871   for( int i=0;i<nClassifier;i++ ) {
00872     bool status = false;
00873     if( useCut[i] )
00874       status = c->addTrainable(classifier[i],SprUtils::lowerBound(cut[i]));
00875     else
00876       status = c->addTrainable(classifier[i]);
00877     if( !status ) {
00878       cerr << "Unable to add classifier " << i << " to AdaBoost." << endl;
00879       return 0;
00880     }
00881   }
00882 
00883   // exit
00884   if( !this->addTrainable(classifierName,c) ) return 0;
00885   return c;
00886 }
00887 
00888 
00889 
00890 SprAbsClassifier* SprRootAdapter::addBagger(const char* classifierName,
00891                                             int nClassifier,
00892                                             SprAbsClassifier** classifier,
00893                                             unsigned ncycles,
00894                                             bool discrete,
00895                                             unsigned nValidate)
00896 {
00897   // sanity check
00898   if( !this->checkData() ) return 0;
00899 
00900   // make bagger
00901   SprBagger* c = new SprBagger(trainData_,ncycles,discrete);
00902   if( nValidate > 0 ) {
00903     SprAverageLoss* loss = new SprAverageLoss(&SprLoss::quadratic);
00904     loss_.push_back(loss);
00905     if( testData_==0 || !c->setValidation(testData_,nValidate,0,loss) ) {
00906         cout << "Unable to set validation data for classifier "
00907              << classifierName << endl;
00908     }
00909   }
00910 
00911   // add weak classifiers
00912   for( int i=0;i<nClassifier;i++ ) {
00913     if( !c->addTrainable(classifier[i]) ) {
00914       cerr << "Unable to add classifier " << i << " to Bagger." << endl;
00915       return 0;
00916     }
00917   }
00918 
00919   // exit
00920   if( !this->addTrainable(classifierName,c) ) return 0;
00921   return c;
00922 }
00923 
00924 
00925 SprAbsClassifier* SprRootAdapter::addBoostedDecisionTree(
00926                                             const char* classifierName,
00927                                             int leafSize,
00928                                             unsigned nTrees,
00929                                             unsigned nValidate)
00930 {
00931   // sanity check
00932   if( !this->checkData() ) return 0;
00933 
00934   // make a decision tree
00935   const SprAbsTwoClassCriterion* crit = new SprTwoClassGiniIndex;
00936   crit_.push_back(crit);
00937   bool doMerge = false;
00938   bool discrete = true;
00939   SprTopdownTree* tree = new SprTopdownTree(trainData_,crit,leafSize,
00940                                             discrete,0);
00941   aux_.insert(tree);
00942   
00943   // make AdaBoost
00944   bool useStandard = false;
00945   bool bagInput = false;
00946   SprAdaBoost* c = new SprAdaBoost(trainData_,nTrees,useStandard,
00947                                    SprTrainedAdaBoost::Discrete,bagInput);
00948   if( nValidate > 0 ) {
00949     SprAverageLoss* loss = new SprAverageLoss(&SprLoss::exponential);
00950     loss_.push_back(loss);
00951     if( testData_==0 || !c->setValidation(testData_,nValidate,loss) ) {
00952         cout << "Unable to set validation data for classifier "
00953              << classifierName << endl;
00954     }
00955   }
00956 
00957   // add classifier
00958   if( !c->addTrainable(tree,SprUtils::lowerBound(0.5)) ) {
00959     cerr << "Cannot add decision tree to AdaBoost." << endl;
00960     return 0;
00961   }
00962 
00963   // exit
00964   if( !this->addTrainable(classifierName,c) ) return 0;
00965   return c;
00966 }
00967 
00968 
00969 SprAbsClassifier* SprRootAdapter::addBoostedBinarySplits(
00970                                           const char* classifierName,
00971                                           unsigned nSplitsPerDim,
00972                                           unsigned nValidate)
00973 {
00974   // sanity check
00975   if( !this->checkData() ) return 0;
00976 
00977   // make AdaBoost
00978   bool useStandard = false;
00979   bool bagInput = false;
00980   SprAdaBoost* c = new SprAdaBoost(trainData_,
00981                                    nSplitsPerDim*trainData_->dim(),
00982                                    useStandard,
00983                                    SprTrainedAdaBoost::Discrete,bagInput);
00984   if( nValidate > 0 ) {
00985     SprAverageLoss* loss = new SprAverageLoss(&SprLoss::exponential);
00986     loss_.push_back(loss);
00987     if( testData_==0 || !c->setValidation(testData_,nValidate,loss) ) {
00988         cout << "Unable to set validation data for classifier "
00989              << classifierName << endl;
00990     }
00991   }
00992 
00993   // add splits to AdaBoost
00994   const SprAbsTwoClassCriterion* crit = new SprTwoClassIDFraction;
00995   crit_.push_back(crit);
00996   for( int i=0;i<trainData_->dim();i++ ) {
00997     SprBinarySplit* split = new SprBinarySplit(trainData_,crit,i);
00998     aux_.insert(split);
00999     if( !c->addTrainable(split,SprUtils::lowerBound(0.5)) ) {
01000       cerr << "Cannot add binary split to AdaBoost." << endl;
01001       delete c;
01002       return 0;
01003     }
01004   }
01005 
01006   // exit
01007   if( !this->addTrainable(classifierName,c) ) return 0;
01008   return c;
01009 }
01010 
01011 
01012 SprAbsClassifier* SprRootAdapter::addRandomForest(const char* classifierName,
01013                                                   int leafSize,
01014                                                   unsigned nTrees,
01015                                                   unsigned nFeaturesToSample,
01016                                                   unsigned nValidate,
01017                                                   bool useArcE4)
01018 {
01019   // sanity check
01020   if( !this->checkData() ) return 0;
01021 
01022   // make a decision tree
01023   const SprAbsTwoClassCriterion* crit = new SprTwoClassGiniIndex;
01024   crit_.push_back(crit);
01025   SprIntegerBootstrap* bs = 0;
01026   if( nFeaturesToSample > 0 ) {
01027     bs = new SprIntegerBootstrap(trainData_->dim(),nFeaturesToSample);
01028     bootstrap_.push_back(bs);
01029   }
01030   bool doMerge = false;
01031   bool discrete = false;
01032   SprTopdownTree* tree = new SprTopdownTree(trainData_,crit,leafSize,
01033                                             discrete,bs);
01034   aux_.insert(tree);
01035   
01036   // make Bagger
01037   SprBagger* c = 0;
01038   if( useArcE4 )
01039     c = new SprArcE4(trainData_,nTrees,discrete);
01040   else
01041     c = new SprBagger(trainData_,nTrees,discrete);
01042   if( nValidate > 0 ) {
01043     SprAverageLoss* loss = new SprAverageLoss(&SprLoss::quadratic);
01044     loss_.push_back(loss);
01045     if( testData_==0 || !c->setValidation(testData_,nValidate,0,loss) ) {
01046         cout << "Unable to set validation data for classifier "
01047              << classifierName << endl;
01048     }
01049   }
01050 
01051   // add classifier
01052   if( !c->addTrainable(tree) ) {
01053     cerr << "Cannot add decision tree to RandomForest." << endl;
01054     return 0;
01055   }
01056 
01057   // exit
01058   if( !this->addTrainable(classifierName,c) ) return 0;
01059   return c;
01060 }
01061 
01062 
01063 SprMultiClassLearner* SprRootAdapter::setMultiClassLearner(
01064                                              SprAbsClassifier* classifier,
01065                                              int nClass,
01066                                              const int* classes,
01067                                              const char* mode)
01068 {
01069   // sanity check
01070   if( !this->checkData() ) return 0;
01071 
01072   // check if there is a multi-class learner already
01073   if( mcTrainable_ != 0 ) {
01074     cerr << "MultiClassLearner already exists. "
01075          << "Must delete before making a new one." << endl;
01076     return 0;
01077   }
01078 
01079   // prepare vector of classes
01080   assert( nClass > 0 );
01081   vector<int> vclasses(&classes[0],&classes[nClass]);
01082 
01083   // decode mode
01084   string smode = mode;
01085   SprMultiClassLearner::MultiClassMode mcMode = SprMultiClassLearner::User;
01086   if(      smode == "One-vs-All" )
01087     mcMode = SprMultiClassLearner::OneVsAll;
01088   else if( smode == "One-vs-One" )
01089     mcMode = SprMultiClassLearner::OneVsOne;
01090   else {
01091     cerr << "Unknown mode for MultiClassLearner." << endl;
01092     return 0;
01093   }
01094 
01095   // make the learner
01096   SprMatrix indicator;
01097   mcTrainable_ = new SprMultiClassLearner(trainData_,classifier,vclasses,
01098                                           indicator,mcMode);
01099 
01100   // move the classifier from trainable list to aux
01101   for( map<std::string,SprAbsClassifier*>::iterator i=trainable_.begin();
01102        i!=trainable_.end();i++ ) {
01103     if( i->second == classifier ) 
01104       trainable_.erase(i);
01105   }
01106   aux_.insert(classifier);
01107 
01108   // exit
01109   return mcTrainable_;
01110 }
01111 
01112 
01113 bool SprRootAdapter::checkData() const
01114 {
01115   if( trainData_ == 0 ) {
01116     cerr << "Training data has not been loaded." << endl;
01117     return false;
01118   }
01119   if( testData_ == 0 ) {
01120     cerr << "Test data has not been loaded." << endl;
01121     return false;
01122   }
01123   vector<SprClass> classes;
01124   trainData_->classes(classes);
01125   if( classes.size() < 2 ) {
01126     cerr << "Classes have not been chosen." << endl;
01127     return false;
01128   }
01129   return true;
01130 }
01131 
01132 
01133 bool SprRootAdapter::addTrainable(const char* classifierName, 
01134                                   SprAbsClassifier* c)
01135 {
01136   assert( c != 0 );
01137   string sclassifier = classifierName;
01138 
01139   // check that classifier does not exist
01140   map<string,SprAbsClassifier*>::const_iterator found =
01141     trainable_.find(sclassifier);
01142   if( found != trainable_.end() ) {
01143     cerr << "Classifier " << sclassifier.c_str() 
01144          << " already exists." << endl;
01145     delete c;
01146     return false;
01147   }
01148 
01149   // add
01150   if( !trainable_.insert(pair<const string,
01151                               SprAbsClassifier*>(sclassifier,c)).second ) {
01152     cerr << "Unable to add classifier " << sclassifier.c_str() << endl;
01153     delete c;
01154     return false;
01155   }
01156 
01157   // exit
01158   return true;
01159 }
01160 
01161 
01162 void SprRootAdapter::useInftyRange() const
01163 {
01164   for( map<string,SprAbsTrainedClassifier*>::const_iterator
01165          i=trained_.begin();i!=trained_.end();i++ ) {
01166     SprAbsTrainedClassifier* trained = i->second;
01167     if(      trained->name() == "AdaBoost" ) {
01168       SprTrainedAdaBoost* specific 
01169         = static_cast<SprTrainedAdaBoost*>(trained);
01170       specific->useStandard();
01171     }
01172     else if( trained->name() == "Fisher" ) {
01173       SprTrainedFisher* specific 
01174         = static_cast<SprTrainedFisher*>(trained);
01175       specific->useStandard();
01176     }
01177     else if( trained->name() == "LogitR" ) {
01178       SprTrainedLogitR* specific 
01179         = static_cast<SprTrainedLogitR*>(trained);
01180       specific->useStandard();
01181     }
01182   }
01183 }
01184 
01185 
01186 void SprRootAdapter::use01Range() const
01187 {
01188   for( map<string,SprAbsTrainedClassifier*>::const_iterator
01189          i=trained_.begin();i!=trained_.end();i++ ) { 
01190     SprAbsTrainedClassifier* trained = i->second; 
01191     if(      trained->name() == "AdaBoost" ) { 
01192       SprTrainedAdaBoost* specific  
01193         = static_cast<SprTrainedAdaBoost*>(trained);
01194       specific->useNormalized(); 
01195     } 
01196     else if( trained->name() == "Fisher" ) { 
01197       SprTrainedFisher* specific  
01198         = static_cast<SprTrainedFisher*>(trained);
01199       specific->useNormalized(); 
01200     } 
01201     else if( trained->name() == "LogitR" ) { 
01202       SprTrainedLogitR* specific  
01203         = static_cast<SprTrainedLogitR*>(trained);
01204       specific->useNormalized(); 
01205     } 
01206   }
01207 }
01208 
01209 
01210 bool SprRootAdapter::train(int verbose)
01211 {
01212   // sanity check
01213   if( !this->checkData() ) return false;
01214   if( trainable_.empty() && mcTrainable_==0 ) {
01215     cerr << "No classifiers selected for training." << endl;
01216     return false;
01217   }
01218 
01219   // clean up responses
01220   this->clearPlotters();
01221 
01222   // train
01223   bool oneSuccess = false;
01224   for( map<string,SprAbsClassifier*>::const_iterator
01225         i=trainable_.begin();i!=trainable_.end();i++ ) {
01226     if( trained_.find(i->first) != trained_.end() ) {
01227       cout << "Trained classifier " << i->first.c_str()
01228            << " already exists. Skipping..." << endl;
01229       continue;
01230     }
01231     cout << "Training classifier " << i->first.c_str() << endl;
01232     if( !i->second->train(verbose) ) {
01233       cerr << "Unable to train classifier " << i->first.c_str() << endl;
01234       continue;
01235     }
01236     SprAbsTrainedClassifier* t = i->second->makeTrained();
01237     if( t == 0 ) {
01238       cerr << "Failed to make trained classifier " << i->first.c_str() << endl;
01239       continue;
01240     }
01241     if( !trained_.insert(pair<const string,
01242                          SprAbsTrainedClassifier*>(i->first,t)).second ) {
01243       cerr << "Failed to insert trained classifier." << endl;
01244       return false;
01245     }
01246     oneSuccess = true;
01247   }
01248 
01249   // multi class learner
01250   if( mcTrainable_ != 0 ) {
01251     if( mcTrained_ == 0 ) {
01252       cout << "Training MultiClassLearner." << endl;
01253       if( mcTrainable_->train(verbose) ) {
01254         mcTrained_ = mcTrainable_->makeTrained();
01255         if( mcTrained_ == 0 )
01256           cerr << "Failed to make trained MultiClassLearner." << endl;
01257         else
01258           oneSuccess = true;
01259       }
01260       else
01261         cerr << "Failed to train MultiClassLearner." << endl;
01262     }
01263     else
01264       cout << "Trained MultiClassLearner already exists. Skipping..." << endl;
01265   }
01266 
01267   // check if any classifiers succeeded
01268   if( !oneSuccess ) {
01269     cerr << "No classifiers have been trained successfully." << endl;
01270     return false;
01271   }
01272 
01273   // exit
01274   return true;
01275 }
01276 
01277 
01278 bool SprRootAdapter::test()
01279 {
01280   // sanity check
01281   if( trained_.empty() && mcTrained_==0 ) {
01282     cerr << "No classifiers have been trained." << endl;
01283     return false;
01284   }
01285   if( testData_==0 || testData_->empty() ) {
01286     cerr << "No test data available." << endl;
01287     return false;
01288   }
01289 
01290   // check classes
01291   vector<SprClass> classes;
01292   testData_->classes(classes);
01293   if( classes.size() < 2 ) {
01294     cerr << "Less than 2 classes found in test data." << endl;
01295     return false;
01296   }
01297 
01298   // cleaned up responses
01299   this->clearPlotters();
01300 
01301   // get data size
01302   int N = testData_->size();
01303 
01304   // all two-class classifiers
01305   if( !trained_.empty() ) {
01306     // map variables
01307     for( map<string,SprAbsTrainedClassifier*>::const_iterator 
01308          i=trained_.begin();i!=trained_.end();i++ ) {
01309       if( !this->mapVars(i->second) ) {
01310         cerr << "Unable to map variables for classifier " 
01311              << i->first.c_str() << endl;
01312         return false;
01313       }
01314     }
01315 
01316     // compute responses
01317     vector<SprPlotter::Response> responses;
01318     for( int n=0;n<N;n++ ) {
01319       const SprPoint* p = (*testData_)[n];
01320       int cls = -1;
01321       if(      classes[0] == p->class_ ) 
01322         cls = 0;
01323       else if( classes[1] == p->class_ )
01324         cls = 1;
01325       else
01326         continue;
01327       double w = testData_->w(n);
01328       SprPlotter::Response resp(cls,w);
01329       for( map<string,SprAbsTrainedClassifier*>::const_iterator
01330              i=trained_.begin();i!=trained_.end();i++ ) {
01331         vector<double> mapped;
01332         map<SprAbsTrainedClassifier*,SprCoordinateMapper*>::const_iterator 
01333           found = mapper_.find(i->second);
01334         assert( found != mapper_.end() );
01335         found->second->map(p->x_,mapped);
01336         resp.set(i->first.c_str(),i->second->response(mapped));
01337       }
01338       responses.push_back(resp);
01339     }
01340     
01341     // make plotter
01342     plotter_ = new SprPlotter(responses);
01343     plotter_->setCrit(showCrit_);
01344   }
01345 
01346   // multi class
01347   if( mcTrained_ != 0 ) {
01348     if( !this->mapMCVars(mcTrained_) ) {
01349       cerr << "Unable to map variables for classifier MultiClassLearner." 
01350            << endl;
01351       return false;
01352     }
01353     vector<int> mcClasses;
01354     mcTrained_->classes(mcClasses);
01355     vector<SprMultiClassPlotter::Response> responses;
01356     for( int n=0;n<N;n++ ) {
01357       const SprPoint* p = (*testData_)[n];
01358       int cls = p->class_;
01359       if( find(mcClasses.begin(),mcClasses.end(),cls) == mcClasses.end() )
01360         continue;
01361       double w = testData_->w(n);
01362       vector<double> mapped;
01363       assert( mcMapper_ != 0 );
01364       mcMapper_->map(p->x_,mapped);
01365       map<int,double> output;
01366       int assigned = mcTrained_->response(mapped,output);
01367       responses.push_back(SprMultiClassPlotter::Response(cls,w,
01368                                                          assigned,output));
01369     }
01370     mcPlotter_ = new SprMultiClassPlotter(responses);
01371   }
01372 
01373   // exit
01374   needToTest_ = false;
01375   return true;
01376 }
01377 
01378 
01379 bool SprRootAdapter::setCrit(const char* criterion)
01380 {
01381   showCrit_ = SprRootAdapter::makeCrit(criterion);
01382   if( showCrit_ == 0 ) return false;
01383   if( plotter_ != 0 ) 
01384     plotter_->setCrit(showCrit_);
01385   return true;
01386 }
01387 
01388 
01389 bool SprRootAdapter::setEffCurveMode(const char* mode)
01390 {
01391   string smode = mode;
01392   if( plotter_ == 0 ) {
01393     cerr << "Unable to set the efficiency plotting mode. "
01394          << "Run test() first to fill out the plotter." << endl;
01395     return false;
01396   }
01397   if(      smode == "relative" )
01398     plotter_->useRelative();
01399   else if( smode == "absolute" )
01400     plotter_->useAbsolute();
01401   else {
01402     cerr << "Unknown mode for efficiency curve." << endl;
01403     return false;
01404   }
01405   return true;
01406 }
01407 
01408 
01409 bool SprRootAdapter::effCurve(const char* classifierName,
01410                               int npts, const double* signalEff,
01411                               double* bgrndEff, double* bgrndErr, double* fom) 
01412   const
01413 {
01414   string sclassifier = classifierName;
01415 
01416   // sanity check
01417   if( npts == 0 ) return true;
01418   if( plotter_ == 0 ) {
01419     cerr << "No responses for test data have been computed. " 
01420          << "Run test() first." << endl;
01421     return false;
01422   }
01423 
01424   // make vector of signal efficiencies
01425   vector<double> vSignalEff(npts);
01426   for( int i=0;i<npts;i++ )
01427     vSignalEff[i] = signalEff[i];
01428 
01429   // compute the curve
01430   vector<SprPlotter::FigureOfMerit> vBgrndEff;
01431   if( !plotter_->backgroundCurve(vSignalEff,sclassifier.c_str(),vBgrndEff) ) {
01432     cerr << "Unable to compute the background curve for classifier "
01433          << sclassifier.c_str() << endl;
01434     return false;
01435   }
01436   assert( vBgrndEff.size() == npts );
01437 
01438   // convert the vector into arrays
01439   double bgrW = plotter_->bgrndWeight();
01440   for( int i=0;i<npts;i++ ) {
01441     bgrndEff[i] = vBgrndEff[i].bgrWeight;
01442     bgrndErr[i] = ( vBgrndEff[i].bgrNevts==0 ? 0 
01443                     : bgrndEff[i]/sqrt(double(vBgrndEff[i].bgrNevts)) );
01444     fom[i] = vBgrndEff[i].fom;
01445   }
01446 
01447   // exit
01448   return true;
01449 }
01450 
01451 
01452 bool SprRootAdapter::allEffCurves(int npts, const double* signalEff,
01453                                   char classifiers[][200],
01454                                   double* bgrndEff, double* bgrndErr,
01455                                   double* fom) const
01456 {
01457   if( trained_.empty() || plotter_==0 ) {
01458     cerr << "Efficiency curves cannot be computed." << endl;
01459     return false;
01460   }
01461   double* eff = bgrndEff;
01462   double* err = bgrndErr;
01463   double* myfom = fom;
01464   int curr = 0;
01465   for( map<string,SprAbsTrainedClassifier*>::const_iterator
01466          i=trained_.begin();i!=trained_.end();i++ ) {
01467     if( !this->effCurve(i->first.c_str(),npts,signalEff,eff,err,myfom) ) {
01468       cerr << "Unable to compute efficiency for classifier "
01469            << i->first.c_str() << endl;
01470       return false;
01471     }
01472     strcpy(classifiers[curr++],i->first.c_str());
01473     eff += npts;
01474     err += npts;
01475     myfom += npts;
01476   }
01477   return true;
01478 }
01479 
01480 
01481 bool SprRootAdapter::correlation(int cls, double* corr, const char* datatype) 
01482   const
01483 {
01484   // sanity check
01485   string sdatatype = datatype;
01486   SprAbsFilter* data = 0;
01487   if(      sdatatype == "train" )
01488     data = trainData_;
01489   else if( sdatatype == "test" )
01490     data = testData_;
01491   if( data == 0 ) {
01492     cerr << "Data of type " << sdatatype.c_str()
01493          << " has not been loaded." << endl;
01494     return false;
01495   }
01496 
01497   // make a temp copy of data
01498   SprEmptyFilter tempData(data);
01499 
01500   // check classes
01501   vector<SprClass> classes;
01502   tempData.classes(classes);
01503   if( (cls+1) > classes.size() ) {
01504     cerr << "Class " << cls << " is not found in data." << endl;
01505     return false;
01506   }
01507   SprClass chosenClass = classes[cls];
01508 
01509   // filter data by class
01510   vector<SprClass> chosen(1,chosenClass);
01511   tempData.chooseClasses(chosen);
01512   if( !tempData.filter() ) {
01513     cerr << "Unable to filter data on class " << cls << endl;
01514     return false;
01515   }
01516 
01517   // compute
01518   SprDataMoments moms(&tempData);
01519   SprSymMatrix cov;
01520   SprVector mean;
01521   if( !moms.covariance(cov,mean) ) {
01522     cerr << "Unable to compute covariance matrix." << endl;
01523     return false;
01524   }
01525 
01526   // compute variances
01527   int N = cov.num_row();
01528   vector<double> rms(N);
01529   vector<int> positive(N,0);
01530   for( int i=0;i<N;i++ ) {
01531     if( cov[i][i] < SprUtils::eps() ) {
01532       cout << "Variance for variable " << i << " is zero." << endl;
01533       rms[i] = 0;
01534     }
01535     else {
01536       rms[i] = sqrt(cov[i][i]);
01537       positive[i] = 1;
01538     }
01539   }
01540 
01541   // fill out array
01542   for( int i=0;i<N-1;i++ ) {
01543     for( int j=i+1;j<N;j++ ) {
01544       int ind = i*N+j;
01545       if( positive[i]*positive[j] > 0 ) 
01546         corr[ind] = cov[i][j]/rms[i]/rms[j];
01547       else
01548         corr[ind] = 0;
01549     }
01550   }
01551   for( int i=0;i<N;i++ ) corr[i*(N+1)] = 1.;
01552   for( int i=1;i<N;i++ ) {
01553     for( int j=0;j<i;j++ ) {
01554       corr[i*N+j] = corr[i+j*N];
01555     }
01556   }
01557 
01558   // exit
01559   return true;
01560 }
01561 
01562 
01563 bool SprRootAdapter::correlationClassLabel(const char* mode,
01564                                            char vars[][200],
01565                                            double* corr, 
01566                                            const char* datatype) const
01567 {
01568   // sanity check
01569   string sdatatype = datatype;
01570   SprAbsFilter* data = 0;
01571   if(      sdatatype == "train" )
01572     data = trainData_;
01573   else if( sdatatype == "test" )
01574     data = testData_;
01575   if( data == 0 ) {
01576     cerr << "Data of type " << sdatatype.c_str()
01577          << " has not been loaded." << endl;
01578     return false;
01579   }
01580 
01581   // fill out vars
01582   unsigned dim = data->dim();
01583   vector<string> dataVars;
01584   data->vars(dataVars);
01585   assert( dataVars.size() == dim );
01586   for( int d=0;d<dim;d++ )
01587     strcpy(vars[d],dataVars[d].c_str());
01588 
01589   // compute correlation
01590   SprDataMoments moms(data);
01591   string smode = mode;
01592   double mean(0), var(0);
01593   if(      smode == "normal" ) {
01594     for( int d=0;d<dim;d++ )
01595       corr[d] = moms.correlClassLabel(d,mean,var);
01596   }
01597   else if( smode == "abs" ) {
01598     for( int d=0;d<dim;d++ )
01599       corr[d] = moms.absCorrelClassLabel(d,mean,var);
01600   }
01601   else {
01602     cerr << "Unknown mode in correlationClassLabel." << endl;
01603     return false;
01604   }
01605 
01606   // exit
01607   return true;
01608 }
01609 
01610 
01611 bool SprRootAdapter::variableImportance(const char* classifierName,
01612                                         unsigned nPerm,
01613                                         char vars[][200], 
01614                                         double* importance,
01615                                         double* error) const
01616 {
01617   // sanity check
01618   if( testData_ == 0 ) {
01619     cerr << "Test data has not been loaded." << endl;
01620     return false;
01621   }
01622   if( needToTest_ ) {
01623     cerr << "Test data has changed. Need to run test() again." << endl;
01624     return false;
01625   }
01626 
01627   // find classifier and mapper
01628   string sclassifier = classifierName;
01629   SprCoordinateMapper* mapper = 0;
01630   SprAbsTrainedClassifier* trained = 0;
01631   SprTrainedMultiClassLearner* mcTrained = 0;
01632   if( sclassifier == "MultiClassLearner" ) {
01633     mapper = mcMapper_;
01634     if( mcTrained_ == 0 ) {
01635       cerr << "Classifier MultiClassLearner not found." << endl;
01636       return false;
01637     }
01638     mcTrained = mcTrained_;
01639   }
01640   else {
01641     map<string,SprAbsTrainedClassifier*>::const_iterator ic
01642       = trained_.find(sclassifier);
01643     if( ic == trained_.end() ) {
01644       cerr << "Classifier " << sclassifier.c_str() << " not found." << endl;
01645       return false;
01646     }
01647     trained = ic->second;
01648     assert( trained != 0 );
01649     map<SprAbsTrainedClassifier*,SprCoordinateMapper*>::const_iterator im
01650       = mapper_.find(trained);
01651     if( im != mapper_.end() )
01652       mapper = im->second;
01653   }
01654 
01655   // compute importance
01656   vector<SprClassifierEvaluator::NameAndValue> lossIncrease;
01657   if( !SprClassifierEvaluator::variableImportance(testData_,
01658                                                   trained,
01659                                                   mcTrained,
01660                                                   mapper,
01661                                                   nPerm,
01662                                                   lossIncrease) ) {
01663     cerr << "Unable to estimate variable importance." << endl;
01664     return false;
01665   }
01666 
01667   // convert result into arrays
01668   for( int d=0;d<lossIncrease.size();d++ ) {
01669     strcpy(vars[d],lossIncrease[d].first.c_str());
01670     importance[d] = lossIncrease[d].second.first;
01671     error[d] = lossIncrease[d].second.second;
01672   }
01673 
01674   // exit
01675   return true;
01676 }
01677 
01678 
01679 bool SprRootAdapter::variableInteraction(const char* classifierName,
01680                                          const char* subset,
01681                                          unsigned nPoints,
01682                                          char vars[][200],
01683                                          double* interaction,
01684                                          double* error,
01685                                          int verbose) const
01686 {
01687   // sanity check
01688   if( testData_ == 0 ) {
01689     cerr << "Test data has not been loaded." << endl;
01690     return false;
01691   }
01692   if( needToTest_ ) {
01693     cerr << "Test data has changed. Need to run test() again." << endl;
01694     return false;
01695   }
01696 
01697   // find classifier and mapper
01698   string sclassifier = classifierName;
01699   SprCoordinateMapper* mapper = 0;
01700   SprAbsTrainedClassifier* trained = 0;
01701   SprTrainedMultiClassLearner* mcTrained = 0;
01702   if( sclassifier == "MultiClassLearner" ) {
01703     mapper = mcMapper_;
01704     if( mcTrained_ == 0 ) {
01705       cerr << "Classifier MultiClassLearner not found." << endl;
01706       return false;
01707     }
01708     mcTrained = mcTrained_;
01709   }
01710   else {
01711     map<string,SprAbsTrainedClassifier*>::const_iterator ic
01712       = trained_.find(sclassifier);
01713     if( ic == trained_.end() ) {
01714       cerr << "Classifier " << sclassifier.c_str() << " not found." << endl;
01715       return false;
01716     }
01717     trained = ic->second;
01718     assert( trained != 0 );
01719     map<SprAbsTrainedClassifier*,SprCoordinateMapper*>::const_iterator im
01720       = mapper_.find(trained);
01721     if( im != mapper_.end() )
01722       mapper = im->second;
01723   }
01724 
01725   // compute interaction
01726   vector<SprClassifierEvaluator::NameAndValue> varInteraction;
01727   if( !SprClassifierEvaluator::variableInteraction(testData_,
01728                                                    trained,
01729                                                    mcTrained,
01730                                                    mapper,
01731                                                    subset,
01732                                                    nPoints,
01733                                                    varInteraction,
01734                                                    verbose) ) {
01735     cerr << "Unable to estimate variable interactions." << endl;
01736     return false;
01737   }
01738 
01739   // convert result into arrays
01740   for( int d=0;d<varInteraction.size();d++ ) {
01741     strcpy(vars[d],varInteraction[d].first.c_str());
01742     interaction[d] = varInteraction[d].second.first;
01743     error[d] = varInteraction[d].second.second;
01744   }
01745 
01746   // exit
01747   return true;
01748 }
01749 
01750 
01751 bool SprRootAdapter::histogram(const char* classifierName,
01752                                double xlo, double xhi, int nbin,
01753                                double* sig, double* sigerr,
01754                                double* bgr, double* bgrerr) const
01755 {
01756   // sanity check
01757   if( plotter_ == 0 ) {
01758     cerr << "No response vectors found. Nothing to histogram." << endl;
01759     return false;
01760   }
01761   if( xhi < xlo ) {
01762     cerr << "requested lower X limit greater than upper X limit." << endl;
01763     return false;
01764   }
01765 
01766   // call through
01767   double dx = (xhi-xlo) / nbin;
01768   vector<pair<double,double> > sigHist;
01769   vector<pair<double,double> > bgrHist;
01770   int nFilledBins = plotter_->histogram(classifierName,
01771                                         xlo,xhi,dx,sigHist,bgrHist);
01772   if( nFilledBins < nbin ) {
01773     cerr << "Requested " << nbin << " bins but filled only "
01774          << nFilledBins << ". Unable to plot histogram." << endl;
01775     return false;
01776   }
01777 
01778   // copy histogram content
01779   for( int i=0;i<nbin;i++ ) {
01780     sig[i]    = sigHist[i].first;
01781     sigerr[i] = sigHist[i].second;
01782     bgr[i]    = bgrHist[i].first;
01783     bgrerr[i] = bgrHist[i].second;
01784   }
01785 
01786   // exit
01787   return true;
01788 }
01789 
01790 
01791 SprAbsTwoClassCriterion* SprRootAdapter::makeCrit(const char* criterion)
01792 {
01793   SprAbsTwoClassCriterion* crit = 0;
01794   string scrit = criterion;
01795   if(      scrit == "correct_id" ) {
01796       crit = new SprTwoClassIDFraction;
01797       cout << "Optimization criterion set to "
01798            << "Fraction of correctly classified events " << endl;
01799   }
01800   else if( scrit == "S/sqrt(S+B)" ) {
01801     crit = new SprTwoClassSignalSignif;
01802     cout << "Optimization criterion set to "
01803          << "Signal significance S/sqrt(S+B) " << endl;
01804   }
01805   else if( scrit == "S/(S+B)" ) {
01806     crit = new SprTwoClassPurity;
01807     cout << "Optimization criterion set to "
01808          << "Purity S/(S+B) " << endl;
01809   }
01810   else if( scrit == "TaggerEff" ) {
01811     crit = new SprTwoClassTaggerEff;
01812     cout << "Optimization criterion set to "
01813          << "Tagging efficiency Q = e*(1-2w)^2 " << endl;
01814   }
01815   else if( scrit == "Gini" ) {
01816     crit = new SprTwoClassGiniIndex;
01817     cout << "Optimization criterion set to "
01818          << "Gini index  -1+p^2+q^2 " << endl;
01819   }
01820   else if( scrit == "CrossEntropy" ) {
01821     crit = new SprTwoClassCrossEntropy;
01822     cout << "Optimization criterion set to "
01823          << "Cross-entropy p*log(p)+q*log(q) " << endl;
01824   }
01825   else if( scrit == "CrossEntropy" ) {
01826     crit = new SprTwoClassUniformPriorUL90;
01827     cout << "Optimization criterion set to "
01828          << "Inverse of 90% Bayesian upper limit with uniform prior" << endl;
01829   }
01830   else if( scrit == "BKDiscovery" ) {
01831     crit = new SprTwoClassBKDiscovery;
01832     cout << "Optimization criterion set to "
01833          << "Discovery potential 2*(sqrt(S+B)-sqrt(B))" << endl;
01834   }
01835   else if( scrit == "Punzi" ) {
01836     crit = new SprTwoClassPunzi(1.);
01837     cout << "Optimization criterion set to "
01838          << "Punzi's sensitivity S/(0.5*nSigma+sqrt(B))" << endl;
01839   }
01840   else {
01841     cerr << "Unknown criterion specified." << endl;
01842     return 0;
01843   }
01844   return crit;
01845 }
01846 
01847 
01848 bool SprRootAdapter::multiClassTable(int nClass,
01849                                      const int* classes,
01850                                      double* classificationTable) const
01851 {
01852   // sanity check
01853   if( mcPlotter_ == 0 ) {
01854     cerr << "No response vectors found. "
01855          << "Cannot compute classification table." << endl;
01856     return false;
01857   }
01858 
01859   // make a list of classes to be included
01860   vector<int> vclasses(&classes[0],&classes[nClass]);
01861 
01862   // call mutliclass plotter
01863   map<int,vector<double> > mcClassificationTable;
01864   map<int,double> weightInClass;
01865   double loss = mcPlotter_->multiClassTable(vclasses,
01866                                             mcClassificationTable,
01867                                             weightInClass);
01868 
01869   // convert the map into an array
01870   for( int ic=0;ic<nClass;ic++ ) {
01871     map<int,vector<double> >::const_iterator found 
01872       = mcClassificationTable.find(classes[ic]);
01873     if( found == mcClassificationTable.end() ) {
01874       for( int j=0;j<nClass;j++ )
01875         classificationTable[j+ic*nClass] = 0;
01876     }
01877     else {
01878       assert( found->second.size() == nClass );
01879       for( int j=0;j<nClass;j++ )
01880         classificationTable[j+ic*nClass] = (found->second)[j];
01881     }
01882   }
01883 
01884   // exit
01885   return true;
01886 }
01887 
01888 
01889 bool SprRootAdapter::saveTestData(const char* filename) const
01890 {
01891   // sanity check
01892   if( testData_ == 0 ) {
01893     cerr << "Test data has not been loaded." << endl;
01894     return false;
01895   }
01896   if( (!trained_.empty() || mcTrained_!=0) && needToTest_ ) {
01897     cerr << "Test data has changed. Need to run test() again." << endl;
01898     return false;
01899   }
01900   if( trained_.empty() && mcTrained_==0 ) {
01901     cout << "No trained classifiers found. " 
01902          << "Data will be saved without any classifiers." << endl;
01903   }
01904 
01905   // create writer and feeder
01906   SprRootWriter writer("TestData");
01907   if( !writer.init(filename) ) {
01908     cerr << "Unable to open output file " << filename << endl;
01909     return false;
01910   }
01911   SprDataFeeder feeder(testData_,&writer);
01912 
01913   // add classifiers
01914   for( map<string,SprAbsTrainedClassifier*>::const_iterator 
01915        i=trained_.begin();i!=trained_.end();i++ ) {
01916     SprCoordinateMapper* mapper = 0;
01917     map<SprAbsTrainedClassifier*,SprCoordinateMapper*>::const_iterator
01918       found = mapper_.find(i->second);
01919     if( found != mapper_.end() )
01920       mapper = found->second->clone();
01921     if( !feeder.addClassifier(i->second,i->first.c_str(),mapper) ) {
01922       cerr << "Unable to add classifier " << i->first.c_str() 
01923            << " to feeder." << endl;
01924       return false;
01925     }
01926   }
01927   if( mcTrained_ != 0 ) {
01928     SprCoordinateMapper* mapper = ( mcMapper_==0 ? 0 : mcMapper_->clone() );
01929     if( !feeder.addMultiClassLearner(mcTrained_,"MultiClassLearner",mapper) ) {
01930       cerr << "Unable to add MultiClassLearner to feeder." << endl;
01931       return false;
01932     }
01933   }
01934 
01935   // feed
01936   if( !feeder.feed(1000) ) {
01937     cerr << "Unable to feed data into writer." << endl;
01938     return false;
01939   }
01940 
01941   // exit
01942   return true;
01943 }
01944 
01945 
01946 bool SprRootAdapter::trainVarTransformer(const char* name, int verbose)
01947 {
01948   // sanity check
01949   if( trainData_ == 0 ) {
01950     cerr << "Training data has not been loaded." << endl;
01951     return false;
01952   }
01953 
01954   // make a transformer
01955   if( trans_ != 0 ) delete trans_;
01956   string sname = name;
01957   if(      sname == "PCA" )
01958     trans_ = new SprPCATransformer();
01959   else {
01960     cerr << "Unknown VarTransformer type requested: " << sname.c_str() << endl;
01961     return false;
01962   }
01963 
01964   // train
01965   if( !trans_->train(trainData_,verbose) ) {
01966     cerr << "Unable to train VarTransformer." << endl;
01967     return false;
01968   }
01969 
01970   // exit
01971   return true;
01972 }
01973 
01974 
01975 bool SprRootAdapter::saveVarTransformer(const char* filename) const
01976 {
01977   // sanity check
01978   if( trans_ == 0 ) {
01979     cerr << "No VarTransformer found. Unable to save." << endl;
01980     return false;
01981   }
01982 
01983   // save
01984   if( !trans_->store(filename) ) {
01985     cerr << "Unable to save VarTransformer to file " << filename << endl;
01986     return false;
01987   }
01988 
01989   // exit
01990   return true;
01991 }
01992 
01993 
01994 bool SprRootAdapter::loadVarTransformer(const char* filename)
01995 {
01996   if( trans_ != 0 ) delete trans_;
01997   trans_ = SprVarTransformerReader::read(filename);
01998   if( trans_ == 0 ) {
01999     cerr << "Unable to load VarTransformer from file " << filename << endl;
02000     return false;
02001   }
02002   return true;
02003 }
02004 
02005 
02006 bool SprRootAdapter::transform()
02007 {
02008   // sanity check
02009   if( trainData_ == 0 ) {
02010     cerr << "Training data has not been loaded. Unable to transform." << endl;
02011     return false;
02012   }
02013   if( testData_ == 0 ) {
02014     cerr << "Test data has not been loaded. Unable to transform." << endl;
02015     return false;
02016   }
02017   if( trans_ == 0 ) {
02018     cerr << "No VarTransformer found. Unable to transform." << endl;
02019     return false;
02020   }
02021 
02022   // make new data filters
02023   SprTransformerFilter* trainData = new SprTransformerFilter(trainData_);
02024   SprTransformerFilter* testData = new SprTransformerFilter(testData_);
02025 
02026   // transform
02027   bool replaceOriginalData = true;
02028   if( !trainData->transform(trans_,replaceOriginalData) ) {
02029     cerr << "Unable to transform training data." << endl;
02030     return false;
02031   }
02032   if( !testData->transform(trans_,replaceOriginalData) ) {
02033     cerr << "Unable to transform test data." << endl;
02034     return false;
02035   }
02036 
02037   // get rid of old non-transformed data
02038   if( trainGarbage_ == 0 )
02039     trainGarbage_ = trainData_;
02040   else
02041     delete trainData_;
02042   if( testGarbage_ == 0 )
02043     testGarbage_ = testData_;
02044   else
02045     delete testData_;
02046   trainData_ = trainData;
02047   testData_ = testData;
02048 
02049   // exit
02050   return true;
02051 }

Generated on Tue Jun 9 17:42:03 2009 for CMSSW by  doxygen 1.5.4