CMS 3D CMS Logo

SprArcE4.cc

Go to the documentation of this file.
00001 //$Id: SprArcE4.cc,v 1.2 2007/09/21 22:32:08 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprBootstrap.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00010 
00011 #include <cassert>
00012 #include <cmath>
00013 
00014 using namespace std;
00015 
00016 
00017 SprArcE4::SprArcE4(SprAbsFilter* data, 
00018                    unsigned cycles, bool discrete)
00019   : 
00020   SprBagger(data,cycles,discrete), 
00021   initialDataWeights_(),
00022   response_(data->size(),pair<double,double>(0,0))
00023 { 
00024   data_->weights(initialDataWeights_);
00025   cout << "ArcE4 initialized." << endl;
00026 }
00027 
00028 
00029 bool SprArcE4::setData(SprAbsFilter* data)
00030 {
00031   // reset base data
00032   if( !SprBagger::setData(data) ) {
00033     cerr << "Unable to set data for ArcE4." << endl;
00034     return false;
00035   }
00036 
00037   // copy weights
00038   data_->weights(initialDataWeights_);
00039   
00040   // init responses
00041   response_.clear();
00042   response_.resize(data_->size(),pair<double,double>(0,0));
00043 
00044   // exit
00045   return true;
00046 }
00047 
00048 
00049 bool SprArcE4::train(int verbose)
00050 {
00051   // sanity check
00052   if( cycles_==0 || trainable_.empty() ) {
00053     cout << "ArcE4 will exit without training." << endl;
00054     return this->prepareExit(true);
00055   }
00056 
00057   // if resume training, generate a seed from time of day
00058   if( cycles_>0 && !trained_.empty() ) {
00059     delete bootstrap_;
00060     bootstrap_ = new SprBootstrap(data_,-1);
00061     assert( bootstrap_ != 0 );
00062   }
00063 
00064   // update responses
00065   assert( data_->size() == response_.size() );
00066   for( int i=0;i<data_->size();i++ ) {
00067     const SprPoint* p = (*data_)[i];
00068     for( int j=0;j<trained_.size();j++ ) {
00069       double& resp = response_[i].first;
00070       double& wresp = response_[i].second;
00071       resp = wresp*resp + trained_[j].first->response(p);
00072       wresp += 1.;
00073       resp /= wresp;
00074     }
00075   }
00076 
00077   // after all betas are filled, do an overall validation
00078   if( valData_ != 0 ) {
00079     // compute cumulative beta weights for validation points
00080     valBeta_.clear();
00081     int vsize = valData_->size();
00082     valBeta_.resize(vsize,0);
00083     int tsize = trained_.size();
00084     for( int i=0;i<vsize;i++ ) {
00085       const SprPoint* p = (*valData_)[i];
00086       if( discrete_ ) {
00087         for( int j=0;j<tsize;j++ )
00088           valBeta_[i] += ( trained_[j].first->accept(p) ? 1 : -1 );
00089       }
00090       else {
00091         for( int j=0;j<tsize;j++ )
00092           valBeta_[i] += trained_[j].first->response(p);
00093       }
00094       if( tsize > 0 ) valBeta_[i] /= tsize;
00095     }
00096 
00097     // print out
00098     if( valPrint_ > 0 ) {
00099       if( !this->printValidation(0) ) {
00100         cerr << "Unable to print out validation data." << endl;
00101         return this->prepareExit(false);
00102       }
00103     }
00104   }
00105 
00106   // loop through trainable
00107   unsigned nCycle = 0;
00108   unsigned nFailed = 0;
00109   while( nCycle < cycles_ ) {
00110     for( int i=0;i<trainable_.size();i++ ) {
00111       // check cycles
00112       if( nCycle++ >= cycles_ ) return this->prepareExit((this->nTrained()>0));
00113 
00114       // generate replica
00115       auto_ptr<SprEmptyFilter> temp(bootstrap_->weightedReplica());
00116       if( temp->size() != data_->size() ) {
00117         cerr << "Failed to generate bootstrap replica." << endl;
00118         return this->prepareExit(false);
00119       }
00120 
00121       // get new classifier
00122       SprAbsClassifier* c = trainable_[i];
00123       if( !c->setData(temp.get()) ) {
00124         cerr << "Unable to set data for classifier " << i << endl;
00125         return this->prepareExit(false);
00126       }
00127       if( !c->train(verbose) ) {
00128         cerr << "ArcE4 failed to train classifier " << i 
00129              << ". Continuing..."<< endl;
00130         if( ++nFailed >= cycles_ ) {
00131           cout << "Exiting after failed to train " << nFailed 
00132                << " classifiers." << endl;
00133           return this->prepareExit((this->nTrained()>0));
00134         }
00135         else
00136           continue;
00137       }
00138 
00139       // register new trained classifier
00140       SprAbsTrainedClassifier* t = c->makeTrained();
00141       if( t == 0 ) {
00142         cerr << "ArcE4 failed to train classifier " << i 
00143              << ". Continuing..."<< endl;
00144         if( ++nFailed >= cycles_ ) {
00145           cout << "Exiting after failed to train " << nFailed 
00146                << " classifiers." << endl;
00147           return this->prepareExit((this->nTrained()>0));
00148         }
00149         else
00150           continue;
00151       }
00152       trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>(t,true));
00153 
00154       // reweight events
00155       this->reweight(t);
00156       if( verbose > 1 ) {
00157         cout << "After reweighting:   W1=" << data_->weightInClass(cls1_)
00158              << " W0=" << data_->weightInClass(cls0_)
00159              << "    N1=" << data_->ptsInClass(cls1_)
00160              << " N0=" << data_->ptsInClass(cls0_) << endl;
00161       }
00162 
00163       // message
00164       if( verbose>1 || (nCycle%100)==0 ) {
00165         cout << "Done cycle " << nCycle << endl;
00166       }
00167 
00168       // validation
00169       if( valData_ != 0 ) {
00170         // update votes
00171         int tsize = trained_.size();
00172         for( int i=0;i<valData_->size();i++ ) {
00173           const SprPoint* p = (*valData_)[i];
00174           if( discrete_ ) {
00175             if( t->accept(p) ) 
00176               valBeta_[i] = ((tsize-1)*valBeta_[i] + 1)/tsize;
00177             else
00178               valBeta_[i] = ((tsize-1)*valBeta_[i] - 1)/tsize;
00179           }
00180           else
00181             valBeta_[i] = ((tsize-1)*valBeta_[i] + t->response(p))/tsize;
00182         }
00183 
00184         // print out
00185         if( valPrint_!=0 && (nCycle%valPrint_)==0 ) {
00186           if( !this->printValidation(nCycle) ) {
00187             cerr << "Unable to print out validation data." << endl;
00188             return this->prepareExit(false);
00189           }
00190         }
00191       }
00192     }
00193   }
00194 
00195   // normal exit
00196   return this->prepareExit((this->nTrained()>0));
00197 }
00198 
00199 
00200 bool SprArcE4::prepareExit(bool status)
00201 {
00202   // restore weights
00203   data_->setWeights(initialDataWeights_);
00204 
00205   // do basic restore
00206   return SprBagger::prepareExit(status);
00207 }
00208 
00209 
00210 void SprArcE4::reweight(const SprAbsTrainedClassifier* t)
00211 {
00212   unsigned size = data_->size();
00213   assert( size == initialDataWeights_.size() );
00214   assert( size == response_.size() );
00215   for( int i=0;i<size;i++ ) {
00216     const SprPoint* p = (*data_)[i];
00217 
00218     // update response
00219     double& resp = response_[i].first;
00220     double& wresp = response_[i].second;
00221     resp = wresp*resp + t->response(p);
00222     wresp += 1.;
00223     resp /= wresp;
00224 
00225     // reweight
00226     int cls = -1;
00227     if(      p->class_ == cls0_ ) 
00228       cls = 0;
00229     else if( p->class_ == cls1_ )
00230       cls = 1;
00231     if( cls > -1 ) {
00232       double error = wresp * (resp - cls);
00233       double w = initialDataWeights_[i] * (1.+pow(fabs(error),4));
00234       data_->setW(i,w);
00235     }
00236   }
00237 }

Generated on Tue Jun 9 17:42:00 2009 for CMSSW by  doxygen 1.5.4