00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprTreeNode.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00010
00011 #include <stdio.h>
00012 #include <functional>
00013 #include <algorithm>
00014 #include <cassert>
00015
00016 using namespace std;
00017
00018
00019 struct SDTCmpPairFirst
00020 : public binary_function<pair<double,const SprTreeNode*>,
00021 pair<double,const SprTreeNode*>,
00022 bool> {
00023 bool operator()(const pair<double,const SprTreeNode*>& l,
00024 const pair<double,const SprTreeNode*>& r)
00025 const {
00026 return (l.first < r.first);
00027 }
00028 };
00029
00030
00031 SprDecisionTree::~SprDecisionTree()
00032 {
00033 delete root_;
00034 }
00035
00036
00037 SprDecisionTree::SprDecisionTree(SprAbsFilter* data,
00038 const SprAbsTwoClassCriterion* crit,
00039 int nmin, bool doMerge, bool discrete,
00040 SprIntegerBootstrap* bootstrap)
00041 :
00042 SprAbsClassifier(data),
00043 cls0_(0),
00044 cls1_(1),
00045 crit_(crit),
00046 nmin_(nmin),
00047 doMerge_(doMerge),
00048 discrete_(discrete),
00049 canHavePureNodes_(true),
00050 fastSort_(false),
00051 showBackgroundNodes_(false),
00052 bootstrap_(bootstrap),
00053 root_(0),
00054 nodes1_(),
00055 nodes0_(),
00056 fullNodeList_(),
00057 fom_(0),
00058 w0_(0),
00059 w1_(0),
00060 n0_(0),
00061 n1_(0),
00062 splits_()
00063 {
00064
00065 if( nmin_ <= 0 ) {
00066 cout << "Resetting minimal number of events per node to 1." << endl;
00067 nmin_ = 1;
00068 }
00069 cout << "Decision tree initialized mith minimal number of events per node "
00070 << nmin_ << endl;
00071
00072
00073 if( bootstrap_ != 0 ) {
00074 cout << "Decision tree will resample at most "
00075 << bootstrap->nsample() << " features." << endl;
00076 }
00077
00078
00079 if( doMerge_ && !discrete_ ) {
00080 discrete_ = true;
00081 cout << "Warning: continuous output is not allowed for trees with "
00082 << "merged terminal nodes." << endl;
00083 cout << "Switching to discrete (0/1) tree output." << endl;
00084 }
00085
00086
00087 root_ = new SprTreeNode(crit,data,doMerge,nmin_,discrete_,
00088 canHavePureNodes_,fastSort_,bootstrap_);
00089
00090
00091 this->setClasses();
00092 bool status = root_->setClasses(cls0_,cls1_);
00093 assert ( status );
00094 }
00095
00096
00097 void SprDecisionTree::setClasses()
00098 {
00099 vector<SprClass> classes;
00100 data_->classes(classes);
00101 int size = classes.size();
00102 if( size > 0 ) cls0_ = classes[0];
00103 if( size > 1 ) cls1_ = classes[1];
00104
00105
00106 }
00107
00108
00109 SprTrainedDecisionTree* SprDecisionTree::makeTrained() const
00110 {
00111
00112 vector<SprBox> nodes1(nodes1_.size());
00113
00114
00115 for( int i=0;i<nodes1_.size();i++ )
00116 nodes1[i] = nodes1_[i]->limits_;
00117
00118
00119 SprTrainedDecisionTree* t = new SprTrainedDecisionTree(nodes1);
00120
00121
00122 vector<string> vars;
00123 data_->vars(vars);
00124 t->setVars(vars);
00125
00126
00127 return t;
00128 }
00129
00130
00131 const SprTreeNode* SprDecisionTree::next(const SprTreeNode* node) const
00132 {
00133
00134 const SprTreeNode* temp = node;
00135 while( temp->parent_!=0 && temp->parent_->right_==temp )
00136 temp = temp->parent_;
00137
00138
00139 if( temp->parent_ == 0 ) return 0;
00140
00141
00142 temp = temp->parent_->right_;
00143
00144
00145 while( temp->left_ != 0 )
00146 temp = temp->left_;
00147
00148
00149 return temp;
00150 }
00151
00152
00153 const SprTreeNode* SprDecisionTree::first() const
00154 {
00155 const SprTreeNode* temp = root_;
00156 while( temp->left_ != 0 )
00157 temp = temp->left_;
00158 return temp;
00159 }
00160
00161
00162 bool SprDecisionTree::train(int verbose)
00163 {
00164
00165 fullNodeList_.clear();
00166 fullNodeList_.push_back(root_);
00167 int splitIndex = 0;
00168 while( splitIndex < fullNodeList_.size() ) {
00169 SprTreeNode* node = fullNodeList_[splitIndex];
00170 if( !node->split(fullNodeList_,splits_,verbose) ) {
00171 cerr << "Unable to split node with index " << splitIndex << endl;
00172 return false;
00173 }
00174 splitIndex++;
00175 }
00176
00177
00178 if( !this->merge(1,doMerge_,nodes1_,fom_,w0_,w1_,n0_,n1_,verbose) ) {
00179 cerr << "Unable to merge signal nodes." << endl;
00180 return false;
00181 }
00182 if( doMerge_ ) showBackgroundNodes_ = false;
00183 if( showBackgroundNodes_ ) {
00184 double fom(0), w0(0), w1(0);
00185 unsigned n0(0), n1(0);
00186 if( !this->merge(0,false,nodes0_,fom,w0,w1,n0,n1,verbose) ) {
00187 cerr << "Unable to merge background nodes." << endl;
00188 return false;
00189 }
00190
00191 double totFom = crit_->fom(w0,w0_,w1_,w1);
00192 if( verbose > 0 ) {
00193 cout << "Included " << nodes1_.size()+nodes0_.size()
00194 << " nodes with overall FOM=" << totFom << endl;
00195 }
00196 }
00197
00198
00199 return true;
00200 }
00201
00202
00203 bool SprDecisionTree::reset()
00204 {
00205 delete root_;
00206 root_ = new SprTreeNode(crit_,data_,doMerge_,nmin_,discrete_,
00207 canHavePureNodes_,fastSort_,bootstrap_);
00208 if( !root_->setClasses(cls0_,cls1_) ) return false;
00209 nodes1_.clear();
00210 nodes0_.clear();
00211 fullNodeList_.clear();
00212 w0_ = 0; w1_ = 0;
00213 n0_ = 0; n1_ = 0;
00214 fom_ = SprUtils::min();
00215 return true;
00216 }
00217
00218
00219 bool SprDecisionTree::setData(SprAbsFilter* data)
00220 {
00221 assert( data != 0 );
00222 data_ = data;
00223 return this->reset();
00224 }
00225
00226
00227 bool SprDecisionTree::merge(int category, bool doMerge,
00228 std::vector<const SprTreeNode*>& nodes,
00229 double& fomtot, double& w0tot, double& w1tot,
00230 unsigned& n0tot, unsigned& n1tot, int verbose)
00231 const
00232 {
00233
00234 vector<const SprTreeNode*> collect;
00235 const SprTreeNode* temp = this->first();
00236 while( temp != 0 ) {
00237 if( temp->nodeClass() == category )
00238 collect.push_back(temp);
00239 temp = this->next(temp);
00240 }
00241 if( collect.empty() ) {
00242 if( verbose > 0 )
00243 cerr << "No leaf nodes found for category " << category << endl;
00244 return true;
00245 }
00246 int size = collect.size();
00247 if( verbose > 1 ) {
00248 cout << "Found " << size << " leaf nodes in category "
00249 << category << ": ";
00250 for( int i=0;i<size;i++ )
00251 cout << collect[i]->id() << " ";
00252 cout << endl;
00253 }
00254
00255
00256 vector<pair<double,const SprTreeNode*> > purity(size);
00257 for( int i=0;i<size;i++ ) {
00258 const SprTreeNode* node = collect[i];
00259 double w0 = node->w0();
00260 double w1 = node->w1();
00261 if( (w1+w0) < SprUtils::eps() ) {
00262 cerr << "Found a node without events: " << node->id() << endl;
00263 return false;
00264 }
00265 if( category == 1 )
00266 purity[i] = pair<double,const SprTreeNode*>(w1/(w1+w0),node);
00267 else if( category == 0 )
00268 purity[i] = pair<double,const SprTreeNode*>(w0/(w1+w0),node);
00269 }
00270 stable_sort(purity.begin(),purity.end(),not2(SDTCmpPairFirst()));
00271 for( int i=0;i<size;i++ ) {
00272 collect[i] = purity[i].second;
00273 }
00274 if( verbose > 1 ) {
00275 cout << "Nodes sorted by purity: " << endl;
00276 for( int i=0;i<size;i++ )
00277 cout << collect[i]->id() << " ";
00278 cout << endl;
00279 }
00280
00281
00282 vector<double> fomVec(size), w0Vec(size), w1Vec(size);
00283 vector<unsigned> n0Vec(size), n1Vec(size);
00284 double w0(0), w1(0);
00285 unsigned n0(0), n1(0);
00286 for( int j=0;j<size;j++ ) {
00287 const SprTreeNode* node = collect[j];
00288 double w0add = node->w0();
00289 double w1add = node->w1();
00290 w0 += w0add;
00291 w1 += w1add;
00292 n0 += node->n0();
00293 n1 += node->n1();
00294 double fom = 0;
00295 if( category == 1 )
00296 fom = crit_->fom(0,w0,w1,0);
00297 else if( category == 0 )
00298 fom = crit_->fom(w0,0,0,w1);
00299 fomVec[j] = fom;
00300 w0Vec[j] = w0;
00301 w1Vec[j] = w1;
00302 n0Vec[j] = n0;
00303 n1Vec[j] = n1;
00304 if( verbose > 1 ) {
00305 cout << "Adding node " << node->id()
00306 << " with " << w0add << " background and "
00307 << w1add << " signal weights at overall FOM=" << fom
00308 << endl;
00309 }
00310 }
00311
00312
00313 int best = size-1;
00314 if( doMerge ) {
00315
00316 vector<double>::reverse_iterator iter
00317 = max_element(fomVec.rbegin(),fomVec.rend());
00318 best = iter - fomVec.rbegin();
00319 best = size-1 - best;
00320 }
00321 double fom0 = fomVec[best];
00322 w0 = w0Vec[best];
00323 w1 = w1Vec[best];
00324 n0 = n0Vec[best];
00325 n1 = n1Vec[best];
00326 nodes.clear();
00327 for( int i=0;i<=best;i++ ) {
00328 nodes.push_back(collect[i]);
00329 }
00330
00331
00332 if( verbose > 0 ) {
00333 cout << "Included " << nodes.size()
00334 << " nodes in category " << category
00335 << " with overall FOM=" << fom0
00336 << " W1=" << w1 << " W0=" << w0
00337 << " N1=" << n1 << " N0=" << n0 << endl;
00338 }
00339 if( verbose > 1 ) {
00340 cout << "Node list: ";
00341 for( int i=0;i<nodes.size();i++ ) cout << nodes[i]->id() << " ";
00342 cout << endl;
00343 }
00344
00345
00346 fomtot = fom0;
00347 w0tot = w0;
00348 w1tot = w1;
00349 n0tot = n0;
00350 n1tot = n1;
00351
00352
00353 return true;
00354 }
00355
00356
00357 void SprDecisionTree::print(std::ostream& os) const
00358 {
00359
00360 char s [200];
00361 sprintf(s,"Trained DecisionTree %-6i signal nodes. Overall FOM=%-10g W0=%-10g W1=%-10g N0=%-10i N1=%-10i Version=%s",nodes1_.size(),fom_,w0_,w1_,n0_,n1_,SprVersion.c_str());
00362 os << s << endl;
00363 os << "-------------------------------------------------------" << endl;
00364
00365
00366 vector<string> vars;
00367 data_->vars(vars);
00368
00369
00370 os << "-------------------------------------------------------" << endl;
00371 os << "Signal nodes:" << endl;
00372 os << "-------------------------------------------------------" << endl;
00373 for( int i=0;i<nodes1_.size();i++ ) {
00374 const SprBox& limits = nodes1_[i]->limits_;
00375 int size = limits.size();
00376 char s [200];
00377 sprintf(s,"Node %6i Size %-4i FOM=%-10g W0=%-10g W1=%-10g N0=%-10i N1=%-10i",i,size,nodes1_[i]->fom(),nodes1_[i]->w0(),nodes1_[i]->w1(),nodes1_[i]->n0(),nodes1_[i]->n1());
00378 os << s << endl;
00379 for( SprBox::const_iterator iter =
00380 limits.begin();iter!=limits.end();iter++ ) {
00381 unsigned d = iter->first;
00382 assert( d < vars.size() );
00383 char s [200];
00384 sprintf(s,"Variable %30s Limits %15g %15g",
00385 vars[d].c_str(),iter->second.first,iter->second.second);
00386 os << s << endl;
00387 }
00388 os << "-------------------------------------------------------" << endl;
00389 }
00390
00391
00392 if( showBackgroundNodes_ ) {
00393 os << "-------------------------------------------------------" << endl;
00394 os << "Background nodes:" << endl;
00395 os << "-------------------------------------------------------" << endl;
00396 for( int i=0;i<nodes0_.size();i++ ) {
00397 const SprBox& limits = nodes0_[i]->limits_;
00398 int size = limits.size();
00399 char s [200];
00400 sprintf(s,"Node %6i Size %-4i FOM=%-10g W0=%-10g W1=%-10g N0=%-10i N1=%-10i",i,size,nodes0_[i]->fom(),nodes0_[i]->w0(),nodes0_[i]->w1(),nodes0_[i]->n0(),nodes0_[i]->n1());
00401 os << s << endl;
00402 for( SprBox::const_iterator iter =
00403 limits.begin();iter!=limits.end();iter++ ) {
00404 unsigned d = iter->first;
00405 assert( d < vars.size() );
00406 char s [200];
00407 sprintf(s,"Variable %30s Limits %15g %15g",
00408 vars[d].c_str(),iter->second.first,iter->second.second);
00409 os << s << endl;
00410 }
00411 os << "-------------------------------------------------------" << endl;
00412 }
00413 }
00414 }
00415
00416
00417 void SprDecisionTree::startSplitCounter()
00418 {
00419 splits_.clear();
00420 splits_.resize(data_->dim(),pair<int,double>(0,0));
00421 }
00422
00423
00424 void SprDecisionTree::printSplitCounter(std::ostream& os) const
00425 {
00426 unsigned dim = data_->dim();
00427 assert( splits_.size() == dim );
00428 vector<string> vars;
00429 data_->vars(vars);
00430 assert( vars.size() == dim );
00431 os << "Tree splits on variables:" << endl;
00432 for( int i=0;i<dim;i++ ) {
00433 char s [200];
00434 sprintf(s,"Variable %30s Splits %10i Delta FOM %10.5f",
00435 vars[i].c_str(),splits_[i].first,splits_[i].second);
00436 os << s << endl;
00437 }
00438 }
00439
00440
00441 bool SprDecisionTree::setClasses(const SprClass& cls0, const SprClass& cls1)
00442 {
00443 cls0_ = cls0;
00444 cls1_ = cls1;
00445 if( root_ != 0 )
00446 return root_->setClasses(cls0,cls1);
00447 return true;
00448 }