00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedNode.hh"
00006
00007 #include <map>
00008 #include <utility>
00009
00010 using namespace std;
00011
00012
00013 SprTrainedTopdownTree::~SprTrainedTopdownTree()
00014 {
00015 if( ownTree_ ) {
00016 for( int i=0;i<nodes_.size();i++ ) delete nodes_[i];
00017 ownTree_ = false;
00018 }
00019 }
00020
00021
00022 double SprTrainedTopdownTree::response(const std::vector<double>& v) const
00023 {
00024 const SprTrainedNode* node = nodes_[0];
00025 while( node->d_ >= 0 ) {
00026 assert( node->d_ < v.size() );
00027 if( v[node->d_] < node->cut_ )
00028 node = node->toDau1_;
00029 else
00030 node = node->toDau2_;
00031 }
00032 return node->score_;
00033 }
00034
00035
00036 void SprTrainedTopdownTree::print(std::ostream& os) const
00037 {
00038 os << "Trained TopdownTree " << SprVersion << endl;
00039 os << "Nodes: " << nodes_.size() << " nodes." << endl;
00040 for( int i=0;i<nodes_.size();i++ ) {
00041 const SprTrainedNode* node = nodes_[i];
00042 os << "Id: " << node->id_
00043 << " Score: " << node->score_
00044 << " Dim: " << node->d_
00045 << " Cut: " << node->cut_
00046 << " Daughters: " << (node->toDau1_==0 ? -1 : node->toDau1_->id_)
00047 << " " << (node->toDau2_==0 ? -1 : node->toDau2_->id_)
00048 << endl;
00049 }
00050 }
00051
00052
00053 bool SprTrainedTopdownTree::replicate(const std::vector<
00054 const SprTrainedNode*>& nodes)
00055 {
00056
00057 map<int,SprTrainedNode*> copy;
00058 for( int i=0;i<nodes.size();i++ ) {
00059 SprTrainedNode* node = new SprTrainedNode(*nodes[i]);
00060 copy.insert(pair<const int,SprTrainedNode*>(node->id_,node));
00061 }
00062
00063
00064 if( copy.begin()->first != 0 ) {
00065 cerr << "First id in the replicated map is not zero." << endl;
00066 return false;
00067 }
00068
00069
00070 for( int i=0;i<nodes.size();i++ ) {
00071 const SprTrainedNode* old = nodes[i];
00072 map<int,SprTrainedNode*>::iterator iter = copy.find(old->id_);
00073 assert( iter != copy.end() );
00074 if( old->toDau1_ != 0 ) {
00075 map<int,SprTrainedNode*>::iterator dau1 = copy.find(old->toDau1_->id_);
00076 assert( dau1 != copy.end() );
00077 iter->second->toDau1_ = dau1->second;
00078 dau1->second->toParent_ = iter->second;
00079 }
00080 if( old->toDau2_ != 0 ) {
00081 map<int,SprTrainedNode*>::iterator dau2 = copy.find(old->toDau2_->id_);
00082 assert( dau2 != copy.end() );
00083 iter->second->toDau2_ = dau2->second;
00084 dau2->second->toParent_ = iter->second;
00085 }
00086 }
00087
00088
00089 nodes_.clear();
00090 for( map<int,SprTrainedNode*>::iterator iter = copy.begin();
00091 iter!=copy.end();iter++ ) {
00092 nodes_.push_back(iter->second);
00093 }
00094
00095
00096 return true;
00097 }
00098
00099
00100 void SprTrainedTopdownTree::printFunction(std::ostream& os,
00101 const SprTrainedNode* currentNode,
00102 int indentLevel) const
00103 {
00104
00105 const SprTrainedNode* node;
00106 if( currentNode == 0 )
00107 node = nodes_[0];
00108 else
00109 node = currentNode;
00110
00111
00112 if( node->d_ >= 0 ) {
00113 for( int I=0;I<indentLevel;I++ ) os << " ";
00114 os << "if( V[" << node->d_ << "] < " << node->cut_ << " ) {" << endl;
00115 this->printFunction(os,node->toDau1_,indentLevel+2);
00116 for( int I=0;I<indentLevel;I++ ) os << " ";
00117 os << "}" << endl;
00118 for( int I=0;I<indentLevel;I++ ) os << " ";
00119 os << "else /*if( V[" << node->d_ << "] >= "
00120 << node->cut_ << " )*/ {" << endl;
00121 this->printFunction(os,node->toDau2_,indentLevel+2);
00122 for( int I=0;I<indentLevel;I++ ) os << " ";
00123 os << "}" << endl;
00124 }
00125 else {
00126 for( int I=0;I<indentLevel;I++ ) os << " ";
00127 os << "R += " << node->score_ << ";" << endl;
00128 }
00129 }