00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprBootstrap.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00013
00014 #include <stdio.h>
00015 #include <functional>
00016 #include <algorithm>
00017 #include <cmath>
00018
00019 using namespace std;
00020
00021
00022 struct SBTrainedOwned
00023 : public unary_function<pair<const SprAbsTrainedClassifier*,bool>,bool> {
00024 bool operator()(const pair<const SprAbsTrainedClassifier*,bool>& p) const {
00025 return p.second;
00026 }
00027 };
00028
00029
00030 SprBagger::~SprBagger()
00031 {
00032 this->destroy();
00033 if( ownLoss_ ) {
00034 delete loss_;
00035 loss_ = 0;
00036 ownLoss_ = false;
00037 }
00038 }
00039
00040
00041 SprBagger::SprBagger(SprAbsFilter* data)
00042 :
00043 SprAbsClassifier(data),
00044 crit_(0),
00045 cls0_(0),
00046 cls1_(1),
00047 cycles_(0),
00048 discrete_(false),
00049 trained_(),
00050 trainable_(),
00051 bootstrap_(0),
00052 valData_(0),
00053 valBeta_(),
00054 valPrint_(0),
00055 loss_(0),
00056 ownLoss_(false)
00057 {}
00058
00059
00060 SprBagger::SprBagger(SprAbsFilter* data, unsigned cycles, bool discrete)
00061 :
00062 SprAbsClassifier(data),
00063 crit_(0),
00064 cls0_(0),
00065 cls1_(1),
00066 cycles_(cycles),
00067 discrete_(discrete),
00068 trained_(),
00069 trainable_(),
00070 bootstrap_(new SprBootstrap(data)),
00071 valData_(0),
00072 valBeta_(),
00073 valPrint_(0),
00074 loss_(0),
00075 ownLoss_(false)
00076 {
00077 assert( bootstrap_ != 0 );
00078 this->setClasses();
00079 cout << "Bagger initialized with classes " << cls0_ << " " << cls1_
00080 << " with cycles " << cycles_ << endl;
00081 }
00082
00083
00084 void SprBagger::destroy()
00085 {
00086 for( int i=0;i<trained_.size();i++ ) {
00087 if( trained_[i].second )
00088 delete trained_[i].first;
00089 }
00090 trained_.erase(remove_if(trained_.begin(),trained_.end(),SBTrainedOwned()),
00091 trained_.end());
00092 delete bootstrap_;
00093 bootstrap_ = 0;
00094 }
00095
00096
00097 void SprBagger::setClasses()
00098 {
00099 vector<SprClass> classes;
00100 data_->classes(classes);
00101 int size = classes.size();
00102 if( size > 0 ) cls0_ = classes[0];
00103 if( size > 1 ) cls1_ = classes[1];
00104 cout << "Classes for Bagger are set to " << cls0_ << " " << cls1_ << endl;
00105 }
00106
00107
00108 bool SprBagger::reset()
00109 {
00110 this->destroy();
00111 bootstrap_ = new SprBootstrap(data_,-1);
00112 return true;
00113 }
00114
00115
00116 bool SprBagger::setData(SprAbsFilter* data)
00117 {
00118 assert( data != 0 );
00119
00120
00121 data_ = data;
00122
00123
00124 for( int i=0;i<trainable_.size();i++ ) {
00125 if( !trainable_[i]->setData(data_) ) {
00126 cerr << "Cannot reset data for trainable classifier " << i << endl;
00127 return false;
00128 }
00129 }
00130
00131
00132 return this->reset();
00133 }
00134
00135
00136 bool SprBagger::addTrained(const SprAbsTrainedClassifier* c, bool own)
00137 {
00138 if( c == 0 ) return false;
00139 trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>(c,own));
00140 return true;
00141 }
00142
00143
00144 bool SprBagger::addTrainable(SprAbsClassifier* c)
00145 {
00146 if( c == 0 ) return false;
00147 trainable_.push_back(c);
00148 return true;
00149 }
00150
00151
00152 bool SprBagger::train(int verbose)
00153 {
00154
00155 if( cycles_==0 || trainable_.empty() ) {
00156 cout << "Bagger will exit without training." << endl;
00157 return this->prepareExit(true);
00158 }
00159
00160
00161 if( cycles_>0 && !trained_.empty() ) {
00162 delete bootstrap_;
00163 bootstrap_ = new SprBootstrap(data_,-1);
00164 assert( bootstrap_ != 0 );
00165 }
00166
00167
00168 if( valData_ != 0 ) {
00169
00170 valBeta_.clear();
00171 int vsize = valData_->size();
00172 valBeta_.resize(vsize,0);
00173 int tsize = trained_.size();
00174 for( int i=0;i<vsize;i++ ) {
00175 const SprPoint* p = (*valData_)[i];
00176 if( discrete_ ) {
00177 for( int j=0;j<tsize;j++ )
00178 valBeta_[i] += ( trained_[j].first->accept(p) ? 1 : -1 );
00179 }
00180 else {
00181 for( int j=0;j<tsize;j++ )
00182 valBeta_[i] += trained_[j].first->response(p);
00183 }
00184 if( tsize > 0 ) valBeta_[i] /= tsize;
00185 }
00186
00187
00188 if( valPrint_ > 0 ) {
00189 if( !this->printValidation(0) ) {
00190 cerr << "Unable to print out validation data." << endl;
00191 return this->prepareExit(false);
00192 }
00193 }
00194 }
00195
00196
00197 unsigned nCycle = 0;
00198 unsigned nFailed = 0;
00199 while( nCycle < cycles_ ) {
00200 for( int i=0;i<trainable_.size();i++ ) {
00201
00202 if( nCycle++ >= cycles_ ) return this->prepareExit((this->nTrained()>0));
00203
00204
00205 auto_ptr<SprEmptyFilter> temp(bootstrap_->plainReplica());
00206 if( temp->size() != data_->size() ) {
00207 cerr << "Failed to generate bootstrap replica." << endl;
00208 return this->prepareExit(false);
00209 }
00210
00211
00212 SprAbsClassifier* c = trainable_[i];
00213 if( !c->setData(temp.get()) ) {
00214 cerr << "Unable to set data for classifier " << i << endl;
00215 return this->prepareExit(false);
00216 }
00217 if( !c->train(verbose) ) {
00218 cerr << "Bagger failed to train classifier " << i
00219 << ". Continuing..."<< endl;
00220 if( ++nFailed >= cycles_ ) {
00221 cout << "Exiting after failed to train " << nFailed
00222 << " classifiers." << endl;
00223 return this->prepareExit((this->nTrained()>0));
00224 }
00225 else
00226 continue;
00227 }
00228
00229
00230 SprAbsTrainedClassifier* t = c->makeTrained();
00231 if( t == 0 ) {
00232 cerr << "Bagger failed to train classifier " << i
00233 << ". Continuing..."<< endl;
00234 if( ++nFailed >= cycles_ ) {
00235 cout << "Exiting after failed to train " << nFailed
00236 << " classifiers." << endl;
00237 return this->prepareExit((this->nTrained()>0));
00238 }
00239 else
00240 continue;
00241 }
00242 trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>(t,true));
00243
00244
00245 if( verbose>0 || (nCycle%100)==0 ) {
00246 cout << "Finished cycle " << nCycle
00247 << " with classifier " << t->name().c_str() << endl;
00248 }
00249
00250
00251 if( valData_ != 0 ) {
00252
00253 int tsize = trained_.size();
00254 for( int i=0;i<valData_->size();i++ ) {
00255 const SprPoint* p = (*valData_)[i];
00256 if( discrete_ ) {
00257 if( t->accept(p) )
00258 valBeta_[i] = ((tsize-1)*valBeta_[i] + 1)/tsize;
00259 else
00260 valBeta_[i] = ((tsize-1)*valBeta_[i] - 1)/tsize;
00261 }
00262 else
00263 valBeta_[i] = ((tsize-1)*valBeta_[i] + t->response(p))/tsize;
00264 }
00265
00266
00267 if( valPrint_!=0 && (nCycle%valPrint_)==0 ) {
00268 if( !this->printValidation(nCycle) ) {
00269 cerr << "Unable to print out validation data." << endl;
00270 return this->prepareExit(false);
00271 }
00272 }
00273 }
00274 }
00275 }
00276
00277
00278 return this->prepareExit((this->nTrained()>0));
00279 }
00280
00281
00282 SprTrainedBagger* SprBagger::makeTrained() const
00283 {
00284
00285 if( trained_.empty() ) return 0;
00286
00287
00288 vector<pair<const SprAbsTrainedClassifier*,bool> > trained;
00289 for( int i=0;i<trained_.size();i++ ) {
00290 SprAbsTrainedClassifier* c = trained_[i].first->clone();
00291 trained.push_back(pair<const SprAbsTrainedClassifier*,bool>(c,true));
00292 }
00293
00294
00295 SprTrainedBagger* t = new SprTrainedBagger(trained,discrete_);
00296
00297
00298 vector<string> vars;
00299 data_->vars(vars);
00300 t->setVars(vars);
00301
00302
00303 return t;
00304 }
00305
00306
00307 void SprBagger::print(std::ostream& os) const
00308 {
00309 os << "Trained Bagger " << SprVersion << endl;
00310 os << "Classifiers: " << trained_.size() << endl;
00311 for( int i=0;i<trained_.size();i++ ) {
00312 os << "Classifier " << i
00313 << " " << trained_[i].first->name().c_str() << endl;
00314 trained_[i].first->print(os);
00315 }
00316 }
00317
00318
00319 bool SprBagger::setValidation(const SprAbsFilter* valData,
00320 unsigned valPrint,
00321 const SprAbsTwoClassCriterion* crit,
00322 SprAverageLoss* loss)
00323 {
00324
00325 if( !valBeta_.empty() ) {
00326 cerr << "One cannot reset validation data after training has started."
00327 << endl;
00328 return false;
00329 }
00330 assert( valData != 0 );
00331
00332
00333 valData_ = valData;
00334 valPrint_ = valPrint;
00335 crit_ = crit;
00336 loss_ = loss;
00337
00338
00339 if( crit_==0 && loss_==0 ) {
00340 loss_ = new SprAverageLoss(&SprLoss::quadratic);
00341 ownLoss_ = true;
00342 }
00343
00344
00345 if( loss_==0 && !discrete_ ) {
00346 cout << "Warning: you requested continuous output for validation,"
00347 << " yet you have not supplied average loss appropriate for "
00348 << "the continuous output. Do you know what you are doing?" << endl;
00349 }
00350
00351
00352 return true;
00353 }
00354
00355
00356 bool SprBagger::printValidation(unsigned cycle)
00357 {
00358
00359 if( cycle == 0 ) return true;
00360
00361
00362 assert(valBeta_.size() == valData_->size());
00363
00364
00365 if( loss_ != 0 ) loss_->reset();
00366
00367
00368 int vsize = valData_->size();
00369 double wcor0(0), wcor1(0), wmis0(0), wmis1(0);
00370 for( int i=0;i<vsize;i++ ) {
00371 const SprPoint* p = (*valData_)[i];
00372 double w = valData_->w(i);
00373 if( p->class_!=cls0_ && p->class_!=cls1_ ) w = 0;
00374 if( loss_ == 0 ) {
00375 if( valBeta_[i] > 0 ) {
00376 if( p->class_ == cls0_ )
00377 wmis0 += w;
00378 else if( p->class_ == cls1_ )
00379 wcor1 += w;
00380 }
00381 else {
00382 if( p->class_ == cls0_ )
00383 wcor0 += w;
00384 else if( p->class_ == cls1_ )
00385 wmis1 += w;
00386 }
00387 }
00388 else {
00389 if( p->class_ == cls0_ )
00390 loss_->update(0,valBeta_[i],w);
00391 else if( p->class_ == cls1_ )
00392 loss_->update(1,valBeta_[i],w);
00393 }
00394 }
00395
00396
00397 double fom = 0;
00398 assert( crit_!=0 || loss_!=0 );
00399 if( loss_ == 0 )
00400 fom = crit_->fom(wcor0,wmis0,wcor1,wmis1);
00401 else
00402 fom = loss_->value();
00403 cout << "Validation FOM=" << fom << " at cycle " << cycle << endl;
00404
00405
00406 return true;
00407 }
00408
00409
00410 bool SprBagger::prepareExit(bool status)
00411 {
00412
00413 for( int i=0;i<trainable_.size();i++ ) {
00414 SprAbsClassifier* c = trainable_[i];
00415 if( !c->setData(data_) )
00416 cerr << "Unable to restore original data for classifier " << i << endl;
00417 }
00418
00419
00420 return status;
00421 }
00422
00423
00424 bool SprBagger::setClasses(const SprClass& cls0, const SprClass& cls1)
00425 {
00426 for( int i=0;i<trainable_.size();i++ ) {
00427 if( !trainable_[i]->setClasses(cls0,cls1) ) {
00428 cerr << "Bagger unable to reset classes for classifier " << i << endl;
00429 return false;
00430 }
00431 }
00432 cls0_ = cls0; cls1_ = cls1;
00433 cout << "Classes for Bagger reset to " << cls0_ << " " << cls1_ << endl;
00434 return true;
00435 }
00436
00437
00438 bool SprBagger::initBootstrapFromTimeOfDay()
00439 {
00440 if( bootstrap_ == 0 ) {
00441 cerr << "No bootstrap object found for the Bagger." << endl;
00442 return false;
00443 }
00444 bootstrap_->init(-1);
00445 return true;
00446 }