CMS 3D CMS Logo

SprDecisionTree.cc

Go to the documentation of this file.
00001 //$Id: SprDecisionTree.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprDecisionTree.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprTreeNode.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00010 
00011 #include <stdio.h>
00012 #include <functional>
00013 #include <algorithm>
00014 #include <cassert>
00015 
00016 using namespace std;
00017 
00018 
00019 struct SDTCmpPairFirst
00020   : public binary_function<pair<double,const SprTreeNode*>,
00021                            pair<double,const SprTreeNode*>,
00022                            bool> {
00023   bool operator()(const pair<double,const SprTreeNode*>& l, 
00024                   const pair<double,const SprTreeNode*>& r)
00025     const {
00026     return (l.first < r.first);
00027   }
00028 };
00029  
00030 
00031 SprDecisionTree::~SprDecisionTree()
00032 {
00033   delete root_;
00034 }
00035 
00036 
00037 SprDecisionTree::SprDecisionTree(SprAbsFilter* data, 
00038                                  const SprAbsTwoClassCriterion* crit,
00039                                  int nmin, bool doMerge, bool discrete,
00040                                  SprIntegerBootstrap* bootstrap)
00041   :
00042   SprAbsClassifier(data),
00043   cls0_(0),
00044   cls1_(1),
00045   crit_(crit),
00046   nmin_(nmin),
00047   doMerge_(doMerge),
00048   discrete_(discrete),
00049   canHavePureNodes_(true),
00050   fastSort_(false),
00051   showBackgroundNodes_(false),
00052   bootstrap_(bootstrap),
00053   root_(0),
00054   nodes1_(),
00055   nodes0_(),
00056   fullNodeList_(),
00057   fom_(0),
00058   w0_(0),
00059   w1_(0),
00060   n0_(0),
00061   n1_(0),
00062   splits_()
00063 {
00064   // check nmin
00065   if( nmin_ <= 0 ) {
00066     cout << "Resetting minimal number of events per node to 1." << endl;
00067     nmin_ = 1;
00068   }
00069   cout << "Decision tree initialized mith minimal number of events per node "
00070        << nmin_ << endl;
00071 
00072   // check bootstrap
00073   if( bootstrap_ != 0 ) {
00074     cout << "Decision tree will resample at most " 
00075          << bootstrap->nsample() << " features." << endl;
00076   }
00077 
00078   // check discrete
00079   if( doMerge_ && !discrete_ ) {
00080     discrete_ = true;
00081     cout << "Warning: continuous output is not allowed for trees with "
00082          << "merged terminal nodes." << endl;
00083     cout << "Switching to discrete (0/1) tree output." << endl;
00084   }
00085 
00086   // make root
00087   root_ = new SprTreeNode(crit,data,doMerge,nmin_,discrete_,
00088                           canHavePureNodes_,fastSort_,bootstrap_);
00089 
00090   // set classes
00091   this->setClasses();
00092   bool status = root_->setClasses(cls0_,cls1_);
00093   assert ( status );
00094 }
00095 
00096 
00097 void SprDecisionTree::setClasses() 
00098 {
00099   vector<SprClass> classes;
00100   data_->classes(classes);
00101   int size = classes.size();
00102   if( size > 0 ) cls0_ = classes[0];
00103   if( size > 1 ) cls1_ = classes[1];
00104   //  cout << "Classes for decision tree are set to " 
00105   //       << cls0_ << " " << cls1_ << endl;
00106 }
00107 
00108 
00109 SprTrainedDecisionTree* SprDecisionTree::makeTrained() const 
00110 {
00111   // prepare vectors of accepted regions
00112   vector<SprBox> nodes1(nodes1_.size());
00113 
00114   // copy box limits
00115   for( int i=0;i<nodes1_.size();i++ )
00116     nodes1[i] = nodes1_[i]->limits_;
00117 
00118   // make tree
00119   SprTrainedDecisionTree* t =  new SprTrainedDecisionTree(nodes1);
00120 
00121   // vars
00122   vector<string> vars;
00123   data_->vars(vars);
00124   t->setVars(vars);
00125 
00126   // exit
00127   return t;
00128 }
00129 
00130 
00131 const SprTreeNode* SprDecisionTree::next(const SprTreeNode* node) const
00132 {
00133   // travel up
00134   const SprTreeNode* temp = node;
00135   while( temp->parent_!=0 && temp->parent_->right_==temp )
00136     temp = temp->parent_;
00137 
00138   // if root, exit
00139   if( temp->parent_ == 0 ) return 0;
00140 
00141   // go over the hill
00142   temp = temp->parent_->right_;
00143 
00144   // travel down
00145   while( temp->left_ != 0 )
00146     temp = temp->left_;
00147 
00148   // exit
00149   return temp;
00150 }
00151 
00152 
00153 const SprTreeNode* SprDecisionTree::first() const
00154 {
00155   const SprTreeNode* temp = root_;
00156   while( temp->left_ != 0 )
00157     temp = temp->left_;
00158   return temp;
00159 }
00160 
00161 
00162 bool SprDecisionTree::train(int verbose)
00163 {
00164   // train the tree
00165   fullNodeList_.clear();
00166   fullNodeList_.push_back(root_);
00167   int splitIndex = 0;
00168   while( splitIndex < fullNodeList_.size() ) {
00169     SprTreeNode* node = fullNodeList_[splitIndex];
00170     if( !node->split(fullNodeList_,splits_,verbose) ) {
00171       cerr << "Unable to split node with index " << splitIndex << endl;
00172       return false;
00173     }
00174     splitIndex++;
00175   }
00176 
00177   // merge
00178   if( !this->merge(1,doMerge_,nodes1_,fom_,w0_,w1_,n0_,n1_,verbose) ) {
00179     cerr << "Unable to merge signal nodes." << endl;
00180     return false;
00181   }
00182   if( doMerge_ ) showBackgroundNodes_ = false;
00183   if( showBackgroundNodes_ ) {
00184     double fom(0), w0(0), w1(0);
00185     unsigned n0(0), n1(0);
00186     if( !this->merge(0,false,nodes0_,fom,w0,w1,n0,n1,verbose) ) {
00187       cerr << "Unable to merge background nodes." << endl;
00188       return false;
00189     }
00190     // show overall FOM
00191     double totFom = crit_->fom(w0,w0_,w1_,w1);
00192     if( verbose > 0 ) {
00193       cout << "Included " << nodes1_.size()+nodes0_.size() 
00194            << " nodes with overall FOM=" << totFom << endl;
00195     }
00196   }
00197 
00198   // exit
00199   return true;
00200 }
00201 
00202 
00203 bool SprDecisionTree::reset()
00204 {
00205   delete root_;
00206   root_ = new SprTreeNode(crit_,data_,doMerge_,nmin_,discrete_,
00207                           canHavePureNodes_,fastSort_,bootstrap_);
00208   if( !root_->setClasses(cls0_,cls1_) ) return false;
00209   nodes1_.clear();
00210   nodes0_.clear();
00211   fullNodeList_.clear();
00212   w0_ = 0; w1_ = 0;
00213   n0_ = 0; n1_ = 0;
00214   fom_ = SprUtils::min();
00215   return true;
00216 }
00217 
00218 
00219 bool SprDecisionTree::setData(SprAbsFilter* data)
00220 {
00221   assert( data != 0 );
00222   data_ = data;
00223   return this->reset();
00224 }
00225 
00226 
00227 bool SprDecisionTree::merge(int category, bool doMerge,
00228                             std::vector<const SprTreeNode*>& nodes,
00229                             double& fomtot, double& w0tot, double& w1tot,
00230                             unsigned& n0tot, unsigned& n1tot, int verbose) 
00231   const
00232 {
00233   // find leaf nodes
00234   vector<const SprTreeNode*> collect;
00235   const SprTreeNode* temp = this->first();
00236   while( temp != 0 ) {
00237     if( temp->nodeClass() == category )
00238       collect.push_back(temp);
00239     temp = this->next(temp);
00240   }
00241   if( collect.empty() ) {
00242     if( verbose > 0 )
00243       cerr << "No leaf nodes found for category " << category << endl;
00244     return true;
00245   }
00246   int size = collect.size();
00247   if( verbose > 1 ) {
00248     cout << "Found " << size << " leaf nodes in category " 
00249          << category << ":     ";
00250     for( int i=0;i<size;i++ )
00251       cout << collect[i]->id() << " ";
00252     cout << endl;
00253   }
00254 
00255   // sort leaf nodes by purity
00256   vector<pair<double,const SprTreeNode*> > purity(size);
00257   for( int i=0;i<size;i++ ) {
00258     const SprTreeNode* node = collect[i];
00259     double w0 = node->w0();
00260     double w1 = node->w1();
00261     if( (w1+w0) < SprUtils::eps() ) {
00262       cerr << "Found a node without events: " << node->id() << endl;
00263       return false;
00264     }
00265     if(      category == 1 )
00266       purity[i] = pair<double,const SprTreeNode*>(w1/(w1+w0),node);
00267     else if( category == 0 )
00268       purity[i] = pair<double,const SprTreeNode*>(w0/(w1+w0),node);
00269   }
00270   stable_sort(purity.begin(),purity.end(),not2(SDTCmpPairFirst()));
00271   for( int i=0;i<size;i++ ) {
00272     collect[i] = purity[i].second;
00273   }
00274   if( verbose > 1 ) {
00275     cout << "Nodes sorted by purity: " << endl;
00276     for( int i=0;i<size;i++ )
00277       cout << collect[i]->id() << " ";
00278     cout << endl;
00279   }
00280 
00281   // add nodes in the order of decreasing purity
00282   vector<double> fomVec(size), w0Vec(size), w1Vec(size);
00283   vector<unsigned> n0Vec(size), n1Vec(size);
00284   double w0(0), w1(0);
00285   unsigned n0(0), n1(0);
00286   for( int j=0;j<size;j++ ) {
00287     const SprTreeNode* node = collect[j];
00288     double w0add = node->w0();
00289     double w1add = node->w1();
00290     w0 += w0add;
00291     w1 += w1add;
00292     n0 += node->n0();
00293     n1 += node->n1();
00294     double fom = 0;
00295     if(      category == 1 )
00296       fom = crit_->fom(0,w0,w1,0);
00297     else if( category == 0 ) 
00298       fom = crit_->fom(w0,0,0,w1);
00299     fomVec[j] = fom;
00300     w0Vec[j] = w0;
00301     w1Vec[j] = w1;
00302     n0Vec[j] = n0;
00303     n1Vec[j] = n1;
00304     if( verbose > 1 ) {
00305       cout << "Adding node " << node->id() 
00306            << " with " << w0add << " background and "
00307            << w1add << " signal weights at overall FOM=" << fom 
00308            << endl;
00309     }
00310   }
00311 
00312   // find the combination with largest FOM
00313   int best = size-1;
00314   if( doMerge ) {
00315     // if nodes have equal FOM's, prefer those with more events
00316     vector<double>::reverse_iterator iter 
00317       = max_element(fomVec.rbegin(),fomVec.rend());
00318     best = iter - fomVec.rbegin();
00319     best = size-1 - best;
00320   }
00321   double fom0 = fomVec[best];
00322   w0 = w0Vec[best];
00323   w1 = w1Vec[best];
00324   n0 = n0Vec[best];
00325   n1 = n1Vec[best];
00326   nodes.clear();
00327   for( int i=0;i<=best;i++ ) {
00328     nodes.push_back(collect[i]);
00329   }
00330 
00331   // message
00332   if( verbose > 0 ) {
00333     cout << "Included " << nodes.size() 
00334          << " nodes in category " << category
00335          << " with overall FOM=" << fom0 
00336          << "    W1=" << w1 << " W0=" << w0 
00337          << "    N1=" << n1 << " N0=" << n0 << endl;
00338   }
00339   if( verbose > 1 ) {
00340     cout << "Node list: ";
00341     for( int i=0;i<nodes.size();i++ ) cout << nodes[i]->id() << " ";
00342     cout << endl;
00343   }
00344 
00345   // assign FOM and weights
00346   fomtot = fom0;
00347   w0tot = w0;
00348   w1tot = w1;
00349   n0tot = n0;
00350   n1tot = n1;
00351 
00352   // exit
00353   return true;
00354 }
00355 
00356 
00357 void SprDecisionTree::print(std::ostream& os) const
00358 {
00359   // header
00360   char s [200];
00361   sprintf(s,"Trained DecisionTree %-6i signal nodes.    Overall FOM=%-10g W0=%-10g W1=%-10g N0=%-10i N1=%-10i    Version=%s",nodes1_.size(),fom_,w0_,w1_,n0_,n1_,SprVersion.c_str());
00362   os << s << endl;
00363   os << "-------------------------------------------------------" << endl;
00364 
00365   // get vars
00366   vector<string> vars;
00367   data_->vars(vars);
00368 
00369   // signal nodes
00370   os << "-------------------------------------------------------" << endl;
00371   os << "Signal nodes:" << endl;
00372   os << "-------------------------------------------------------" << endl;
00373   for( int i=0;i<nodes1_.size();i++ ) {
00374     const SprBox& limits = nodes1_[i]->limits_;
00375     int size = limits.size();
00376     char s [200];
00377     sprintf(s,"Node %6i    Size %-4i    FOM=%-10g W0=%-10g W1=%-10g N0=%-10i N1=%-10i",i,size,nodes1_[i]->fom(),nodes1_[i]->w0(),nodes1_[i]->w1(),nodes1_[i]->n0(),nodes1_[i]->n1());
00378     os << s << endl;
00379     for( SprBox::const_iterator iter = 
00380            limits.begin();iter!=limits.end();iter++ ) {
00381       unsigned d = iter->first;
00382       assert( d < vars.size() );
00383       char s [200];
00384       sprintf(s,"Variable %30s    Limits  %15g %15g",
00385               vars[d].c_str(),iter->second.first,iter->second.second);
00386       os << s << endl;
00387     }
00388     os << "-------------------------------------------------------" << endl;
00389   }
00390 
00391   // background nodes
00392   if( showBackgroundNodes_ ) {
00393     os << "-------------------------------------------------------" << endl;
00394     os << "Background nodes:" << endl;
00395     os << "-------------------------------------------------------" << endl;
00396     for( int i=0;i<nodes0_.size();i++ ) {
00397       const SprBox& limits = nodes0_[i]->limits_;
00398       int size = limits.size();
00399       char s [200];
00400       sprintf(s,"Node %6i    Size %-4i    FOM=%-10g W0=%-10g W1=%-10g N0=%-10i N1=%-10i",i,size,nodes0_[i]->fom(),nodes0_[i]->w0(),nodes0_[i]->w1(),nodes0_[i]->n0(),nodes0_[i]->n1());
00401       os << s << endl;
00402       for( SprBox::const_iterator iter = 
00403              limits.begin();iter!=limits.end();iter++ ) {
00404         unsigned d = iter->first;
00405         assert( d < vars.size() );
00406         char s [200];
00407         sprintf(s,"Variable %30s    Limits  %15g %15g",
00408                 vars[d].c_str(),iter->second.first,iter->second.second);
00409         os << s << endl;
00410       }
00411       os << "-------------------------------------------------------" << endl;
00412     }
00413   }
00414 }
00415 
00416 
00417 void SprDecisionTree::startSplitCounter()
00418 {
00419   splits_.clear();
00420   splits_.resize(data_->dim(),pair<int,double>(0,0));
00421 }
00422 
00423 
00424 void SprDecisionTree::printSplitCounter(std::ostream& os) const
00425 {
00426   unsigned dim = data_->dim();
00427   assert( splits_.size() == dim );
00428   vector<string> vars;
00429   data_->vars(vars);
00430   assert( vars.size() == dim );
00431   os << "Tree splits on variables:" << endl;
00432   for( int i=0;i<dim;i++ ) {
00433     char s [200];
00434     sprintf(s,"Variable %30s    Splits  %10i    Delta FOM  %10.5f",
00435             vars[i].c_str(),splits_[i].first,splits_[i].second);
00436     os << s << endl;
00437   }
00438 }
00439 
00440 
00441 bool SprDecisionTree::setClasses(const SprClass& cls0, const SprClass& cls1) 
00442 {
00443   cls0_ = cls0;
00444   cls1_ = cls1;
00445   if( root_ != 0 ) 
00446     return root_->setClasses(cls0,cls1);
00447   return true;
00448 }

Generated on Tue Jun 9 17:42:02 2009 for CMSSW by  doxygen 1.5.4