CMS 3D CMS Logo

SprTrainedTopdownTree.cc

Go to the documentation of this file.
00001 //$Id: SprTrainedTopdownTree.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedTopdownTree.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedNode.hh"
00006 
00007 #include <map>
00008 #include <utility>
00009 
00010 using namespace std;
00011 
00012 
00013 SprTrainedTopdownTree::~SprTrainedTopdownTree()
00014 {
00015   if( ownTree_ ) {
00016     for( int i=0;i<nodes_.size();i++ ) delete nodes_[i];
00017     ownTree_ = false;
00018   }
00019 }
00020 
00021 
00022 double SprTrainedTopdownTree::response(const std::vector<double>& v) const
00023 {
00024   const SprTrainedNode* node = nodes_[0];
00025   while( node->d_ >= 0 ) {
00026     assert( node->d_ < v.size() );
00027     if( v[node->d_] < node->cut_ )
00028       node = node->toDau1_;
00029     else
00030       node = node->toDau2_;
00031   }
00032   return node->score_;
00033 }
00034 
00035 
00036 void SprTrainedTopdownTree::print(std::ostream& os) const
00037 {
00038   os << "Trained TopdownTree " << SprVersion << endl;
00039   os << "Nodes: " << nodes_.size() << " nodes." << endl;
00040   for( int i=0;i<nodes_.size();i++ ) {
00041     const SprTrainedNode* node = nodes_[i];
00042     os << "Id: "         << node->id_
00043        << " Score: "     << node->score_
00044        << " Dim: "       << node->d_
00045        << " Cut: "       << node->cut_
00046        << " Daughters: " << (node->toDau1_==0 ? -1 : node->toDau1_->id_)
00047        << " "            << (node->toDau2_==0 ? -1 : node->toDau2_->id_)
00048        << endl;
00049   }
00050 }
00051 
00052 
00053 bool SprTrainedTopdownTree::replicate(const std::vector<
00054                                       const SprTrainedNode*>& nodes)
00055 {
00056   // copy all nodes into the map
00057   map<int,SprTrainedNode*> copy;
00058   for( int i=0;i<nodes.size();i++ ) {
00059     SprTrainedNode* node = new SprTrainedNode(*nodes[i]);
00060     copy.insert(pair<const int,SprTrainedNode*>(node->id_,node));
00061   }
00062 
00063   // make sure the first node has id 0
00064   if( copy.begin()->first != 0 ) {
00065     cerr << "First id in the replicated map is not zero." << endl;
00066     return false;
00067   }
00068 
00069   // resolve mother/daughter references
00070   for( int i=0;i<nodes.size();i++ ) {
00071     const SprTrainedNode* old = nodes[i];
00072     map<int,SprTrainedNode*>::iterator iter = copy.find(old->id_);
00073     assert( iter != copy.end() );
00074     if( old->toDau1_ != 0 ) {
00075       map<int,SprTrainedNode*>::iterator dau1 = copy.find(old->toDau1_->id_);
00076       assert( dau1 != copy.end() );
00077       iter->second->toDau1_ = dau1->second;
00078       dau1->second->toParent_ = iter->second;
00079     }
00080     if( old->toDau2_ != 0 ) {
00081       map<int,SprTrainedNode*>::iterator dau2 = copy.find(old->toDau2_->id_);
00082       assert( dau2 != copy.end() );
00083       iter->second->toDau2_ = dau2->second;
00084       dau2->second->toParent_ = iter->second;
00085     }
00086   }
00087 
00088   // convert the map into a plain vector
00089   nodes_.clear();
00090   for( map<int,SprTrainedNode*>::iterator iter = copy.begin();
00091        iter!=copy.end();iter++ ) {
00092     nodes_.push_back(iter->second);
00093   }
00094 
00095   // exit
00096   return true;
00097 }
00098 
00099 
00100 void SprTrainedTopdownTree::printFunction(std::ostream& os,
00101                                           const SprTrainedNode* currentNode,
00102                                           int indentLevel) const
00103 {
00104   // Use root if no node given.
00105   const SprTrainedNode* node; 
00106   if( currentNode == 0 ) 
00107     node = nodes_[0]; 
00108   else 
00109     node = currentNode; 
00110  
00111   // Print this node. 
00112   if( node->d_ >= 0 ) { 
00113     for( int I=0;I<indentLevel;I++ ) os << " ";
00114     os << "if( V[" << node->d_ << "] < " << node->cut_ << " ) {" << endl; 
00115     this->printFunction(os,node->toDau1_,indentLevel+2); 
00116     for( int I=0;I<indentLevel;I++ ) os << " ";
00117     os << "}" << endl; 
00118     for( int I=0;I<indentLevel;I++ ) os << " ";
00119     os << "else /*if( V[" << node->d_ << "] >= " 
00120        << node->cut_ << " )*/ {" << endl; 
00121     this->printFunction(os,node->toDau2_,indentLevel+2); 
00122     for( int I=0;I<indentLevel;I++ ) os << " ";
00123     os << "}" << endl; 
00124   } 
00125   else { 
00126     for( int I=0;I<indentLevel;I++ ) os << " ";
00127     os << "R += " << node->score_ << ";" << endl; 
00128   } 
00129 }

Generated on Tue Jun 9 17:42:04 2009 for CMSSW by  doxygen 1.5.4