00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTreeNode.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprIntegerBootstrap.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedNode.hh"
00011
00012 #include <iostream>
00013 #include <set>
00014 #include <algorithm>
00015 #include <functional>
00016 #include <cassert>
00017 #include <cmath>
00018
00019 using namespace std;
00020
00021
00022 int SprTreeNode::counter_ = 0;
00023
00024 struct STNCmpPairFirst
00025 : public binary_function<pair<double,int>,pair<double,int>,bool> {
00026 bool operator()(const pair<double,int>& l, const pair<double,int>& r)
00027 const {
00028 return (l.first < r.first);
00029 }
00030 };
00031
00032
00033 SprTreeNode::~SprTreeNode()
00034 {
00035 delete data_;
00036 delete left_;
00037 delete right_;
00038 }
00039
00040
00041 SprTreeNode::SprTreeNode(const SprAbsTwoClassCriterion* crit,
00042 const SprAbsFilter* data,
00043 bool allLeafsSignal,
00044 int nmin,
00045 bool discrete,
00046 bool canHavePureNodes,
00047 bool fastSort,
00048 SprIntegerBootstrap* bootstrap)
00049 :
00050 crit_(crit),
00051 data_(new SprBoxFilter(data)),
00052 allLeafsSignal_(allLeafsSignal),
00053 nmin_(nmin),
00054 discrete_(discrete),
00055 canHavePureNodes_(canHavePureNodes),
00056 fastSort_(fastSort),
00057 cls0_(0),
00058 cls1_(1),
00059 parent_(0),
00060 left_(0),
00061 right_(0),
00062 fom_(0),
00063 w0_(0),
00064 w1_(0),
00065 n0_(0),
00066 n1_(0),
00067 limits_(),
00068 id_(0),
00069 nodeClass_(-1),
00070 d_(-1),
00071 cut_(0),
00072 bootstrap_(bootstrap)
00073 {
00074 assert( crit_ != 0 );
00075 assert( data_->size() > nmin_ );
00076 counter_ = 0;
00077 }
00078
00079
00080 SprTreeNode::SprTreeNode(const SprAbsTwoClassCriterion* crit,
00081 const SprBoxFilter& data,
00082 bool allLeafsSignal,
00083 int nmin,
00084 bool discrete,
00085 bool canHavePureNodes,
00086 bool fastSort,
00087 const SprClass& cls0,
00088 const SprClass& cls1,
00089 const SprTreeNode* parent,
00090 const SprBox& limits,
00091 SprIntegerBootstrap* bootstrap)
00092 :
00093 crit_(crit),
00094 data_(new SprBoxFilter(&data)),
00095 allLeafsSignal_(allLeafsSignal),
00096 nmin_(nmin),
00097 discrete_(discrete),
00098 canHavePureNodes_(canHavePureNodes),
00099 fastSort_(fastSort),
00100 cls0_(cls0),
00101 cls1_(cls1),
00102 parent_(parent),
00103 left_(0),
00104 right_(0),
00105 fom_(0),
00106 w0_(0),
00107 w1_(0),
00108 n0_(0),
00109 n1_(0),
00110 limits_(limits),
00111 id_(++counter_),
00112 nodeClass_(-1),
00113 d_(-1),
00114 cut_(0),
00115 bootstrap_(bootstrap)
00116 {
00117 assert( crit_ != 0 );
00118 assert( parent_ != 0 );
00119 bool status = data_->setBox(limits);
00120 assert( status );
00121 status = data_->irreversibleFilter();
00122 assert( status );
00123 }
00124
00125
00126 SprInterval SprTreeNode::limits(int d) const
00127 {
00128
00129 assert( d>=0 );
00130
00131
00132 SprBox::const_iterator iter = limits_.find(d);
00133
00134
00135 if( iter == limits_.end() )
00136 return SprInterval(SprUtils::min(),SprUtils::max());
00137
00138
00139 return iter->second;
00140 }
00141
00142
00143 bool SprTreeNode::split(std::vector<SprTreeNode*>& nodesToSplit,
00144 std::vector<std::pair<int,double> >& countTreeSplits,
00145 int verbose)
00146 {
00147
00148 if( data_ == 0 ) return true;
00149
00150
00151 if( (id_%100)==0 && verbose>1 )
00152 cout << "Splitting node " << id_ << " ..." << endl;
00153
00154
00155 w0_ = data_->weightInClass(cls0_);
00156 w1_ = data_->weightInClass(cls1_);
00157
00158
00159 n0_ = data_->ptsInClass(cls0_);
00160 n1_ = data_->ptsInClass(cls1_);
00161
00162
00163 int ntot = n0_ + n1_;
00164 double wtot = w0_ + w1_;
00165 if( ntot<nmin_ || wtot<SprUtils::eps() ) {
00166 if( verbose > 2 ) {
00167 cout << "Ignore node " << id_ << " with "
00168 << ntot << " events and " << wtot << " total weight." << endl;
00169 }
00170 return this->prepareExit(true);
00171 }
00172
00173
00174 fom_ = crit_->fom(0,w0_,w1_,0);
00175 double invertedFom = crit_->fom(w0_,0,0,w1_);
00176 if( verbose > 3 ) {
00177 cout << "===================" << endl;
00178 cout << "Direct FOM=" << fom_
00179 << " Inverted FOM=" << invertedFom << endl;
00180 }
00181 if( n1_>0 &&
00182 ( allLeafsSignal_ ||
00183 ( crit_->symmetric() && (w1_+SprUtils::eps())>w0_ ) ||
00184 ( !crit_->symmetric() &&
00185 ( fom_>invertedFom ||
00186 ( fabs(fom_-invertedFom)<SprUtils::eps()
00187 && (w1_+SprUtils::eps())>w0_ ) )
00188 )
00189 )
00190 ) {
00191 nodeClass_ = 1;
00192 }
00193 else {
00194 nodeClass_ = 0;
00195 fom_ = invertedFom;
00196 }
00197
00198
00199 if( w0_<SprUtils::eps() || w1_<SprUtils::eps() ) {
00200 if( verbose > 3 ) {
00201 cout << "Node " << id_ << " missing one of categories." << endl;
00202 }
00203 return this->prepareExit(true);
00204 }
00205
00206
00207 if( (n0_+n1_) == nmin_ ) {
00208 if( verbose > 3 ) {
00209 cout << "Node " << id_ << " has minimal number of events."
00210 << " Will exit without splitting." << endl;
00211 }
00212 return this->prepareExit(true);
00213 }
00214
00215
00216 if( verbose > 3 ) {
00217 cout << "Splitting node " << id_ << " of class " << nodeClass_
00218 << " with " << w0_ << " background and "
00219 << w1_ << " signal weights and " << ntot << " events." << endl;
00220 cout << "Starting FOM=" << fom_ << endl;
00221 }
00222
00223
00224 set<unsigned> dims;
00225 if( bootstrap_ == 0 ) {
00226 for( int d=0;d<data_->dim();d++ ) dims.insert(d);
00227 }
00228 else if( !bootstrap_->replica(dims) ) {
00229 cerr << "Unable to select features." << endl;
00230 return this->prepareExit(false);
00231 }
00232 if( verbose > 2 ) {
00233 cout << "Selected dimensions: ";
00234 for( set<unsigned>::const_iterator
00235 iter=dims.begin();iter!=dims.end();iter++ ) cout << *iter << " ";
00236 cout << endl;
00237 }
00238
00239
00240 if( parent_ == 0 ) {
00241 limits_.clear();
00242 }
00243 if( verbose > 3 ) {
00244 cout << "Limits:" << endl;
00245 for( SprBox::const_iterator
00246 iter=limits_.begin();iter!=limits_.end();iter++ ) {
00247 cout << "Dimension " << iter->first << " "
00248 << iter->second.first << " " << iter->second.second << endl;
00249 }
00250 }
00251
00252
00253 unsigned dim = data_->dim();
00254 vector<double> fom(dim,SprUtils::min());
00255 vector<double> cut(dim,SprUtils::min());
00256
00257
00258 for( set<unsigned>::const_iterator
00259 iter=dims.begin();iter!=dims.end();iter++ ) {
00260
00261 unsigned d = *iter;
00262 assert( d < dim );
00263
00264
00265 vector<int> sorted;
00266 vector<double> division;
00267 if( !this->sort(d,sorted,division) ) {
00268 cerr << "Unable to sort tree node in dimension " << d << endl;
00269 return this->prepareExit(false);
00270 }
00271
00272
00273 if( division.empty() ) continue;
00274
00275
00276 double wmis0 = w0_;
00277 double wcor1 = w1_;
00278 double wmis1(0), wcor0(0);
00279 int nmis0 = n0_;
00280 int ncor1 = n1_;
00281 int nmis1(0), ncor0(0);
00282 vector<double> flo, fhi;
00283
00284
00285 int ndiv = division.size();
00286 int istart(0), isplit(0);
00287 bool lbreak = true;
00288 for( int k=0;k<ndiv;k++ ) {
00289 double z = division[k];
00290 lbreak = false;
00291 for( isplit=istart;isplit<sorted.size();isplit++ ) {
00292 if( (*data_)[sorted[isplit]]->x_[d] > z ) {
00293 lbreak = true;
00294 break;
00295 }
00296 }
00297 if( !lbreak ) isplit = sorted.size();
00298 for( int i=istart;i<isplit;i++ ) {
00299 const SprPoint* p = (*data_)[sorted[i]];
00300 double w = data_->w(sorted[i]);
00301 if( p->class_ == cls0_ ) {
00302 wmis0 -= w;
00303 wcor0 += w;
00304 nmis0--;
00305 ncor0++;
00306 }
00307 else if( p->class_ == cls1_ ) {
00308 wcor1 -= w;
00309 wmis1 += w;
00310 ncor1--;
00311 nmis1++;
00312 }
00313 }
00314 istart = isplit;
00315 if( crit_->symmetric() ) {
00316 if( (ncor1+nmis0)>=nmin_ && (nmis1+ncor0)>=nmin_
00317 && (wcor1+wmis0)>0 && (wmis1+wcor0)>0
00318 && ( canHavePureNodes_
00319 || ((ncor1*nmis0)>0 && (nmis1*ncor0)>0
00320 && (wcor1*wmis0)>0 && (wmis1*wcor0)>0) ) )
00321 flo.push_back(crit_->fom(wcor0,wmis0,wcor1,wmis1));
00322 else
00323 flo.push_back(SprUtils::min());
00324 }
00325 else {
00326 if( (ncor1+nmis0)>=nmin_ && (wcor1+wmis0)>0
00327 && ( canHavePureNodes_ || ((ncor1*nmis0)>0 && (wcor1*wmis0)>0) ) )
00328 flo.push_back(crit_->fom(wcor0,wmis0,wcor1,wmis1));
00329 else
00330 flo.push_back(SprUtils::min());
00331 if( (nmis1+ncor0)>=nmin_ && (wmis1+wcor0)>0
00332 && (canHavePureNodes_ || ((nmis1*ncor0)>0 && (wmis1*wcor0)>0) ) )
00333 fhi.push_back(crit_->fom(wmis0,wcor0,wmis1,wcor1));
00334 else
00335 fhi.push_back(SprUtils::min());
00336 }
00337 }
00338
00339
00340 vector<double>::iterator ilo = max_element(flo.begin(),flo.end());
00341 vector<double>::iterator ihi = max_element(fhi.begin(),fhi.end());
00342 if( crit_->symmetric() || *ilo>*ihi ) {
00343 int k = ilo - flo.begin();
00344 cut[d] = division[k];
00345 fom[d] = *ilo;
00346 }
00347 else {
00348 int k = ihi - fhi.begin();
00349 cut[d] = division[k];
00350 fom[d] = *ihi;
00351 }
00352 }
00353
00354
00355 vector<double>::iterator imax = max_element(fom.begin(),fom.end());
00356 double newFom = *imax;
00357
00358
00359 if( newFom > fom_ ) {
00360 d_ = imax - fom.begin();
00361 cut_ = cut[d_];
00362
00363
00364 if( verbose > 2 ) {
00365 cout << "Splitting node " << id_ << " of class " << nodeClass_
00366 << " in dimension " << d_ << " with "
00367 << n1_ << " signal and " << n0_ << " background events"
00368 << " FOM=" << newFom
00369 << " Split=" << cut_ << endl;
00370 }
00371 if( verbose > 3 ) {
00372 double w0l(0), w0r(0), w1l(0), w1r(0);
00373 int n0l(0), n0r(0), n1l(0), n1r(0);
00374 for( int i=0;i<data_->size();i++ ) {
00375 const SprPoint* p = (*data_)[i];
00376 double w = data_->w(i);
00377 if( p->class_ == cls0_ ) {
00378 if( p->x_[d_] < cut_ ) {
00379 w0l += w;
00380 n0l++;
00381 }
00382 else {
00383 w0r += w;
00384 n0r++;
00385 }
00386 }
00387 else if( p->class_ == cls1_ ) {
00388 if( p->x_[d_] < cut_ ) {
00389 w1l += w;
00390 n1l++;
00391 }
00392 else {
00393 w1r += w;
00394 n1r++;
00395 }
00396 }
00397 }
00398 cout << "Splitting node "<< id_
00399 << " into nodes " << counter_+1 << " " << counter_+2
00400 << " Left (0/1)= " << w0l << "/" << w1l
00401 << " " << n0l << "/" << n1l
00402 << " Right (0/1)= " << w0r << "/" << w1r
00403 << " " << n0r << "/" << n1r
00404 << endl;
00405 }
00406
00407
00408 SprBox leftLims = limits_;
00409 SprBox rightLims = limits_;
00410
00411
00412 SprBox::iterator iter = limits_.find(d_);
00413 if( iter == limits_.end() ) {
00414 SprInterval leftcut(SprUtils::min(),cut_);
00415 leftLims.insert(pair<const unsigned,SprInterval>(d_,leftcut));
00416 SprInterval rightcut(cut_,SprUtils::max());
00417 rightLims.insert(pair<const unsigned,SprInterval>(d_,rightcut));
00418 }
00419 else {
00420 leftLims[d_].second = cut_;
00421 rightLims[d_].first = cut_;
00422 }
00423
00424
00425 left_ = new SprTreeNode(crit_,*data_,allLeafsSignal_,nmin_,
00426 discrete_,canHavePureNodes_,fastSort_,
00427 cls0_,cls1_,this,leftLims,bootstrap_);
00428 right_ = new SprTreeNode(crit_,*data_,allLeafsSignal_,nmin_,
00429 discrete_,canHavePureNodes_,fastSort_,
00430 cls0_,cls1_,this,rightLims,bootstrap_);
00431
00432
00433 nodesToSplit.push_back(left_);
00434 nodesToSplit.push_back(right_);
00435
00436
00437 if( countTreeSplits.size() == data_->dim() ) {
00438 countTreeSplits[d_].first++;
00439 countTreeSplits[d_].second += (newFom - fom_);
00440 }
00441
00442
00443 return this->prepareExit(true);
00444 }
00445
00446
00447 if( verbose > 2 ) {
00448 cout << "Failed to split node " << id_
00449 << " with " << n1_ << " signal and "
00450 << n0_ << " background events." << endl;
00451 }
00452 if( verbose > 3 )
00453 cout << "===================" << endl;
00454
00455
00456 return this->prepareExit(true);
00457 }
00458
00459
00460 bool SprTreeNode::sort(unsigned d, std::vector<int>& sorted,
00461 std::vector<double>& division)
00462 {
00463
00464 assert( d < data_->dim() );
00465 int size = data_->size();
00466 sorted.clear();
00467 sorted.resize(size,-1);
00468 division.clear();
00469
00470
00471 vector<pair<double,int> > r(size);
00472
00473
00474 for( int j=0;j<size;j++ )
00475 r[j] = pair<double,int>((*data_)[j]->x_[d],j);
00476
00477
00478 if( fastSort_ )
00479 SprSort(r.begin(),r.end(),STNCmpPairFirst());
00480 else
00481 stable_sort(r.begin(),r.end(),STNCmpPairFirst());
00482
00483
00484 double xprev = r[0].first;
00485 sorted[0] = r[0].second;
00486 for( int j=1;j<size;j++ ) {
00487 sorted[j] = r[j].second;
00488 double xcurr = r[j].first;
00489 if( (xcurr-xprev) > SprUtils::eps() ) {
00490 division.push_back(0.5*(xcurr+xprev));
00491 xprev = xcurr;
00492 }
00493 }
00494
00495
00496 return true;
00497 }
00498
00499
00500 bool SprTreeNode::prepareExit(bool status)
00501 {
00502 delete data_;
00503 data_ = 0;
00504 return status;
00505 }
00506
00507
00508 SprTrainedNode* SprTreeNode::makeTrained() const
00509 {
00510 SprTrainedNode* t = new SprTrainedNode;
00511 t->id_ = id_;
00512 if( discrete_ )
00513 t->score_ = nodeClass_;
00514 else {
00515 if( (w0_+w1_) > 0 )
00516 t->score_ = w1_/(w0_+w1_);
00517 else {
00518
00519
00520 t->score_ = 0.5;
00521 }
00522 }
00523 t->d_ = d_;
00524 t->cut_ = cut_;
00525 return t;
00526 }
00527
00528
00529 bool SprTreeNode::setClasses(const SprClass& cls0, const SprClass& cls1)
00530 {
00531 if( left_!=0 || right_!=0 ) {
00532 cerr << "Unable to reset classes for the tree node with daughters."
00533 << endl;
00534 return false;
00535 }
00536 cls0_ = cls0;
00537 cls1_ = cls1;
00538 vector<SprClass> classes(2);
00539 classes[0] = cls0_; classes[1] = cls1_;
00540 data_->chooseClasses(classes);
00541 return true;
00542 }