CMS 3D CMS Logo

SprTrainedBagger.cc

Go to the documentation of this file.
00001 //$Id: SprTrainedBagger.cc,v 1.3 2007/10/30 18:56:14 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00007 
00008 #include <stdio.h>
00009 #include <cassert>
00010 
00011 using namespace std;
00012 
00013 
00014 SprTrainedBagger::SprTrainedBagger(const std::vector<
00015                    std::pair<const SprAbsTrainedClassifier*,bool> >& 
00016                    trained, bool discrete) 
00017   : 
00018   SprAbsTrainedClassifier(),
00019   trained_(trained),
00020   discrete_(discrete)
00021 {
00022   assert( !trained_.empty() );
00023   this->setCut(SprUtils::lowerBound(0.5));
00024 }
00025 
00026 
00027 SprTrainedBagger::SprTrainedBagger(const SprTrainedBagger& other)
00028   :
00029   SprAbsTrainedClassifier(other),
00030   trained_(),
00031   discrete_(other.discrete_)
00032 {
00033   for( int i=0;i<other.trained_.size();i++ )
00034     trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>
00035                        (other.trained_[i].first->clone(),true));
00036 }
00037 
00038 
00039 double SprTrainedBagger::response(const std::vector<double>& v) const
00040 {
00041   // init
00042   double r = 0;
00043 
00044   // discrete/continuous
00045   if( discrete_ ) {
00046     int out = 0;
00047     for( int i=0;i<trained_.size();i++ )
00048       out += ( trained_[i].first->accept(v) ? 1 : -1 );
00049     r = out;
00050     r /= 2.*trained_.size();
00051     r += 0.5;
00052   }
00053   else {
00054     for( int i=0;i<trained_.size();i++ )
00055       r += trained_[i].first->response(v);
00056     r /= trained_.size();
00057   }
00058 
00059   // exit
00060   return r;
00061 }
00062 
00063 
00064 void SprTrainedBagger::destroy()
00065 {
00066   for( int i=0;i<trained_.size();i++ ) {
00067     if( trained_[i].second )
00068       delete trained_[i].first;
00069   }
00070 }
00071 
00072 
00073 void SprTrainedBagger::print(std::ostream& os) const
00074 {
00075   os << "Trained Bagger " << SprVersion << endl;
00076   os << "Classifiers: " << trained_.size() << endl;
00077   for( int i=0;i<trained_.size();i++ ) {
00078     os << "Classifier " << i 
00079        << " " << trained_[i].first->name().c_str() << endl;
00080     trained_[i].first->print(os);
00081   }
00082 }
00083 
00084 
00085 bool SprTrainedBagger::generateCode(std::ostream& os) const 
00086 { 
00087   // generate weak classifiers
00088   for( int i=0;i<trained_.size();i++ ) { 
00089     string name = trained_[i].first->name();
00090     os << " // Classifier " << i  
00091        << " \"" << name.c_str() << "\"" << endl; 
00092     if( !trained_[i].first->generateCode(os) ) {
00093       cerr << "Unable to generate code for classifier " << name.c_str() 
00094            << endl;
00095       return false;
00096     }
00097     if( i < trained_.size()-1 ) os << endl; 
00098   }
00099 
00100   // exit
00101   return true; 
00102 } 
00103 
00104 
00105 SprTrainedBagger& SprTrainedBagger::operator+=(const SprTrainedBagger& other)
00106 {
00107   // check vars
00108   if( vars_.size() != other.vars_.size() ) {
00109     cerr << "Unable to add Bagger: variable lists do not match." << endl;
00110     return *this;
00111   }
00112   for( int i=0;i<vars_.size();i++ ) {
00113     if( vars_[i] != other.vars_[i] ) {
00114       cerr << "Unable to add Bagger: variable lists do not match." << endl;
00115       cerr << "Variables " << i << ": " 
00116            << vars_[i] << " " << other.vars_[i] << endl;
00117       return *this;
00118     }
00119   }
00120 
00121   // check discreteness
00122   if( discrete_ != other.discrete_ ) {
00123     cerr << "Unable to add Bagger: discreteness does not match." << endl;
00124     return *this;
00125   }
00126 
00127   // add
00128   for( int i=0;i<other.trained_.size();i++ ) {
00129     trained_.push_back(pair<const SprAbsTrainedClassifier*,
00130                        bool>(other.trained_[i].first->clone(),true));
00131   }
00132   this->setCut(SprUtils::lowerBound(0.5));
00133 
00134   // exit
00135   return *this;
00136 }
00137 
00138 
00139 const SprTrainedBagger operator+(const SprTrainedBagger& l,
00140                                  const SprTrainedBagger& r)
00141 {
00142   // check variable list
00143   assert( l.vars_.size() == r.vars_.size() );
00144   for( int i=0;i<l.vars_.size();i++ )
00145     assert( l.vars_[i] == r.vars_[i] );
00146 
00147   // add classifiers
00148   vector<pair<const SprAbsTrainedClassifier*,bool> > trained;
00149   for( int i=0;i<l.trained_.size();i++ ) {
00150     trained.push_back(pair<const SprAbsTrainedClassifier*,
00151                       bool>(l.trained_[i].first->clone(),true));
00152   }
00153   
00154   for( int i=0;i<r.trained_.size();i++ ) {
00155     trained.push_back(pair<const SprAbsTrainedClassifier*,
00156                       bool>(r.trained_[i].first->clone(),true));
00157   }
00158 
00159   // add discrete
00160   assert( l.discrete_ == r.discrete_ );
00161 
00162   // make bagger and set cut
00163   SprTrainedBagger newBagger(trained,l.discrete_);
00164   newBagger.setCut(SprUtils::lowerBound(0.5));
00165 
00166   // exit
00167   return newBagger;
00168 }

Generated on Tue Jun 9 17:42:04 2009 for CMSSW by  doxygen 1.5.4