00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprCrossValidator.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsClassifier.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprFomCalculator.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerPermutator.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00014
00015 #include <list>
00016 #include <iostream>
00017
00018 using namespace std;
00019
00020
00021 SprCrossValidator::~SprCrossValidator()
00022 {
00023 for( int i=0;i<samples_.size();i++ )
00024 delete samples_[i];
00025 }
00026
00027
00028 bool SprCrossValidator::divide(unsigned nPieces)
00029 {
00030
00031 unsigned size = data_->size();
00032 if( nPieces > size ) {
00033 cerr << "Too many pieces requested for cross-validation: "
00034 << nPieces << ">" << size << endl;
00035 return false;
00036 }
00037
00038
00039 vector<unsigned> index;
00040 SprIntegerPermutator permu(size);
00041 if( !permu.sequence(index) ) {
00042 cerr << "CrossValidator is unable to randomize indices." << endl;
00043 return false;
00044 }
00045 else
00046 cout << "Indices for cross-validation permuted." << endl;
00047
00048
00049 unsigned nupdate = size/nPieces;
00050 for( unsigned i=0;i<nPieces;i++ ) {
00051
00052 SprData* sample = data_->emptyCopy();
00053 vector<SprClass> classes;
00054 data_->classes(classes);
00055 samples_[i] = new SprEmptyFilter(sample,classes,true);
00056
00057
00058 vector<double> w;
00059 for( unsigned j=0;j<nupdate;j++ ) {
00060 unsigned icurr = index[j+i*nupdate];
00061 sample->uncheckedInsert((*data_)[icurr]);
00062 w.push_back(data_->w(icurr));
00063 }
00064
00065
00066 if( !samples_[i]->setWeights(w) ) {
00067 cerr << "Unable to set weights for subsample " << i << endl;
00068 return false;
00069 }
00070
00071
00072 assert( !samples_[i]->empty() );
00073
00074
00075 cout << "Obtained subsample " << i << " for cross-validation." << endl;
00076 }
00077
00078
00079 return true;
00080 }
00081
00082
00083 bool SprCrossValidator::validate(const SprAbsTwoClassCriterion* crit,
00084 SprAverageLoss* loss,
00085 const std::vector<SprAbsClassifier*>&
00086 classifiers,
00087 const SprClass& cls0, const SprClass& cls1,
00088 const SprCut& cut,
00089 std::vector<double>& crossFom,
00090 int verbose)
00091 const
00092 {
00093
00094 if( verbose > 0 ) {
00095 cout << "Will cross-validate using "
00096 << samples_.size() << " subsamples: " << endl;
00097 for( int i=0;i<samples_.size();i++ ) {
00098 cout << "Subsample " << i
00099 << " W1=" << samples_[i]->weightInClass(cls1)
00100 << " W0=" << samples_[i]->weightInClass(cls0)
00101 << " N1=" << samples_[i]->ptsInClass(cls1)
00102 << " N0=" << samples_[i]->ptsInClass(cls0) << endl;
00103 }
00104 }
00105
00106
00107 assert( !classifiers.empty() && !samples_.empty() );
00108
00109
00110 SprEmptyFilter data(data_);
00111
00112
00113 crossFom.clear();
00114 crossFom.resize(classifiers.size());
00115
00116
00117 for( int ic=0;ic<classifiers.size();ic++ ) {
00118 SprAbsClassifier* c = classifiers[ic];
00119 assert( c != 0 );
00120
00121
00122 cout << "Cross-validator processing classifier " << ic << endl;
00123
00124
00125 vector<double> fom;
00126
00127
00128 for( int i=0;i<samples_.size();i++ ) {
00129
00130 cout << "Cross-validator processing sub-sample " << i
00131 << " for classifier " << ic << endl;
00132
00133
00134 data.clear();
00135 data.remove(samples_[i]->data());
00136
00137
00138 if( verbose > 0 ) {
00139 cout << "Will train classifier " << c->name().c_str()
00140 << " on a sample: " << endl;
00141 cout << " W1=" << data.weightInClass(cls1)
00142 << " W0=" << data.weightInClass(cls0)
00143 << " N1=" << data.ptsInClass(cls1)
00144 << " N0=" << data.ptsInClass(cls0) << endl;
00145 }
00146
00147
00148 if( !c->setData(&data) ) {
00149 cerr << "Cross-validator unable to set data for classifier "
00150 << ic << endl;
00151 return false;
00152 }
00153
00154
00155 if( !c->train(verbose-1) ) {
00156 cerr << "Unable to train classifier " << ic << endl;
00157 continue;
00158 }
00159 SprAbsTrainedClassifier* t = c->makeTrained();
00160 if( t == 0 ) {
00161 cerr << "Cross-validator unable to get trained classifier "
00162 << ic << " for subsample " << i << endl;
00163 continue;
00164 }
00165 t->setCut(cut);
00166
00167
00168 fom.push_back(SprFomCalculator::fom(samples_[i],t,crit,loss,cls0,cls1));
00169
00170
00171 delete t;
00172 }
00173
00174
00175 if( fom.empty() ) {
00176 cerr << "Cross-validator unable to compute FOM for classifier "
00177 << ic << endl;
00178 crossFom[ic] = SprUtils::min();
00179 continue;
00180 }
00181
00182
00183 double ave = 0;
00184 for( int i=0;i<fom.size();i++ )
00185 ave += fom[i];
00186 ave /= fom.size();
00187
00188
00189 if( verbose > 0 ) {
00190 cout << "Computed FOMs for subsamples:" << endl;
00191 for( int i=0;i<fom.size();i++ )
00192 cout << i << " FOM=" << fom[i] << endl;
00193 }
00194
00195
00196 crossFom[ic] = ave;
00197
00198
00199 if( !c->setData(const_cast<SprAbsFilter*>(data_)) ) {
00200 cerr << "Cross-validator unable to restore data for classifier "
00201 << ic << endl;
00202 return false;
00203 }
00204 }
00205
00206
00207 return true;
00208 }