CMS 3D CMS Logo

List of all members | Public Member Functions | Private Member Functions | Private Attributes
emtf::Tree Class Reference

#include <Tree.h>

Public Member Functions

void addXMLAttributes (TXMLEngine *xml, Node *node, XMLNodePointer_t np)
 
void buildTree (int nodeLimit)
 
void calcError ()
 
NodefilterEvent (Event *e)
 
NodefilterEventRecursive (Node *node, Event *e)
 
void filterEvents (std::vector< Event * > &tEvents)
 
void filterEventsRecursive (Node *node)
 
double getBoostWeight (void) const
 
int getNumTerminalNodes ()
 
NodegetRootNode ()
 
void getSplitValues (std::vector< std::vector< double >> &v)
 
void getSplitValuesRecursive (Node *node, std::vector< std::vector< double >> &v)
 
std::list< Node * > & getTerminalNodes ()
 
void loadFromCondPayload (const L1TMuonEndCapForest::DTree &tree)
 
void loadFromCondPayloadRecursive (const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
 
void loadFromXML (const char *filename)
 
void loadFromXMLRecursive (TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
 
Treeoperator= (const Tree &tree)
 
void rankVariables (std::vector< double > &v)
 
void rankVariablesRecursive (Node *node, std::vector< double > &v)
 
void saveToXML (const char *filename)
 
void saveToXMLRecursive (TXMLEngine *xml, Node *node, XMLNodePointer_t np)
 
void setBoostWeight (double wgt)
 
void setRootNode (Node *sRootNode)
 
void setTerminalNodes (std::list< Node * > &sTNodes)
 
 Tree ()
 
 Tree (std::vector< std::vector< Event * > > &cEvents)
 
 Tree (const Tree &tree)
 
 Tree (Tree &&tree)
 
 ~Tree ()
 

Private Member Functions

NodecopyFrom (const Node *local_root)
 
void findLeafs (Node *local_root, std::list< Node * > &tn)
 

Private Attributes

double boostWeight
 
int numTerminalNodes
 
double rmsError
 
NoderootNode
 
std::list< Node * > terminalNodes
 
unsigned xmlVersion
 

Detailed Description

Definition at line 15 of file Tree.h.

Constructor & Destructor Documentation

Tree::Tree ( )

Definition at line 30 of file Tree.cc.

References boostWeight, numTerminalNodes, rootNode, terminalNodes, and xmlVersion.

31 {
32  rootNode = new Node("root");
33 
34  terminalNodes.push_back(rootNode);
35  numTerminalNodes = 1;
36  boostWeight = 0;
37  xmlVersion = 2017;
38 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
TGeoNode Node
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
Tree::Tree ( std::vector< std::vector< Event * > > &  cEvents)

Definition at line 40 of file Tree.cc.

References boostWeight, numTerminalNodes, rootNode, emtf::Node::setEvents(), terminalNodes, and xmlVersion.

41 {
42  rootNode = new Node("root");
43  rootNode->setEvents(cEvents);
44 
45  terminalNodes.push_back(rootNode);
46  numTerminalNodes = 1;
47  boostWeight = 0;
48  xmlVersion = 2017;
49 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
TGeoNode Node
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:204
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
Tree::~Tree ( )

Definition at line 55 of file Tree.cc.

References rootNode.

56 {
57 // When the tree is destroyed it will delete all of the nodes in the tree.
58 // The deletion begins with the rootnode and continues recursively.
59  if(rootNode) delete rootNode;
60 }
Node * rootNode
Definition: Tree.h:60
Tree::Tree ( const Tree tree)

if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();

Definition at line 62 of file Tree.cc.

References boostWeight, copyFrom(), findLeafs(), getRootNode(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

63 {
64  // unfortunately, authors of these classes didn't use const qualifiers
65  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
67  rmsError = tree.rmsError;
68  boostWeight = tree.boostWeight;
69  xmlVersion = tree.xmlVersion;
70 
71  terminalNodes.resize(0);
72  // find new leafs
74 
76 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
Node * getRootNode()
Definition: Tree.cc:160
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:125
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:96
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
Tree::Tree ( Tree &&  tree)

Definition at line 140 of file Tree.cc.

References boostWeight, eostools::move(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

141 {
142  if(rootNode) delete rootNode; // this line is the only reason not to use default move constructor
143  rootNode = tree.rootNode;
146  rmsError = tree.rmsError;
147  boostWeight = tree.boostWeight;
148  xmlVersion = tree.xmlVersion;
149 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
def move(src, dest)
Definition: eostools.py:511

Member Function Documentation

void Tree::addXMLAttributes ( TXMLEngine *  xml,
Node node,
XMLNodePointer_t  np 
)

Definition at line 418 of file Tree.cc.

References emtf::Node::getFitValue(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), and emtf::numToStr().

Referenced by saveToXML(), and saveToXMLRecursive().

419 {
420  // Convert Node members into XML attributes
421  // and add them to the XMLEngine.
422  xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
423  xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
424  xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
425 }
double getFitValue()
Definition: Node.cc:158
int np
Definition: AMPTWrapper.h:33
double getSplitValue()
Definition: Node.cc:136
int getSplitVariable()
Definition: Node.cc:146
std::string numToStr(T num)
Definition: Utilities.h:44
void Tree::buildTree ( int  nodeLimit)

Definition at line 204 of file Tree.cc.

References calcError(), emtf::Node::calcOptimumSplit(), emtf::Node::filterEventsToDaughters(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), numTerminalNodes, rootNode, terminalNodes, and emtf::Node::theMiracleOfChildBirth().

Referenced by emtf::Forest::doRegression().

205 {
206  // We greedily pick the best terminal node to split.
207  double bestNodeErrorReduction = -1;
208  Node* nodeToSplit = nullptr;
209 
210  if(numTerminalNodes == 1)
211  {
213  calcError();
214 // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
215  }
216 
217  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
218  {
219  if( (*it)->getErrorReduction() > bestNodeErrorReduction )
220  {
221  bestNodeErrorReduction = (*it)->getErrorReduction();
222  nodeToSplit = (*it);
223  }
224  }
225 
226  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
227 
228  // If all of the nodes have one event we can't add any more nodes and reduce the error.
229  if(nodeToSplit == nullptr) return;
230 
231  // Create daughter nodes, and link the nodes together appropriately.
232  nodeToSplit->theMiracleOfChildBirth();
233 
234  // Get left and right daughters for reference.
235  Node* left = nodeToSplit->getLeftDaughter();
236  Node* right = nodeToSplit->getRightDaughter();
237 
238  // Update the list of terminal nodes.
239  terminalNodes.remove(nodeToSplit);
240  terminalNodes.push_back(left);
241  terminalNodes.push_back(right);
243 
244  // Filter the events from the parent into the daughters.
245  nodeToSplit->filterEventsToDaughters();
246 
247  // Calculate the best splits for the new nodes.
248  left->calcOptimumSplit();
249  right->calcOptimumSplit();
250 
251  // See if the error reduces as we add more nodes.
252  calcError();
253 
254  if(numTerminalNodes % 1 == 0)
255  {
256 // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
257  }
258 
259  // Repeat until done.
260  if(numTerminalNodes < nodeLimit) buildTree(nodeLimit);
261 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
std::list< Node * > terminalNodes
Definition: Tree.h:61
void buildTree(int nodeLimit)
Definition: Tree.cc:204
int numTerminalNodes
Definition: Tree.h:62
void calcError()
Definition: Tree.cc:188
Node * rootNode
Definition: Tree.h:60
void filterEventsToDaughters()
Definition: Node.cc:350
void calcOptimumSplit()
Definition: Node.cc:214
void theMiracleOfChildBirth()
Definition: Node.cc:335
void Tree::calcError ( )

Definition at line 188 of file Tree.cc.

References emtf::Node::getNumEvents(), rmsError, rootNode, mathSSE::sqrt(), and terminalNodes.

Referenced by buildTree().

189 {
190 // Loop through the separate predictive regions (terminal nodes) and
191 // add up the errors to get the error of the entire space.
192 
193  double totalSquaredError = 0;
194 
195  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
196  {
197  totalSquaredError += (*it)->getTotalError();
198  }
199  rmsError = sqrt( totalSquaredError/rootNode->getNumEvents() );
200 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
T sqrt(T t)
Definition: SSEVec.h:18
int getNumEvents()
Definition: Node.cc:192
Node * rootNode
Definition: Tree.h:60
Node * Tree::copyFrom ( const Node local_root)
private

Definition at line 96 of file Tree.cc.

References emtf::Node::getAvgError(), emtf::Node::getErrorReduction(), emtf::Node::getFitValue(), emtf::Node::getLeftDaughter(), emtf::Node::getName(), emtf::Node::getNumEvents(), emtf::Node::getRightDaughter(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), emtf::Node::getTotalError(), and emtf::Node::setParent().

Referenced by operator=(), and Tree().

97 {
98  // end-case
99  if( !local_root ) return nullptr;
100 
101  Node *lr = const_cast<Node*>(local_root);
102 
103  // recursion
104  Node *left_new_child = copyFrom( lr->getLeftDaughter() );
105  Node *right_new_child = copyFrom( lr->getRightDaughter() );
106 
107  // performing main work at this level
108  Node *new_local_root = new Node( lr->getName() );
109  if( left_new_child ) left_new_child ->setParent(new_local_root);
110  if( right_new_child ) right_new_child->setParent(new_local_root);
111  new_local_root->setLeftDaughter ( left_new_child );
112  new_local_root->setRightDaughter( right_new_child );
113  new_local_root->setErrorReduction( lr->getErrorReduction() );
114  new_local_root->setSplitValue( lr->getSplitValue() );
115  new_local_root->setSplitVariable( lr->getSplitVariable() );
116  new_local_root->setFitValue( lr->getFitValue() );
117  new_local_root->setTotalError( lr->getTotalError() );
118  new_local_root->setAvgError( lr->getAvgError() );
119  new_local_root->setNumEvents( lr->getNumEvents() );
120 // new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
121 
122  return new_local_root;
123 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
double getFitValue()
Definition: Node.cc:158
TGeoNode Node
double getTotalError()
Definition: Node.cc:170
double getSplitValue()
Definition: Node.cc:136
std::string getName()
Definition: Node.cc:78
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:96
void setParent(Node *sParent)
Definition: Node.cc:119
int getNumEvents()
Definition: Node.cc:192
double getErrorReduction()
Definition: Node.cc:90
double getAvgError()
Definition: Node.cc:180
int getSplitVariable()
Definition: Node.cc:146
Node * Tree::filterEvent ( Event e)

Definition at line 299 of file Tree.cc.

References filterEventRecursive(), and rootNode.

Referenced by emtf::Forest::appendCorrection().

300 {
301 // Use trees which have already been built to fit a bunch of events
302 // given by the tEvents vector.
303 
304  // Filter the event into a predictive region (terminal node).
305  Node* node = filterEventRecursive(rootNode, e);
306  return node;
307 }
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:311
Node * rootNode
Definition: Tree.h:60
Node * Tree::filterEventRecursive ( Node node,
Event e 
)

Definition at line 311 of file Tree.cc.

References emtf::Node::filterEventToDaughter().

Referenced by filterEvent().

312 {
313 // Filter the event repeatedly into the daughter nodes until it
314 // falls into a terminal node.
315 
316 
317  Node* nextNode = node->filterEventToDaughter(e);
318  if(nextNode == nullptr) return node;
319 
320  return filterEventRecursive(nextNode, e);
321 }
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:311
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:395
void Tree::filterEvents ( std::vector< Event * > &  tEvents)

Definition at line 265 of file Tree.cc.

References filterEventsRecursive(), emtf::Node::getEvents(), and rootNode.

Referenced by emtf::Forest::appendCorrection().

266 {
267 // Use trees which have already been built to fit a bunch of events
268 // given by the tEvents vector.
269 
270  // Set the events to be filtered.
271  rootNode->getEvents() = std::vector< std::vector<Event*> >(1);
272  rootNode->getEvents()[0] = tEvents;
273 
274  // The tree now knows about the events it needs to fit.
275  // Filter them into a predictive region (terminal node).
277 }
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:199
void filterEventsRecursive(Node *node)
Definition: Tree.cc:281
Node * rootNode
Definition: Tree.h:60
void Tree::filterEventsRecursive ( Node node)

Definition at line 281 of file Tree.cc.

References emtf::Node::filterEventsToDaughters(), emtf::Node::getLeftDaughter(), and emtf::Node::getRightDaughter().

Referenced by filterEvents().

282 {
283 // Filter the events repeatedly into the daughter nodes until they
284 // fall into a terminal node.
285 
286  Node* left = node->getLeftDaughter();
287  Node* right = node->getRightDaughter();
288 
289  if(left == nullptr || right == nullptr) return;
290 
291  node->filterEventsToDaughters();
292 
293  filterEventsRecursive(left);
294  filterEventsRecursive(right);
295 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void filterEventsRecursive(Node *node)
Definition: Tree.cc:281
void filterEventsToDaughters()
Definition: Node.cc:350
void Tree::findLeafs ( Node local_root,
std::list< Node * > &  tn 
)
private

Definition at line 125 of file Tree.cc.

References emtf::Node::getLeftDaughter(), and emtf::Node::getRightDaughter().

Referenced by operator=(), and Tree().

126 {
127  if( !local_root->getLeftDaughter() && !local_root->getRightDaughter() ){
128  // leaf or ternimal node found
129  tn.push_back(local_root);
130  return;
131  }
132 
133  if( local_root->getLeftDaughter() )
134  findLeafs( local_root->getLeftDaughter(), tn );
135 
136  if( local_root->getRightDaughter() )
137  findLeafs( local_root->getRightDaughter(), tn );
138 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:125
double emtf::Tree::getBoostWeight ( void  ) const
inline

Definition at line 56 of file Tree.h.

References boostWeight.

Referenced by L1TMuonEndCapForestESProducer::produce().

56 { return boostWeight; }
double boostWeight
Definition: Tree.h:64
int Tree::getNumTerminalNodes ( )

Definition at line 179 of file Tree.cc.

References numTerminalNodes.

180 {
181  return numTerminalNodes;
182 }
int numTerminalNodes
Definition: Tree.h:62
Node * Tree::getRootNode ( )

Definition at line 160 of file Tree.cc.

References rootNode.

Referenced by operator=(), L1TMuonEndCapForestESProducer::produce(), and Tree().

161 {
162  return rootNode;
163 }
Node * rootNode
Definition: Tree.h:60
void Tree::getSplitValues ( std::vector< std::vector< double >> &  v)

Definition at line 397 of file Tree.cc.

References getSplitValuesRecursive(), rootNode, and findQualityFiles::v.

398 {
400 }
Node * rootNode
Definition: Tree.h:60
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:367
void Tree::getSplitValuesRecursive ( Node node,
std::vector< std::vector< double >> &  v 
)

Definition at line 367 of file Tree.cc.

References gather_cfg::cout, emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), pfDeepBoostedJetPreprocessParams_cfi::sv, and findQualityFiles::v.

Referenced by getSplitValues().

368 {
369 // We recursively go through all of the nodes in the tree and find the
370 // split points used for each split variable.
371 
372  Node* left = node->getLeftDaughter();
373  Node* right = node->getRightDaughter();
374 
375  // Terminal nodes don't contribute.
376  if(left==nullptr || right==nullptr) return;
377 
378  int sv = node->getSplitVariable();
379  double sp = node->getSplitValue();
380 
381  if(sv == -1)
382  {
383  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
384  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
385  }
386 
387  // Add the split point to the list for the correct split variable.
388  v[sv].push_back(sp);
389 
390  getSplitValuesRecursive(left, v);
391  getSplitValuesRecursive(right, v);
392 
393 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
double getSplitValue()
Definition: Node.cc:136
int getSplitVariable()
Definition: Node.cc:146
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:367
std::list< Node * > & Tree::getTerminalNodes ( )

Definition at line 172 of file Tree.cc.

References terminalNodes.

Referenced by emtf::Forest::updateEvents(), and emtf::Forest::updateRegTargets().

173 {
174  return terminalNodes;
175 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
void Tree::loadFromCondPayload ( const L1TMuonEndCapForest::DTree tree)

Definition at line 587 of file Tree.cc.

References loadFromCondPayloadRecursive(), and rootNode.

588 {
589  // start fresh in case this is not the only call to construct a tree
590  if( rootNode ) delete rootNode;
591  rootNode = new Node("root");
592 
593  const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
595 }
TGeoNode Node
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:597
Node * rootNode
Definition: Tree.h:60
Definition: tree.py:1
void Tree::loadFromCondPayloadRecursive ( const L1TMuonEndCapForest::DTree tree,
const L1TMuonEndCapForest::DTreeNode node,
Node tnode 
)

Definition at line 597 of file Tree.cc.

References L1TMuonEndCapForest::DTreeNode::fitVal, emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), L1TMuonEndCapForest::DTreeNode::ileft, L1TMuonEndCapForest::DTreeNode::iright, numTerminalNodes, emtf::Node::setFitValue(), emtf::Node::setSplitValue(), emtf::Node::setSplitVariable(), L1TMuonEndCapForest::DTreeNode::splitVal, L1TMuonEndCapForest::DTreeNode::splitVar, terminalNodes, and emtf::Node::theMiracleOfChildBirth().

Referenced by loadFromCondPayload().

598 {
599  // Store gathered splitInfo into the node object.
600  tnode->setSplitVariable(node.splitVar);
601  tnode->setSplitValue(node.splitVal);
602  tnode->setFitValue(node.fitVal);
603 
604  // If there are no daughters we are done.
605  if( node.ileft == 0 || node.iright == 0) return; // root cannot be anyone's child
606  if( node.ileft >= tree.size() ||
607  node.iright >= tree.size() ) return; // out of range addressing on purpose
608 
609  // If there are daughters link the node objects appropriately.
610  tnode->theMiracleOfChildBirth();
611  Node* tleft = tnode->getLeftDaughter();
612  Node* tright = tnode->getRightDaughter();
613 
614  // Update the list of terminal nodes.
615  terminalNodes.remove(tnode);
616  terminalNodes.push_back(tleft);
617  terminalNodes.push_back(tright);
619 
622 }
void setFitValue(double sFitValue)
Definition: Node.cc:153
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
std::list< Node * > terminalNodes
Definition: Tree.h:61
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:597
int numTerminalNodes
Definition: Tree.h:62
void setSplitVariable(int sSplitVar)
Definition: Node.cc:141
Definition: tree.py:1
void setSplitValue(double sSplitValue)
Definition: Node.cc:131
void theMiracleOfChildBirth()
Definition: Node.cc:335
void Tree::loadFromXML ( const char *  filename)

Definition at line 477 of file Tree.cc.

References boostWeight, loadFromXMLRecursive(), rootNode, AlCaHLTBitMon_QueryRunRegistry::string, cmsPerfSuiteHarvest::xmldoc, and xmlVersion.

478 {
479  // First create the engine.
480  TXMLEngine* xml = new TXMLEngine;
481 
482  // Now try to parse xml file.
483  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
484  if (xmldoc==nullptr)
485  {
486  delete xml;
487  return;
488  }
489 
490  // Get access to main node of the xml file.
491  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
492 
493  // the original 2016 pT xmls define the source tree node to be the top-level xml node
494  // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
495  // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
496  if( std::string("BinaryTree") == xml->GetNodeName(mainnode) ){
497  XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
498  attr = xml->GetNextAttr(attr);
499  boostWeight = (attr ? strtod(xml->GetAttrValue(attr),nullptr) : 0);
500  // step inside the top-level xml node
501  mainnode = xml->GetChild(mainnode);
502  xmlVersion = 2017;
503  } else {
504  boostWeight = 0;
505  xmlVersion = 2016;
506  }
507  // Recursively connect nodes together.
508  loadFromXMLRecursive(xml, mainnode, rootNode);
509 
510  // Release memory before exit
511  xml->FreeDoc(xmldoc);
512  delete xml;
513 }
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:517
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
void Tree::loadFromXMLRecursive ( TXMLEngine *  xml,
XMLNodePointer_t  node,
Node tnode 
)

Definition at line 517 of file Tree.cc.

References emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), mps_fire::i, numTerminalNodes, emtf::Node::setFitValue(), emtf::Node::setSplitValue(), emtf::Node::setSplitVariable(), terminalNodes, emtf::Node::theMiracleOfChildBirth(), and xmlVersion.

Referenced by loadFromXML().

518 {
519 
520  // Get the split information from xml.
521  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
522  std::vector<std::string> splitInfo(3);
523  if( xmlVersion >= 2017 ){
524  for(unsigned int i=0,j=0; i<10; i++)
525  {
526  if(i==3 || i==4 || i==6){
527  splitInfo[j++] = xml->GetAttrValue(attr);
528  }
529  attr = xml->GetNextAttr(attr);
530  }
531  } else {
532  for(unsigned int i=0; i<3; i++)
533  {
534  splitInfo[i] = xml->GetAttrValue(attr);
535  attr = xml->GetNextAttr(attr);
536  }
537  }
538 
539  // Convert strings into numbers.
540  std::stringstream converter;
541  int splitVar;
542  double splitVal;
543  double fitVal;
544 
545  converter << splitInfo[0];
546  converter >> splitVar;
547  converter.str("");
548  converter.clear();
549 
550  converter << splitInfo[1];
551  converter >> splitVal;
552  converter.str("");
553  converter.clear();
554 
555  converter << splitInfo[2];
556  converter >> fitVal;
557  converter.str("");
558  converter.clear();
559 
560  // Store gathered splitInfo into the node object.
561  tnode->setSplitVariable(splitVar);
562  tnode->setSplitValue(splitVal);
563  tnode->setFitValue(fitVal);
564 
565  // Get the xml daughters of the current xml node.
566  XMLNodePointer_t xleft = xml->GetChild(xnode);
567  XMLNodePointer_t xright = xml->GetNext(xleft);
568 
569  // If there are no daughters we are done.
570  if(xleft == nullptr || xright == nullptr) return;
571 
572  // If there are daughters link the node objects appropriately.
573  tnode->theMiracleOfChildBirth();
574  Node* tleft = tnode->getLeftDaughter();
575  Node* tright = tnode->getRightDaughter();
576 
577  // Update the list of terminal nodes.
578  terminalNodes.remove(tnode);
579  terminalNodes.push_back(tleft);
580  terminalNodes.push_back(tright);
582 
583  loadFromXMLRecursive(xml, xleft, tleft);
584  loadFromXMLRecursive(xml, xright, tright);
585 }
void setFitValue(double sFitValue)
Definition: Node.cc:153
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
std::list< Node * > terminalNodes
Definition: Tree.h:61
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:517
int numTerminalNodes
Definition: Tree.h:62
void setSplitVariable(int sSplitVar)
Definition: Node.cc:141
unsigned xmlVersion
Definition: Tree.h:65
void setSplitValue(double sSplitValue)
Definition: Node.cc:131
void theMiracleOfChildBirth()
Definition: Node.cc:335
Tree & Tree::operator= ( const Tree tree)

if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();

Definition at line 78 of file Tree.cc.

References boostWeight, copyFrom(), findLeafs(), getRootNode(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

78  {
79  if(rootNode) delete rootNode;
80  // unfortunately, authors of these classes didn't use const qualifiers
81  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
83  rmsError = tree.rmsError;
84  boostWeight = tree.boostWeight;
85  xmlVersion = tree.xmlVersion;
86 
87  terminalNodes.resize(0);
88  // find new leafs
90 
92 
93  return *this;
94 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
Node * getRootNode()
Definition: Tree.cc:160
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:125
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:96
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
void Tree::rankVariables ( std::vector< double > &  v)

Definition at line 359 of file Tree.cc.

References rankVariablesRecursive(), and rootNode.

360 {
362 }
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:326
Node * rootNode
Definition: Tree.h:60
void Tree::rankVariablesRecursive ( Node node,
std::vector< double > &  v 
)

Definition at line 326 of file Tree.cc.

References emtf::Node::getErrorReduction(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), emtf::Node::getSplitVariable(), and pfDeepBoostedJetPreprocessParams_cfi::sv.

Referenced by rankVariables().

327 {
328 // We recursively go through all of the nodes in the tree and find the
329 // total error reduction for each variable. The one with the most
330 // error reduction should be the most important.
331 
332  Node* left = node->getLeftDaughter();
333  Node* right = node->getRightDaughter();
334 
335  // Terminal nodes don't contribute to error reduction.
336  if(left==nullptr || right==nullptr) return;
337 
338  int sv = node->getSplitVariable();
339  double er = node->getErrorReduction();
340 
341  //if(sv == -1)
342  //{
343  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
344  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
345  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
346  //}
347 
348  // Add error reduction to the current total for the appropriate
349  // variable.
350  v[sv] += er;
351 
352  rankVariablesRecursive(left, v);
353  rankVariablesRecursive(right, v);
354 
355 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:326
double getErrorReduction()
Definition: Node.cc:90
int getSplitVariable()
Definition: Node.cc:146
void Tree::saveToXML ( const char *  filename)

Definition at line 429 of file Tree.cc.

References addXMLAttributes(), emtf::Node::getName(), rootNode, saveToXMLRecursive(), and cmsPerfSuiteHarvest::xmldoc.

Referenced by emtf::Forest::doRegression().

430 {
431 
432  TXMLEngine* xml = new TXMLEngine();
433 
434  // Add the root node.
435  XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
436  addXMLAttributes(xml, rootNode, root);
437 
438  // Recursively write the tree to XML.
439  saveToXMLRecursive(xml, rootNode, root);
440 
441  // Make the XML Document.
442  XMLDocPointer_t xmldoc = xml->NewDoc();
443  xml->DocSetRootElement(xmldoc, root);
444 
445  // Save to file.
446  xml->SaveDoc(xmldoc, c);
447 
448  // Clean up.
449  xml->FreeDoc(xmldoc);
450  delete xml;
451 }
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:418
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:455
std::string getName()
Definition: Node.cc:78
Node * rootNode
Definition: Tree.h:60
void Tree::saveToXMLRecursive ( TXMLEngine *  xml,
Node node,
XMLNodePointer_t  np 
)

Definition at line 455 of file Tree.cc.

References addXMLAttributes(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), checklumidiff::l, and alignCSCRings::r.

Referenced by saveToXML().

456 {
457  Node* l = node->getLeftDaughter();
458  Node* r = node->getRightDaughter();
459 
460  if(l==nullptr || r==nullptr) return;
461 
462  // Add children to the XMLEngine.
463  XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
464  XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
465 
466  // Add attributes to the children.
467  addXMLAttributes(xml, l, left);
468  addXMLAttributes(xml, r, right);
469 
470  // Recurse.
471  saveToXMLRecursive(xml, l, left);
472  saveToXMLRecursive(xml, r, right);
473 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:418
int np
Definition: AMPTWrapper.h:33
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:455
void emtf::Tree::setBoostWeight ( double  wgt)
inline

Definition at line 57 of file Tree.h.

References boostWeight.

57 { boostWeight = wgt; }
double boostWeight
Definition: Tree.h:64
void Tree::setRootNode ( Node sRootNode)

Definition at line 155 of file Tree.cc.

References rootNode.

156 {
157  rootNode = sRootNode;
158 }
Node * rootNode
Definition: Tree.h:60
void Tree::setTerminalNodes ( std::list< Node * > &  sTNodes)

Definition at line 167 of file Tree.cc.

References terminalNodes.

168 {
169  terminalNodes = sTNodes;
170 }
std::list< Node * > terminalNodes
Definition: Tree.h:61

Member Data Documentation

double emtf::Tree::boostWeight
private

Definition at line 64 of file Tree.h.

Referenced by getBoostWeight(), loadFromXML(), operator=(), setBoostWeight(), and Tree().

int emtf::Tree::numTerminalNodes
private
double emtf::Tree::rmsError
private

Definition at line 63 of file Tree.h.

Referenced by calcError(), operator=(), and Tree().

Node* emtf::Tree::rootNode
private
std::list<Node*> emtf::Tree::terminalNodes
private
unsigned emtf::Tree::xmlVersion
private

Definition at line 65 of file Tree.h.

Referenced by loadFromXML(), loadFromXMLRecursive(), operator=(), and Tree().