CMS 3D CMS Logo

SprTreeNode.cc

Go to the documentation of this file.
00001 //$Id: SprTreeNode.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTreeNode.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedNode.hh"
00011 
00012 #include <iostream>
00013 #include <set>
00014 #include <algorithm>
00015 #include <functional>
00016 #include <cassert>
00017 #include <cmath>
00018 
00019 using namespace std;
00020 
00021 
00022 int SprTreeNode::counter_ = 0;
00023 
00024 struct STNCmpPairFirst 
00025   : public binary_function<pair<double,int>,pair<double,int>,bool> {
00026   bool operator()(const pair<double,int>& l, const pair<double,int>& r)
00027     const {
00028     return (l.first < r.first);
00029   }
00030 };
00031 
00032 
00033 SprTreeNode::~SprTreeNode()
00034 {
00035   delete data_;
00036   delete left_;
00037   delete right_;
00038 }
00039 
00040 
00041 SprTreeNode::SprTreeNode(const SprAbsTwoClassCriterion* crit,
00042                          const SprAbsFilter* data,
00043                          bool allLeafsSignal,
00044                          int nmin,
00045                          bool discrete,
00046                          bool canHavePureNodes,
00047                          bool fastSort,
00048                          SprIntegerBootstrap* bootstrap)
00049   :
00050   crit_(crit),
00051   data_(new SprBoxFilter(data)),
00052   allLeafsSignal_(allLeafsSignal),
00053   nmin_(nmin),
00054   discrete_(discrete),
00055   canHavePureNodes_(canHavePureNodes),
00056   fastSort_(fastSort),
00057   cls0_(0),
00058   cls1_(1),
00059   parent_(0),
00060   left_(0),
00061   right_(0),
00062   fom_(0),
00063   w0_(0),
00064   w1_(0),
00065   n0_(0),
00066   n1_(0),
00067   limits_(),
00068   id_(0),
00069   nodeClass_(-1),
00070   d_(-1),
00071   cut_(0),
00072   bootstrap_(bootstrap)
00073 {
00074   assert( crit_ != 0 );
00075   assert( data_->size() > nmin_ );
00076   counter_ = 0;// if no parent specified, starting a new tree from scratch
00077 }
00078 
00079 
00080 SprTreeNode::SprTreeNode(const SprAbsTwoClassCriterion* crit,
00081                          const SprBoxFilter& data, 
00082                          bool allLeafsSignal,
00083                          int nmin,
00084                          bool discrete,
00085                          bool canHavePureNodes,
00086                          bool fastSort,
00087                          const SprClass& cls0,
00088                          const SprClass& cls1,
00089                          const SprTreeNode* parent,
00090                          const SprBox& limits,
00091                          SprIntegerBootstrap* bootstrap)
00092   :
00093   crit_(crit),
00094   data_(new SprBoxFilter(&data)),
00095   allLeafsSignal_(allLeafsSignal),
00096   nmin_(nmin),
00097   discrete_(discrete),
00098   canHavePureNodes_(canHavePureNodes),
00099   fastSort_(fastSort),
00100   cls0_(cls0),
00101   cls1_(cls1),
00102   parent_(parent),
00103   left_(0),
00104   right_(0),
00105   fom_(0),
00106   w0_(0),
00107   w1_(0),
00108   n0_(0),
00109   n1_(0),
00110   limits_(limits),
00111   id_(++counter_),
00112   nodeClass_(-1),
00113   d_(-1),
00114   cut_(0),
00115   bootstrap_(bootstrap)
00116 {
00117   assert( crit_ != 0 );
00118   assert( parent_ != 0 );
00119   bool status = data_->setBox(limits);
00120   assert( status );
00121   status = data_->irreversibleFilter();
00122   assert( status );
00123 }
00124 
00125 
00126 SprInterval SprTreeNode::limits(int d) const 
00127 {
00128   // sanity check
00129   assert( d>=0 );
00130 
00131   // find the cut
00132   SprBox::const_iterator iter = limits_.find(d);
00133 
00134   // if not found, infty range
00135   if( iter == limits_.end() ) 
00136     return SprInterval(SprUtils::min(),SprUtils::max());
00137 
00138   // exit
00139   return iter->second;
00140 }
00141 
00142 
00143 bool SprTreeNode::split(std::vector<SprTreeNode*>& nodesToSplit, 
00144                         std::vector<std::pair<int,double> >& countTreeSplits,
00145                         int verbose)
00146 {
00147   // sanity check
00148   if( data_ == 0 ) return true;
00149 
00150   // message
00151   if( (id_%100)==0 && verbose>1 ) 
00152     cout << "Splitting node " << id_ << " ..." << endl;
00153 
00154   // check weights
00155   w0_ = data_->weightInClass(cls0_);
00156   w1_ = data_->weightInClass(cls1_);
00157 
00158   // check numbers of events
00159   n0_ = data_->ptsInClass(cls0_);
00160   n1_ = data_->ptsInClass(cls1_);
00161 
00162   // get totals and FOM
00163   int ntot = n0_ + n1_;
00164   double wtot = w0_ + w1_;
00165   if( ntot<nmin_ || wtot<SprUtils::eps() ) {
00166     if( verbose > 2 ) {
00167       cout << "Ignore node " << id_ << " with " 
00168            << ntot << " events and " << wtot << " total weight." << endl;
00169     }
00170     return this->prepareExit(true);
00171   }
00172 
00173   // compute fom
00174   fom_ = crit_->fom(0,w0_,w1_,0);
00175   double invertedFom = crit_->fom(w0_,0,0,w1_);
00176   if( verbose > 3 ) {
00177     cout << "===================" << endl;
00178     cout << "Direct FOM=" << fom_ 
00179          << "  Inverted FOM=" << invertedFom << endl;
00180   }
00181   if( n1_>0 && 
00182       ( allLeafsSignal_ ||
00183         ( crit_->symmetric() && (w1_+SprUtils::eps())>w0_ ) ||
00184         ( !crit_->symmetric() && 
00185           ( fom_>invertedFom || 
00186             ( fabs(fom_-invertedFom)<SprUtils::eps() 
00187               && (w1_+SprUtils::eps())>w0_ ) ) 
00188           ) 
00189         ) 
00190       ) {
00191     nodeClass_ = 1;
00192   }
00193   else {
00194     nodeClass_ = 0;
00195     fom_ = invertedFom;
00196   }
00197 
00198   // check weights
00199   if( w0_<SprUtils::eps() || w1_<SprUtils::eps() ) {
00200     if( verbose > 3 ) {
00201       cout << "Node " << id_ << " missing one of categories." << endl;
00202     }
00203     return this->prepareExit(true);
00204   }
00205 
00206   // check if minimal number of events
00207   if( (n0_+n1_) == nmin_ ) {
00208     if( verbose > 3 ) {
00209       cout << "Node " << id_ << " has minimal number of events." 
00210            << " Will exit without splitting." << endl;
00211     }
00212     return this->prepareExit(true);
00213   }
00214 
00215   // message
00216   if( verbose > 3 ) {
00217     cout << "Splitting node " << id_ << " of class " << nodeClass_ 
00218          << " with " << w0_ << " background and " 
00219          << w1_ << " signal weights and " << ntot << " events." << endl;
00220     cout << "Starting FOM=" << fom_ << endl;
00221   }
00222 
00223   // select features
00224   set<unsigned> dims;
00225   if(      bootstrap_ == 0 ) {
00226     for( int d=0;d<data_->dim();d++ ) dims.insert(d);
00227   }
00228   else if( !bootstrap_->replica(dims) ) {
00229     cerr << "Unable to select features." << endl;
00230     return this->prepareExit(false);
00231   }
00232   if( verbose > 2 ) {
00233     cout << "Selected dimensions: ";
00234     for( set<unsigned>::const_iterator 
00235            iter=dims.begin();iter!=dims.end();iter++ ) cout << *iter << " ";
00236     cout << endl;
00237   }
00238 
00239   // set limits
00240   if( parent_ == 0 ) {
00241     limits_.clear();
00242   }
00243   if( verbose > 3 ) {
00244     cout << "Limits:" << endl;
00245     for( SprBox::const_iterator 
00246            iter=limits_.begin();iter!=limits_.end();iter++ ) {
00247       cout << "Dimension " << iter->first << "    " 
00248            << iter->second.first << " " << iter->second.second << endl;
00249     }
00250   }
00251 
00252   // init
00253   unsigned dim = data_->dim();
00254   vector<double> fom(dim,SprUtils::min());
00255   vector<double> cut(dim,SprUtils::min());
00256 
00257   // loop through dimensions
00258   for( set<unsigned>::const_iterator 
00259          iter=dims.begin();iter!=dims.end();iter++ ) {
00260     // get dimension
00261     unsigned d = *iter;
00262     assert( d < dim );
00263 
00264     // sort
00265     vector<int> sorted;
00266     vector<double> division;
00267     if( !this->sort(d,sorted,division) ) {
00268       cerr << "Unable to sort tree node in dimension " << d << endl;
00269       return this->prepareExit(false);
00270     }
00271 
00272     // check divisions
00273     if( division.empty() ) continue;
00274 
00275     // init
00276     double wmis0 = w0_;
00277     double wcor1 = w1_;
00278     double wmis1(0), wcor0(0);
00279     int nmis0 = n0_;
00280     int ncor1 = n1_;
00281     int nmis1(0), ncor0(0);
00282     vector<double> flo, fhi;
00283 
00284     // loop through points
00285     int ndiv = division.size();
00286     int istart(0), isplit(0);
00287     bool lbreak = true;
00288     for( int k=0;k<ndiv;k++ ) {
00289       double z = division[k];
00290       lbreak = false;
00291       for( isplit=istart;isplit<sorted.size();isplit++ ) {
00292         if( (*data_)[sorted[isplit]]->x_[d] > z ) {
00293           lbreak = true;
00294           break;
00295         }
00296       }
00297       if( !lbreak ) isplit = sorted.size();
00298       for( int i=istart;i<isplit;i++ ) {
00299         const SprPoint* p = (*data_)[sorted[i]];
00300         double w = data_->w(sorted[i]);
00301         if(      p->class_ == cls0_ ) {
00302           wmis0 -= w;
00303           wcor0 += w;
00304           nmis0--;
00305           ncor0++;
00306         }
00307         else if( p->class_ == cls1_ ) {
00308           wcor1 -= w;
00309           wmis1 += w;
00310           ncor1--;
00311           nmis1++;
00312         }
00313       }
00314       istart = isplit;
00315       if( crit_->symmetric() ) {  // need to compute FOM for one side only
00316         if( (ncor1+nmis0)>=nmin_ && (nmis1+ncor0)>=nmin_ 
00317             && (wcor1+wmis0)>0 && (wmis1+wcor0)>0 
00318             && ( canHavePureNodes_ 
00319                  || ((ncor1*nmis0)>0 && (nmis1*ncor0)>0 
00320                      && (wcor1*wmis0)>0 && (wmis1*wcor0)>0) ) )
00321           flo.push_back(crit_->fom(wcor0,wmis0,wcor1,wmis1));
00322         else
00323           flo.push_back(SprUtils::min());
00324       }
00325       else { // take both sides into account
00326         if( (ncor1+nmis0)>=nmin_ && (wcor1+wmis0)>0 
00327             && ( canHavePureNodes_ || ((ncor1*nmis0)>0 && (wcor1*wmis0)>0) ) )
00328           flo.push_back(crit_->fom(wcor0,wmis0,wcor1,wmis1));
00329         else
00330           flo.push_back(SprUtils::min());
00331         if( (nmis1+ncor0)>=nmin_ && (wmis1+wcor0)>0
00332             && (canHavePureNodes_ || ((nmis1*ncor0)>0 && (wmis1*wcor0)>0) ) )
00333           fhi.push_back(crit_->fom(wmis0,wcor0,wmis1,wcor1));
00334         else
00335           fhi.push_back(SprUtils::min());
00336       }
00337     }
00338 
00339     // find optimal point and sign of cut
00340     vector<double>::iterator ilo = max_element(flo.begin(),flo.end());
00341     vector<double>::iterator ihi = max_element(fhi.begin(),fhi.end());
00342     if( crit_->symmetric() || *ilo>*ihi ) {
00343       int k = ilo - flo.begin();
00344       cut[d] = division[k]; 
00345       fom[d] = *ilo;
00346     }
00347     else {
00348       int k = ihi - fhi.begin();
00349       cut[d] = division[k]; 
00350       fom[d] = *ihi;
00351     }
00352   }
00353 
00354   // find optimal fom
00355   vector<double>::iterator imax = max_element(fom.begin(),fom.end());
00356   double newFom = *imax;
00357 
00358   // split the node
00359   if( newFom > fom_ ) {
00360     d_ = imax - fom.begin();// dimension on which we split
00361     cut_ = cut[d_];
00362 
00363     // print out
00364     if( verbose > 2 ) {
00365       cout << "Splitting node " << id_ << " of class " << nodeClass_
00366            << " in dimension " << d_ << " with " 
00367            << n1_ << " signal and " << n0_ << " background events" 
00368            << "     FOM=" << newFom 
00369            << "   Split=" << cut_ << endl;
00370     }
00371     if( verbose > 3 ) {
00372       double w0l(0), w0r(0), w1l(0), w1r(0);
00373       int n0l(0), n0r(0), n1l(0), n1r(0);
00374       for( int i=0;i<data_->size();i++ ) {
00375         const SprPoint* p = (*data_)[i];
00376         double w = data_->w(i);
00377         if(      p->class_ == cls0_ ) {
00378           if( p->x_[d_] < cut_ ) {
00379             w0l += w;
00380             n0l++;
00381           }
00382           else {
00383             w0r += w;
00384             n0r++;
00385           }
00386         }
00387         else if( p->class_ == cls1_ ) {
00388           if( p->x_[d_] < cut_ ) {
00389             w1l += w;
00390             n1l++;
00391           }
00392           else {
00393             w1r += w;
00394             n1r++;
00395           }
00396         }
00397       }
00398       cout << "Splitting node "<< id_ 
00399            << " into nodes " << counter_+1 << "  " << counter_+2
00400            << "   Left  (0/1)= " << w0l << "/" << w1l 
00401            << " " << n0l << "/" << n1l 
00402            << "   Right (0/1)= " << w0r << "/" << w1r 
00403            << " " << n0r << "/" << n1r 
00404            << endl;
00405     }
00406 
00407     // get limits
00408     SprBox leftLims = limits_;
00409     SprBox rightLims = limits_;
00410 
00411     // update limits
00412     SprBox::iterator iter = limits_.find(d_);
00413     if( iter == limits_.end() ) {
00414       SprInterval leftcut(SprUtils::min(),cut_);
00415       leftLims.insert(pair<const unsigned,SprInterval>(d_,leftcut));
00416       SprInterval rightcut(cut_,SprUtils::max());
00417       rightLims.insert(pair<const unsigned,SprInterval>(d_,rightcut));
00418     }
00419     else {
00420       leftLims[d_].second = cut_;
00421       rightLims[d_].first = cut_;
00422     }
00423 
00424     // make new nodes
00425     left_ = new SprTreeNode(crit_,*data_,allLeafsSignal_,nmin_,
00426                             discrete_,canHavePureNodes_,fastSort_,
00427                             cls0_,cls1_,this,leftLims,bootstrap_);
00428     right_ = new SprTreeNode(crit_,*data_,allLeafsSignal_,nmin_,
00429                              discrete_,canHavePureNodes_,fastSort_,
00430                              cls0_,cls1_,this,rightLims,bootstrap_);
00431 
00432     // add new nodes to the split list
00433     nodesToSplit.push_back(left_);
00434     nodesToSplit.push_back(right_);
00435 
00436     // update split counter
00437     if( countTreeSplits.size() == data_->dim() ) {
00438       countTreeSplits[d_].first++;
00439       countTreeSplits[d_].second += (newFom - fom_);
00440     }
00441 
00442     // exit
00443     return this->prepareExit(true);
00444   }
00445 
00446   // message
00447   if( verbose > 2 ) {
00448     cout << "Failed to split node " << id_ 
00449          << " with " << n1_ << " signal and " 
00450          << n0_ << " background events." << endl;
00451   }
00452   if( verbose > 3 )
00453     cout << "===================" << endl;
00454 
00455   // exit
00456   return this->prepareExit(true);
00457 }
00458 
00459 
00460 bool SprTreeNode::sort(unsigned d, std::vector<int>& sorted,
00461                        std::vector<double>& division)
00462 {
00463   // init
00464   assert( d < data_->dim() );
00465   int size = data_->size();
00466   sorted.clear();
00467   sorted.resize(size,-1);
00468   division.clear();
00469 
00470   // prepare vector
00471   vector<pair<double,int> > r(size);
00472   
00473   // loop through points
00474   for( int j=0;j<size;j++ )
00475     r[j] = pair<double,int>((*data_)[j]->x_[d],j);
00476   
00477   // sort
00478   if( fastSort_ )
00479     SprSort(r.begin(),r.end(),STNCmpPairFirst());
00480   else
00481     stable_sort(r.begin(),r.end(),STNCmpPairFirst());
00482   
00483   // fill out sorted indices
00484   double xprev = r[0].first;
00485   sorted[0] = r[0].second;
00486   for( int j=1;j<size;j++ ) {
00487     sorted[j] = r[j].second;
00488     double xcurr = r[j].first;
00489     if( (xcurr-xprev) > SprUtils::eps() ) {
00490       division.push_back(0.5*(xcurr+xprev));
00491       xprev = xcurr;
00492     }
00493   }
00494 
00495   // exit
00496   return true;
00497 }
00498 
00499 
00500 bool SprTreeNode::prepareExit(bool status)
00501 {
00502   delete data_;
00503   data_ = 0;
00504   return status;
00505 }
00506 
00507 
00508 SprTrainedNode* SprTreeNode::makeTrained() const
00509 {
00510   SprTrainedNode* t = new SprTrainedNode;
00511   t->id_ = id_;
00512   if( discrete_ )
00513     t->score_ = nodeClass_;
00514   else {
00515     if( (w0_+w1_) > 0 )
00516       t->score_ = w1_/(w0_+w1_);
00517     else {
00518       //      cout << "Warning: node " << id_ 
00519       // << " has no associated weight." << endl;
00520       t->score_ = 0.5;
00521     }
00522   }
00523   t->d_ = d_;
00524   t->cut_ = cut_;
00525   return t;
00526 }
00527 
00528 
00529 bool SprTreeNode::setClasses(const SprClass& cls0, const SprClass& cls1)
00530 {
00531   if( left_!=0 || right_!=0 ) {
00532     cerr << "Unable to reset classes for the tree node with daughters." 
00533          << endl;
00534     return false;
00535   }
00536   cls0_ = cls0; 
00537   cls1_ = cls1;
00538   vector<SprClass> classes(2);
00539   classes[0] = cls0_; classes[1] = cls1_;
00540   data_->chooseClasses(classes);
00541   return true;
00542 }

Generated on Tue Jun 9 17:42:04 2009 for CMSSW by  doxygen 1.5.4