CMS 3D CMS Logo

SprTopdownTree.cc

Go to the documentation of this file.
00001 //$Id: SprTopdownTree.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTreeNode.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedNode.hh"
00007 
00008 #include <iostream>
00009 #include <map>
00010 #include <utility>
00011 #include <cassert>
00012 
00013 using namespace std;
00014 
00015 
00016 SprTopdownTree::SprTopdownTree(SprAbsFilter* data, 
00017                                const SprAbsTwoClassCriterion* crit,
00018                                int nmin, bool discrete,
00019                                SprIntegerBootstrap* bootstrap)
00020   :
00021   SprDecisionTree(data,crit,nmin,false,discrete,bootstrap)
00022 {
00023   cout << "Using a Topdown tree." << endl;
00024 }
00025 
00026 
00027 SprTrainedTopdownTree* SprTopdownTree::makeTrained() const
00028 {
00029   // make
00030   vector<const SprTrainedNode*> nodes;
00031   if( !this->makeTrainedNodes(nodes) ) {
00032     cerr << "SprTrainedTopdownTree unable to make trained nodes." << endl;
00033     return 0;
00034   }
00035   SprTrainedTopdownTree* t = new SprTrainedTopdownTree(nodes,true);
00036 
00037   // vars
00038   vector<string> vars;
00039   data_->vars(vars);
00040   t->setVars(vars);
00041 
00042   // exit
00043   return t;
00044 }
00045 
00046 
00047 bool SprTopdownTree::makeTrainedNodes(std::vector<const SprTrainedNode*>& 
00048                                       nodes) const
00049 {
00050   // sanity check
00051   if( fullNodeList_.empty() || root_->id_!=0 || fullNodeList_[0]!=root_ ) {
00052     cerr << "Tree is not properly configured. Unable to make trained nodes." 
00053          << endl;
00054     return false;
00055   }
00056 
00057   // copy all nodes into the map
00058   map<int,SprTrainedNode*> copy;
00059   for( int i=0;i<fullNodeList_.size();i++ ) {
00060     SprTrainedNode* node = fullNodeList_[i]->makeTrained();
00061     copy.insert(pair<const int,SprTrainedNode*>(node->id_,node));
00062   }
00063 
00064   // make sure the first node has id 0
00065   if( copy.begin()->first != 0 ) {
00066     cerr << "First id in the replicated map is not zero." << endl;
00067     return false;
00068   }
00069 
00070   // resolve mother/daughter references
00071   for( int i=0;i<fullNodeList_.size();i++ ) {
00072     const SprTreeNode* old = fullNodeList_[i];
00073     map<int,SprTrainedNode*>::iterator iter = copy.find(old->id_);
00074     assert( iter != copy.end() );
00075     if( old->left_ != 0 ) {
00076       map<int,SprTrainedNode*>::iterator dau1 = copy.find(old->left_->id_);
00077       assert( dau1 != copy.end() );
00078       iter->second->toDau1_ = dau1->second;
00079       dau1->second->toParent_ = iter->second;
00080     }
00081     if( old->right_ != 0 ) {
00082       map<int,SprTrainedNode*>::iterator dau2 = copy.find(old->right_->id_);
00083       assert( dau2 != copy.end() );
00084       iter->second->toDau2_ = dau2->second;
00085       dau2->second->toParent_ = iter->second;
00086     }
00087   }
00088 
00089   // convert the map into a plain vector
00090   nodes.clear();
00091   for( map<int,SprTrainedNode*>::iterator iter = copy.begin();
00092        iter!=copy.end();iter++ ) {
00093     nodes.push_back(iter->second);
00094   }
00095 
00096   // exit
00097   return true;
00098 }

Generated on Tue Jun 9 17:42:04 2009 for CMSSW by  doxygen 1.5.4