CMS 3D CMS Logo

SprBagger.cc

Go to the documentation of this file.
00001 //$Id: SprBagger.cc,v 1.3 2007/10/30 18:56:14 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprBagger.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprBootstrap.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00013 
00014 #include <stdio.h>
00015 #include <functional>
00016 #include <algorithm>
00017 #include <cmath>
00018 
00019 using namespace std;
00020 
00021 
00022 struct SBTrainedOwned 
00023   : public unary_function<pair<const SprAbsTrainedClassifier*,bool>,bool> {
00024   bool operator()(const pair<const SprAbsTrainedClassifier*,bool>& p) const {
00025     return p.second;
00026   }
00027 };
00028 
00029 
00030 SprBagger::~SprBagger() 
00031 { 
00032   this->destroy(); 
00033   if( ownLoss_ ) {
00034     delete loss_;
00035     loss_ = 0;
00036     ownLoss_ = false;
00037   }
00038 }
00039 
00040 
00041 SprBagger::SprBagger(SprAbsFilter* data)
00042   : 
00043   SprAbsClassifier(data), 
00044   crit_(0),
00045   cls0_(0),
00046   cls1_(1),
00047   cycles_(0),
00048   discrete_(false),
00049   trained_(), 
00050   trainable_(), 
00051   bootstrap_(0),
00052   valData_(0),
00053   valBeta_(),
00054   valPrint_(0),
00055   loss_(0),
00056   ownLoss_(false)
00057 {}
00058 
00059 
00060 SprBagger::SprBagger(SprAbsFilter* data, unsigned cycles, bool discrete)
00061   : 
00062   SprAbsClassifier(data), 
00063   crit_(0),
00064   cls0_(0),
00065   cls1_(1),
00066   cycles_(cycles),
00067   discrete_(discrete),
00068   trained_(), 
00069   trainable_(), 
00070   bootstrap_(new SprBootstrap(data)),
00071   valData_(0),
00072   valBeta_(),
00073   valPrint_(0),
00074   loss_(0),
00075   ownLoss_(false)
00076 { 
00077   assert( bootstrap_ != 0 );
00078   this->setClasses();
00079   cout << "Bagger initialized with classes " << cls0_ << " " << cls1_
00080        << " with cycles " << cycles_ << endl;
00081 }
00082 
00083 
00084 void SprBagger::destroy()
00085 {
00086   for( int i=0;i<trained_.size();i++ ) {
00087     if( trained_[i].second )
00088       delete trained_[i].first;
00089   }
00090   trained_.erase(remove_if(trained_.begin(),trained_.end(),SBTrainedOwned()),
00091                  trained_.end());
00092   delete bootstrap_;
00093   bootstrap_ = 0;
00094 }
00095 
00096 
00097 void SprBagger::setClasses() 
00098 {
00099   vector<SprClass> classes;
00100   data_->classes(classes);
00101   int size = classes.size();
00102   if( size > 0 ) cls0_ = classes[0];
00103   if( size > 1 ) cls1_ = classes[1];
00104   cout << "Classes for Bagger are set to " << cls0_ << " " << cls1_ << endl;
00105 }
00106 
00107 
00108 bool SprBagger::reset() 
00109 {
00110   this->destroy();
00111   bootstrap_ = new SprBootstrap(data_,-1);
00112   return true;
00113 }
00114 
00115 
00116 bool SprBagger::setData(SprAbsFilter* data)
00117 {
00118   assert( data != 0 );
00119 
00120   // reset base data
00121   data_ = data;
00122 
00123   // reset data supplied to trainable classifiers
00124   for( int i=0;i<trainable_.size();i++ ) {
00125     if( !trainable_[i]->setData(data_) ) {
00126       cerr << "Cannot reset data for trainable classifier " << i << endl;
00127       return false;
00128     }
00129   }
00130 
00131   // basic reset
00132   return this->reset();
00133 }
00134 
00135 
00136 bool SprBagger::addTrained(const SprAbsTrainedClassifier* c, bool own)
00137 {
00138   if( c == 0 ) return false;
00139   trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>(c,own));
00140   return true;
00141 }
00142 
00143 
00144 bool SprBagger::addTrainable(SprAbsClassifier* c)
00145 {
00146   if( c == 0 ) return false;
00147   trainable_.push_back(c);
00148   return true;
00149 }
00150 
00151 
00152 bool SprBagger::train(int verbose)
00153 {
00154   // sanity check
00155   if( cycles_==0 || trainable_.empty() ) {
00156     cout << "Bagger will exit without training." << endl;
00157     return this->prepareExit(true);
00158   }
00159 
00160   // if resume training, generate a seed from time of day
00161   if( cycles_>0 && !trained_.empty() ) {
00162     delete bootstrap_;
00163     bootstrap_ = new SprBootstrap(data_,-1);
00164     assert( bootstrap_ != 0 );
00165   }
00166 
00167   // after all betas are filled, do an overall validation
00168   if( valData_ != 0 ) {
00169     // compute cumulative beta weights for validation points
00170     valBeta_.clear();
00171     int vsize = valData_->size();
00172     valBeta_.resize(vsize,0);
00173     int tsize = trained_.size();
00174     for( int i=0;i<vsize;i++ ) {
00175       const SprPoint* p = (*valData_)[i];
00176       if( discrete_ ) {
00177         for( int j=0;j<tsize;j++ )
00178           valBeta_[i] += ( trained_[j].first->accept(p) ? 1 : -1 );
00179       }
00180       else {
00181         for( int j=0;j<tsize;j++ )
00182           valBeta_[i] += trained_[j].first->response(p);
00183       }
00184       if( tsize > 0 ) valBeta_[i] /= tsize;
00185     }
00186 
00187     // print out
00188     if( valPrint_ > 0 ) {
00189       if( !this->printValidation(0) ) {
00190         cerr << "Unable to print out validation data." << endl;
00191         return this->prepareExit(false);
00192       }
00193     }
00194   }
00195 
00196   // loop through trainable
00197   unsigned nCycle = 0;
00198   unsigned nFailed = 0;
00199   while( nCycle < cycles_ ) {
00200     for( int i=0;i<trainable_.size();i++ ) {
00201       // check cycles
00202       if( nCycle++ >= cycles_ ) return this->prepareExit((this->nTrained()>0));
00203 
00204       // generate replica
00205       auto_ptr<SprEmptyFilter> temp(bootstrap_->plainReplica());
00206       if( temp->size() != data_->size() ) {
00207         cerr << "Failed to generate bootstrap replica." << endl;
00208         return this->prepareExit(false);
00209       }
00210 
00211       // get new classifier
00212       SprAbsClassifier* c = trainable_[i];
00213       if( !c->setData(temp.get()) ) {
00214         cerr << "Unable to set data for classifier " << i << endl;
00215         return this->prepareExit(false);
00216       }
00217       if( !c->train(verbose) ) {
00218         cerr << "Bagger failed to train classifier " << i 
00219              << ". Continuing..."<< endl;
00220         if( ++nFailed >= cycles_ ) {
00221           cout << "Exiting after failed to train " << nFailed 
00222                << " classifiers." << endl;
00223           return this->prepareExit((this->nTrained()>0));
00224         }
00225         else
00226           continue;
00227       }
00228 
00229       // register new trained classifier
00230       SprAbsTrainedClassifier* t = c->makeTrained();
00231       if( t == 0 ) {
00232         cerr << "Bagger failed to train classifier " << i 
00233              << ". Continuing..."<< endl;
00234         if( ++nFailed >= cycles_ ) {
00235           cout << "Exiting after failed to train " << nFailed 
00236                << " classifiers." << endl;
00237           return this->prepareExit((this->nTrained()>0));
00238         }
00239         else
00240           continue;
00241       }
00242       trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>(t,true));
00243 
00244       // message
00245       if( verbose>0 || (nCycle%100)==0 ) {
00246         cout << "Finished cycle " << nCycle 
00247              << " with classifier " << t->name().c_str() << endl;
00248       }
00249 
00250       // validation
00251       if( valData_ != 0 ) {
00252         // update votes
00253         int tsize = trained_.size();
00254         for( int i=0;i<valData_->size();i++ ) {
00255           const SprPoint* p = (*valData_)[i];
00256           if( discrete_ ) {
00257             if( t->accept(p) ) 
00258               valBeta_[i] = ((tsize-1)*valBeta_[i] + 1)/tsize;
00259             else
00260               valBeta_[i] = ((tsize-1)*valBeta_[i] - 1)/tsize;
00261           }
00262           else
00263             valBeta_[i] = ((tsize-1)*valBeta_[i] + t->response(p))/tsize;
00264         }
00265 
00266         // print out
00267         if( valPrint_!=0 && (nCycle%valPrint_)==0 ) {
00268           if( !this->printValidation(nCycle) ) {
00269             cerr << "Unable to print out validation data." << endl;
00270             return this->prepareExit(false);
00271           }
00272         }
00273       }
00274     }
00275   }
00276 
00277   // normal exit
00278   return this->prepareExit((this->nTrained()>0));
00279 }
00280 
00281 
00282 SprTrainedBagger* SprBagger::makeTrained() const
00283 {
00284   // sanity check
00285   if( trained_.empty() ) return 0;
00286 
00287   // prepare a vector of trained classifiers
00288   vector<pair<const SprAbsTrainedClassifier*,bool> > trained;
00289   for( int i=0;i<trained_.size();i++ ) {
00290     SprAbsTrainedClassifier* c = trained_[i].first->clone();
00291     trained.push_back(pair<const SprAbsTrainedClassifier*,bool>(c,true));
00292   }
00293 
00294   // make a trained bagger
00295   SprTrainedBagger* t = new SprTrainedBagger(trained,discrete_);
00296 
00297   // vars
00298   vector<string> vars;
00299   data_->vars(vars);
00300   t->setVars(vars);
00301 
00302   // exit
00303   return t;
00304 }
00305 
00306 
00307 void SprBagger::print(std::ostream& os) const
00308 {
00309   os << "Trained Bagger " << SprVersion << endl;
00310   os << "Classifiers: " << trained_.size() << endl;
00311   for( int i=0;i<trained_.size();i++ ) {
00312     os << "Classifier " << i 
00313        << " " << trained_[i].first->name().c_str() << endl;
00314     trained_[i].first->print(os);
00315   }
00316 }
00317 
00318 
00319 bool SprBagger::setValidation(const SprAbsFilter* valData, 
00320                               unsigned valPrint,
00321                               const SprAbsTwoClassCriterion* crit,
00322                               SprAverageLoss* loss) 
00323 {
00324   // sanity checks
00325   if( !valBeta_.empty() ) {
00326     cerr << "One cannot reset validation data after training has started." 
00327          << endl;
00328     return false;
00329   }
00330   assert( valData != 0 );
00331 
00332   // set 
00333   valData_ = valData;
00334   valPrint_ = valPrint;
00335   crit_ = crit;
00336   loss_ = loss;
00337 
00338   // make default loss if none supplied
00339   if( crit_==0 && loss_==0 ) {
00340     loss_ = new SprAverageLoss(&SprLoss::quadratic);
00341     ownLoss_ = true;
00342   }
00343 
00344   // check loss and discreteness
00345   if( loss_==0 && !discrete_ ) {
00346     cout << "Warning: you requested continuous output for validation,"
00347          << " yet you have not supplied average loss appropriate for "
00348          << "the continuous output. Do you know what you are doing?" << endl;
00349   }
00350 
00351   // exit
00352   return true;
00353 }
00354 
00355 
00356 bool SprBagger::printValidation(unsigned cycle)
00357 {
00358   // no print-out for zero training cycle
00359   if( cycle == 0 ) return true;
00360 
00361   // sanity check
00362   assert(valBeta_.size() == valData_->size());
00363 
00364   // reset loss
00365   if( loss_ != 0 ) loss_->reset();
00366 
00367   // loop through validation data
00368   int vsize = valData_->size();
00369   double wcor0(0), wcor1(0), wmis0(0), wmis1(0);
00370   for( int i=0;i<vsize;i++ ) {
00371     const SprPoint* p = (*valData_)[i];
00372     double w = valData_->w(i);
00373     if( p->class_!=cls0_ && p->class_!=cls1_ ) w = 0;
00374     if( loss_ == 0 ) {
00375       if( valBeta_[i] > 0 ) {
00376         if(      p->class_ == cls0_ )
00377           wmis0 += w;
00378         else if( p->class_ == cls1_ )
00379           wcor1 += w;
00380       }
00381       else {
00382         if(      p->class_ == cls0_ )
00383           wcor0 += w;
00384         else if( p->class_ == cls1_ )
00385           wmis1 += w;
00386       }
00387     }
00388     else {
00389       if(      p->class_ == cls0_ )
00390         loss_->update(0,valBeta_[i],w);
00391       else if( p->class_ == cls1_ )
00392         loss_->update(1,valBeta_[i],w);
00393     }
00394   }
00395 
00396   // compute fom
00397   double fom = 0;
00398   assert( crit_!=0 || loss_!=0 );
00399   if( loss_ == 0 )
00400     fom = crit_->fom(wcor0,wmis0,wcor1,wmis1);
00401   else
00402     fom = loss_->value();
00403   cout << "Validation FOM=" << fom << " at cycle " << cycle << endl;
00404 
00405   // exit
00406   return true;
00407 }
00408 
00409 
00410 bool SprBagger::prepareExit(bool status)
00411 {
00412   // restore the original data supplied to the classifiers
00413   for( int i=0;i<trainable_.size();i++ ) {
00414     SprAbsClassifier* c = trainable_[i];
00415     if( !c->setData(data_) )
00416       cerr << "Unable to restore original data for classifier " << i << endl;
00417   }
00418    
00419   // exit
00420   return status;
00421 }
00422 
00423 
00424 bool SprBagger::setClasses(const SprClass& cls0, const SprClass& cls1) 
00425 {
00426   for( int i=0;i<trainable_.size();i++ ) {
00427     if( !trainable_[i]->setClasses(cls0,cls1) ) {
00428       cerr << "Bagger unable to reset classes for classifier " << i << endl;
00429       return false;
00430     }
00431   }
00432   cls0_ = cls0; cls1_ = cls1;
00433   cout << "Classes for Bagger reset to " << cls0_ << " " << cls1_ << endl;
00434   return true;
00435 }
00436 
00437 
00438 bool SprBagger::initBootstrapFromTimeOfDay()
00439 {
00440   if( bootstrap_ == 0 ) {
00441     cerr << "No bootstrap object found for the Bagger." << endl;
00442     return false;
00443   }
00444   bootstrap_->init(-1);
00445   return true;
00446 }

Generated on Tue Jun 9 17:42:00 2009 for CMSSW by  doxygen 1.5.4