00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedBagger.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00007
00008 #include <stdio.h>
00009 #include <cassert>
00010
00011 using namespace std;
00012
00013
00014 SprTrainedBagger::SprTrainedBagger(const std::vector<
00015 std::pair<const SprAbsTrainedClassifier*,bool> >&
00016 trained, bool discrete)
00017 :
00018 SprAbsTrainedClassifier(),
00019 trained_(trained),
00020 discrete_(discrete)
00021 {
00022 assert( !trained_.empty() );
00023 this->setCut(SprUtils::lowerBound(0.5));
00024 }
00025
00026
00027 SprTrainedBagger::SprTrainedBagger(const SprTrainedBagger& other)
00028 :
00029 SprAbsTrainedClassifier(other),
00030 trained_(),
00031 discrete_(other.discrete_)
00032 {
00033 for( int i=0;i<other.trained_.size();i++ )
00034 trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>
00035 (other.trained_[i].first->clone(),true));
00036 }
00037
00038
00039 double SprTrainedBagger::response(const std::vector<double>& v) const
00040 {
00041
00042 double r = 0;
00043
00044
00045 if( discrete_ ) {
00046 int out = 0;
00047 for( int i=0;i<trained_.size();i++ )
00048 out += ( trained_[i].first->accept(v) ? 1 : -1 );
00049 r = out;
00050 r /= 2.*trained_.size();
00051 r += 0.5;
00052 }
00053 else {
00054 for( int i=0;i<trained_.size();i++ )
00055 r += trained_[i].first->response(v);
00056 r /= trained_.size();
00057 }
00058
00059
00060 return r;
00061 }
00062
00063
00064 void SprTrainedBagger::destroy()
00065 {
00066 for( int i=0;i<trained_.size();i++ ) {
00067 if( trained_[i].second )
00068 delete trained_[i].first;
00069 }
00070 }
00071
00072
00073 void SprTrainedBagger::print(std::ostream& os) const
00074 {
00075 os << "Trained Bagger " << SprVersion << endl;
00076 os << "Classifiers: " << trained_.size() << endl;
00077 for( int i=0;i<trained_.size();i++ ) {
00078 os << "Classifier " << i
00079 << " " << trained_[i].first->name().c_str() << endl;
00080 trained_[i].first->print(os);
00081 }
00082 }
00083
00084
00085 bool SprTrainedBagger::generateCode(std::ostream& os) const
00086 {
00087
00088 for( int i=0;i<trained_.size();i++ ) {
00089 string name = trained_[i].first->name();
00090 os << " // Classifier " << i
00091 << " \"" << name.c_str() << "\"" << endl;
00092 if( !trained_[i].first->generateCode(os) ) {
00093 cerr << "Unable to generate code for classifier " << name.c_str()
00094 << endl;
00095 return false;
00096 }
00097 if( i < trained_.size()-1 ) os << endl;
00098 }
00099
00100
00101 return true;
00102 }
00103
00104
00105 SprTrainedBagger& SprTrainedBagger::operator+=(const SprTrainedBagger& other)
00106 {
00107
00108 if( vars_.size() != other.vars_.size() ) {
00109 cerr << "Unable to add Bagger: variable lists do not match." << endl;
00110 return *this;
00111 }
00112 for( int i=0;i<vars_.size();i++ ) {
00113 if( vars_[i] != other.vars_[i] ) {
00114 cerr << "Unable to add Bagger: variable lists do not match." << endl;
00115 cerr << "Variables " << i << ": "
00116 << vars_[i] << " " << other.vars_[i] << endl;
00117 return *this;
00118 }
00119 }
00120
00121
00122 if( discrete_ != other.discrete_ ) {
00123 cerr << "Unable to add Bagger: discreteness does not match." << endl;
00124 return *this;
00125 }
00126
00127
00128 for( int i=0;i<other.trained_.size();i++ ) {
00129 trained_.push_back(pair<const SprAbsTrainedClassifier*,
00130 bool>(other.trained_[i].first->clone(),true));
00131 }
00132 this->setCut(SprUtils::lowerBound(0.5));
00133
00134
00135 return *this;
00136 }
00137
00138
00139 const SprTrainedBagger operator+(const SprTrainedBagger& l,
00140 const SprTrainedBagger& r)
00141 {
00142
00143 assert( l.vars_.size() == r.vars_.size() );
00144 for( int i=0;i<l.vars_.size();i++ )
00145 assert( l.vars_[i] == r.vars_[i] );
00146
00147
00148 vector<pair<const SprAbsTrainedClassifier*,bool> > trained;
00149 for( int i=0;i<l.trained_.size();i++ ) {
00150 trained.push_back(pair<const SprAbsTrainedClassifier*,
00151 bool>(l.trained_[i].first->clone(),true));
00152 }
00153
00154 for( int i=0;i<r.trained_.size();i++ ) {
00155 trained.push_back(pair<const SprAbsTrainedClassifier*,
00156 bool>(r.trained_[i].first->clone(),true));
00157 }
00158
00159
00160 assert( l.discrete_ == r.discrete_ );
00161
00162
00163 SprTrainedBagger newBagger(trained,l.discrete_);
00164 newBagger.setCut(SprUtils::lowerBound(0.5));
00165
00166
00167 return newBagger;
00168 }