00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTopdownTree.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTreeNode.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedNode.hh"
00007
00008 #include <iostream>
00009 #include <map>
00010 #include <utility>
00011 #include <cassert>
00012
00013 using namespace std;
00014
00015
00016 SprTopdownTree::SprTopdownTree(SprAbsFilter* data,
00017 const SprAbsTwoClassCriterion* crit,
00018 int nmin, bool discrete,
00019 SprIntegerBootstrap* bootstrap)
00020 :
00021 SprDecisionTree(data,crit,nmin,false,discrete,bootstrap)
00022 {
00023 cout << "Using a Topdown tree." << endl;
00024 }
00025
00026
00027 SprTrainedTopdownTree* SprTopdownTree::makeTrained() const
00028 {
00029
00030 vector<const SprTrainedNode*> nodes;
00031 if( !this->makeTrainedNodes(nodes) ) {
00032 cerr << "SprTrainedTopdownTree unable to make trained nodes." << endl;
00033 return 0;
00034 }
00035 SprTrainedTopdownTree* t = new SprTrainedTopdownTree(nodes,true);
00036
00037
00038 vector<string> vars;
00039 data_->vars(vars);
00040 t->setVars(vars);
00041
00042
00043 return t;
00044 }
00045
00046
00047 bool SprTopdownTree::makeTrainedNodes(std::vector<const SprTrainedNode*>&
00048 nodes) const
00049 {
00050
00051 if( fullNodeList_.empty() || root_->id_!=0 || fullNodeList_[0]!=root_ ) {
00052 cerr << "Tree is not properly configured. Unable to make trained nodes."
00053 << endl;
00054 return false;
00055 }
00056
00057
00058 map<int,SprTrainedNode*> copy;
00059 for( int i=0;i<fullNodeList_.size();i++ ) {
00060 SprTrainedNode* node = fullNodeList_[i]->makeTrained();
00061 copy.insert(pair<const int,SprTrainedNode*>(node->id_,node));
00062 }
00063
00064
00065 if( copy.begin()->first != 0 ) {
00066 cerr << "First id in the replicated map is not zero." << endl;
00067 return false;
00068 }
00069
00070
00071 for( int i=0;i<fullNodeList_.size();i++ ) {
00072 const SprTreeNode* old = fullNodeList_[i];
00073 map<int,SprTrainedNode*>::iterator iter = copy.find(old->id_);
00074 assert( iter != copy.end() );
00075 if( old->left_ != 0 ) {
00076 map<int,SprTrainedNode*>::iterator dau1 = copy.find(old->left_->id_);
00077 assert( dau1 != copy.end() );
00078 iter->second->toDau1_ = dau1->second;
00079 dau1->second->toParent_ = iter->second;
00080 }
00081 if( old->right_ != 0 ) {
00082 map<int,SprTrainedNode*>::iterator dau2 = copy.find(old->right_->id_);
00083 assert( dau2 != copy.end() );
00084 iter->second->toDau2_ = dau2->second;
00085 dau2->second->toParent_ = iter->second;
00086 }
00087 }
00088
00089
00090 nodes.clear();
00091 for( map<int,SprTrainedNode*>::iterator iter = copy.begin();
00092 iter!=copy.end();iter++ ) {
00093 nodes.push_back(iter->second);
00094 }
00095
00096
00097 return true;
00098 }