CMS 3D CMS Logo

SprCrossValidator.cc

Go to the documentation of this file.
00001 //$Id: SprCrossValidator.cc,v 1.2 2007/09/21 22:32:09 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprFomCalculator.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerPermutator.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00014 
00015 #include <list>
00016 #include <iostream>
00017 
00018 using namespace std;
00019 
00020 
00021 SprCrossValidator::~SprCrossValidator()
00022 {
00023   for( int i=0;i<samples_.size();i++ )
00024     delete samples_[i];
00025 }
00026 
00027 
00028 bool SprCrossValidator::divide(unsigned nPieces)
00029 {
00030   // sanity check
00031   unsigned size = data_->size();
00032   if( nPieces > size ) {
00033     cerr << "Too many pieces requested for cross-validation: " 
00034          << nPieces << ">" << size << endl;
00035     return false;
00036   }
00037 
00038   // randomize point indices
00039   vector<unsigned> index;
00040   SprIntegerPermutator permu(size);
00041   if( !permu.sequence(index) ) {
00042     cerr << "CrossValidator is unable to randomize indices." << endl;
00043     return false;
00044   }
00045   else
00046     cout << "Indices for cross-validation permuted." << endl;
00047 
00048   // fill subsamples
00049   unsigned nupdate = size/nPieces;
00050   for( unsigned i=0;i<nPieces;i++ ) {
00051     // subsamples own SprData which does not own points
00052     SprData* sample = data_->emptyCopy();
00053     vector<SprClass> classes;
00054     data_->classes(classes);
00055     samples_[i] = new SprEmptyFilter(sample,classes,true);
00056 
00057     //make subsample
00058     vector<double> w;
00059     for( unsigned j=0;j<nupdate;j++ ) {
00060       unsigned icurr = index[j+i*nupdate];
00061       sample->uncheckedInsert((*data_)[icurr]);
00062       w.push_back(data_->w(icurr));
00063     }
00064 
00065     // set weights
00066     if( !samples_[i]->setWeights(w) ) {
00067       cerr << "Unable to set weights for subsample " << i << endl;
00068       return false;
00069     }
00070 
00071     // sanity check
00072     assert( !samples_[i]->empty() );
00073 
00074     // message
00075     cout << "Obtained subsample " << i << " for cross-validation." << endl;
00076   }
00077 
00078   // exit
00079   return true;
00080 }
00081 
00082 
00083 bool SprCrossValidator::validate(const SprAbsTwoClassCriterion* crit,
00084                                  SprAverageLoss* loss,
00085                                  const std::vector<SprAbsClassifier*>& 
00086                                  classifiers,
00087                                  const SprClass& cls0, const SprClass& cls1,
00088                                  const SprCut& cut,
00089                                  std::vector<double>& crossFom,
00090                                  int verbose) 
00091   const
00092 {
00093   // print out
00094   if( verbose > 0 ) {
00095     cout << "Will cross-validate using " 
00096          << samples_.size() << " subsamples: " << endl;
00097     for( int i=0;i<samples_.size();i++ ) {
00098       cout << "Subsample " << i 
00099            << "  W1=" << samples_[i]->weightInClass(cls1)
00100            << "  W0=" << samples_[i]->weightInClass(cls0)
00101            << "  N1=" << samples_[i]->ptsInClass(cls1)
00102            << "  N0=" << samples_[i]->ptsInClass(cls0) << endl;
00103     }
00104   }
00105 
00106   // sanity check
00107   assert( !classifiers.empty() && !samples_.empty() );
00108 
00109   // make a local copy of data
00110   SprEmptyFilter data(data_);
00111 
00112   // init
00113   crossFom.clear();
00114   crossFom.resize(classifiers.size());
00115 
00116   // loop over classifiers
00117   for( int ic=0;ic<classifiers.size();ic++ ) {
00118     SprAbsClassifier* c = classifiers[ic];
00119     assert( c != 0 );
00120 
00121     // message
00122     cout << "Cross-validator processing classifier " << ic << endl;
00123 
00124     // init fom vector
00125     vector<double> fom;
00126 
00127     // loop over subsamples
00128     for( int i=0;i<samples_.size();i++ ) {
00129       // message
00130       cout << "Cross-validator processing sub-sample " << i 
00131            << " for classifier " << ic << endl;
00132 
00133       // remove subsample from training data
00134       data.clear();
00135       data.remove(samples_[i]->data());
00136 
00137       // print out
00138       if( verbose > 0 ) {
00139         cout << "Will train classifier " << c->name().c_str()
00140              << " on a sample: " << endl;
00141         cout << "  W1=" << data.weightInClass(cls1)
00142              << "  W0=" << data.weightInClass(cls0)
00143              << "  N1=" << data.ptsInClass(cls1)
00144              << "  N0=" << data.ptsInClass(cls0) << endl;
00145       }
00146 
00147       // reset classifier
00148       if( !c->setData(&data) ) {
00149         cerr << "Cross-validator unable to set data for classifier " 
00150              << ic << endl;
00151         return false;
00152       }
00153 
00154       // train
00155       if( !c->train(verbose-1) ) {
00156         cerr << "Unable to train classifier " << ic << endl;
00157         continue;
00158       }
00159       SprAbsTrainedClassifier* t = c->makeTrained();
00160       if( t == 0 ) {
00161         cerr << "Cross-validator unable to get trained classifier "
00162              << ic << " for subsample " << i << endl;
00163         continue;
00164       }
00165       t->setCut(cut);
00166 
00167       // compute FOM
00168       fom.push_back(SprFomCalculator::fom(samples_[i],t,crit,loss,cls0,cls1));
00169 
00170       // cleanup
00171       delete t;
00172     }// end loop over subsamples
00173 
00174     // sanity check
00175     if( fom.empty() ) {
00176       cerr << "Cross-validator unable to compute FOM for classifier " 
00177            << ic << endl;
00178       crossFom[ic] = SprUtils::min();
00179       continue;
00180     }
00181 
00182     // compute average FOM
00183     double ave = 0;
00184     for( int i=0;i<fom.size();i++ )
00185       ave += fom[i];
00186     ave /= fom.size();
00187 
00188     // print out
00189     if( verbose > 0 ) {
00190       cout << "Computed FOMs for subsamples:" << endl;
00191       for( int i=0;i<fom.size();i++ )
00192         cout << i << "   FOM=" << fom[i] << endl;
00193     }
00194 
00195     // fill cross-validation FOM
00196     crossFom[ic] = ave;
00197 
00198     // reset classifier to point to the original data
00199     if( !c->setData(const_cast<SprAbsFilter*>(data_)) ) {
00200       cerr << "Cross-validator unable to restore data for classifier " 
00201            << ic << endl;
00202       return false;
00203     }
00204   }// end loop over classifiers
00205 
00206   // exit
00207   return true;
00208 }

Generated on Tue Jun 9 17:42:02 2009 for CMSSW by  doxygen 1.5.4