CMS 3D CMS Logo

List of all members | Public Member Functions | Private Member Functions | Private Attributes
emtf::Tree Class Reference

#include <Tree.h>

Public Member Functions

void addXMLAttributes (TXMLEngine *xml, Node *node, XMLNodePointer_t np)
 
void buildTree (int nodeLimit)
 
void calcError ()
 
NodefilterEvent (Event *e)
 
NodefilterEventRecursive (Node *node, Event *e)
 
void filterEvents (std::vector< Event * > &tEvents)
 
void filterEventsRecursive (Node *node)
 
double getBoostWeight (void) const
 
int getNumTerminalNodes ()
 
NodegetRootNode ()
 
void getSplitValues (std::vector< std::vector< double >> &v)
 
void getSplitValuesRecursive (Node *node, std::vector< std::vector< double >> &v)
 
std::list< Node * > & getTerminalNodes ()
 
void loadFromCondPayload (const L1TMuonEndCapForest::DTree &tree)
 
void loadFromCondPayloadRecursive (const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
 
void loadFromXML (const char *filename)
 
void loadFromXMLRecursive (TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
 
Treeoperator= (const Tree &tree)
 
void rankVariables (std::vector< double > &v)
 
void rankVariablesRecursive (Node *node, std::vector< double > &v)
 
void saveToXML (const char *filename)
 
void saveToXMLRecursive (TXMLEngine *xml, Node *node, XMLNodePointer_t np)
 
void setBoostWeight (double wgt)
 
void setRootNode (Node *sRootNode)
 
void setTerminalNodes (std::list< Node * > &sTNodes)
 
 Tree ()
 
 Tree (std::vector< std::vector< Event * > > &cEvents)
 
 Tree (const Tree &tree)
 
 Tree (Tree &&tree)
 
 ~Tree ()
 

Private Member Functions

NodecopyFrom (const Node *local_root)
 
void findLeafs (Node *local_root, std::list< Node * > &tn)
 

Private Attributes

double boostWeight
 
int numTerminalNodes
 
double rmsError
 
NoderootNode
 
std::list< Node * > terminalNodes
 
unsigned xmlVersion
 

Detailed Description

Definition at line 15 of file Tree.h.

Constructor & Destructor Documentation

Tree::Tree ( )

Definition at line 29 of file Tree.cc.

References boostWeight, numTerminalNodes, rootNode, terminalNodes, and xmlVersion.

30 {
31  rootNode = new Node("root");
32 
33  terminalNodes.push_back(rootNode);
34  numTerminalNodes = 1;
35  boostWeight = 0;
36  xmlVersion = 2017;
37 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
Tree::Tree ( std::vector< std::vector< Event * > > &  cEvents)

Definition at line 39 of file Tree.cc.

References boostWeight, numTerminalNodes, rootNode, emtf::Node::setEvents(), terminalNodes, and xmlVersion.

40 {
41  rootNode = new Node("root");
42  rootNode->setEvents(cEvents);
43 
44  terminalNodes.push_back(rootNode);
45  numTerminalNodes = 1;
46  boostWeight = 0;
47  xmlVersion = 2017;
48 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:204
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
Tree::~Tree ( )

Definition at line 54 of file Tree.cc.

References rootNode.

55 {
56 // When the tree is destroyed it will delete all of the nodes in the tree.
57 // The deletion begins with the rootnode and continues recursively.
58  if(rootNode) delete rootNode;
59 }
Node * rootNode
Definition: Tree.h:60
Tree::Tree ( const Tree tree)

if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();

Definition at line 61 of file Tree.cc.

References boostWeight, copyFrom(), findLeafs(), getRootNode(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

62 {
63  // unfortunately, authors of these classes didn't use const qualifiers
64  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
66  rmsError = tree.rmsError;
67  boostWeight = tree.boostWeight;
68  xmlVersion = tree.xmlVersion;
69 
70  terminalNodes.resize(0);
71  // find new leafs
73 
75 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
Node * getRootNode()
Definition: Tree.cc:159
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:124
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:95
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
Tree::Tree ( Tree &&  tree)

Definition at line 139 of file Tree.cc.

References boostWeight, eostools::move(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

140 {
141  if(rootNode) delete rootNode; // this line is the only reason not to use default move constructor
142  rootNode = tree.rootNode;
145  rmsError = tree.rmsError;
146  boostWeight = tree.boostWeight;
147  xmlVersion = tree.xmlVersion;
148 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
def move(src, dest)
Definition: eostools.py:510

Member Function Documentation

void Tree::addXMLAttributes ( TXMLEngine *  xml,
Node node,
XMLNodePointer_t  np 
)

Definition at line 417 of file Tree.cc.

References emtf::Node::getFitValue(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), and emtf::numToStr().

Referenced by saveToXML(), and saveToXMLRecursive().

418 {
419  // Convert Node members into XML attributes
420  // and add them to the XMLEngine.
421  xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
422  xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
423  xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
424 }
double getFitValue()
Definition: Node.cc:158
int np
Definition: AMPTWrapper.h:33
double getSplitValue()
Definition: Node.cc:136
int getSplitVariable()
Definition: Node.cc:146
std::string numToStr(T num)
Definition: Utilities.h:44
void Tree::buildTree ( int  nodeLimit)

Definition at line 203 of file Tree.cc.

References calcError(), emtf::Node::calcOptimumSplit(), emtf::Node::filterEventsToDaughters(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), numTerminalNodes, rootNode, terminalNodes, and emtf::Node::theMiracleOfChildBirth().

Referenced by emtf::Forest::doRegression().

204 {
205  // We greedily pick the best terminal node to split.
206  double bestNodeErrorReduction = -1;
207  Node* nodeToSplit = nullptr;
208 
209  if(numTerminalNodes == 1)
210  {
212  calcError();
213 // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
214  }
215 
216  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
217  {
218  if( (*it)->getErrorReduction() > bestNodeErrorReduction )
219  {
220  bestNodeErrorReduction = (*it)->getErrorReduction();
221  nodeToSplit = (*it);
222  }
223  }
224 
225  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
226 
227  // If all of the nodes have one event we can't add any more nodes and reduce the error.
228  if(nodeToSplit == nullptr) return;
229 
230  // Create daughter nodes, and link the nodes together appropriately.
231  nodeToSplit->theMiracleOfChildBirth();
232 
233  // Get left and right daughters for reference.
234  Node* left = nodeToSplit->getLeftDaughter();
235  Node* right = nodeToSplit->getRightDaughter();
236 
237  // Update the list of terminal nodes.
238  terminalNodes.remove(nodeToSplit);
239  terminalNodes.push_back(left);
240  terminalNodes.push_back(right);
242 
243  // Filter the events from the parent into the daughters.
244  nodeToSplit->filterEventsToDaughters();
245 
246  // Calculate the best splits for the new nodes.
247  left->calcOptimumSplit();
248  right->calcOptimumSplit();
249 
250  // See if the error reduces as we add more nodes.
251  calcError();
252 
253  if(numTerminalNodes % 1 == 0)
254  {
255 // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
256  }
257 
258  // Repeat until done.
259  if(numTerminalNodes < nodeLimit) buildTree(nodeLimit);
260 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
std::list< Node * > terminalNodes
Definition: Tree.h:61
void buildTree(int nodeLimit)
Definition: Tree.cc:203
int numTerminalNodes
Definition: Tree.h:62
void calcError()
Definition: Tree.cc:187
Node * rootNode
Definition: Tree.h:60
void filterEventsToDaughters()
Definition: Node.cc:350
void calcOptimumSplit()
Definition: Node.cc:214
void theMiracleOfChildBirth()
Definition: Node.cc:335
void Tree::calcError ( )

Definition at line 187 of file Tree.cc.

References emtf::Node::getNumEvents(), rmsError, rootNode, mathSSE::sqrt(), and terminalNodes.

Referenced by buildTree().

188 {
189 // Loop through the separate predictive regions (terminal nodes) and
190 // add up the errors to get the error of the entire space.
191 
192  double totalSquaredError = 0;
193 
194  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
195  {
196  totalSquaredError += (*it)->getTotalError();
197  }
198  rmsError = sqrt( totalSquaredError/rootNode->getNumEvents() );
199 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
T sqrt(T t)
Definition: SSEVec.h:18
int getNumEvents()
Definition: Node.cc:192
Node * rootNode
Definition: Tree.h:60
Node * Tree::copyFrom ( const Node local_root)
private

Definition at line 95 of file Tree.cc.

References emtf::Node::getAvgError(), emtf::Node::getErrorReduction(), emtf::Node::getFitValue(), emtf::Node::getLeftDaughter(), emtf::Node::getName(), emtf::Node::getNumEvents(), emtf::Node::getRightDaughter(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), emtf::Node::getTotalError(), and emtf::Node::setParent().

Referenced by operator=(), and Tree().

96 {
97  // end-case
98  if( !local_root ) return nullptr;
99 
100  Node *lr = const_cast<Node*>(local_root);
101 
102  // recursion
103  Node *left_new_child = copyFrom( lr->getLeftDaughter() );
104  Node *right_new_child = copyFrom( lr->getRightDaughter() );
105 
106  // performing main work at this level
107  Node *new_local_root = new Node( lr->getName() );
108  if( left_new_child ) left_new_child ->setParent(new_local_root);
109  if( right_new_child ) right_new_child->setParent(new_local_root);
110  new_local_root->setLeftDaughter ( left_new_child );
111  new_local_root->setRightDaughter( right_new_child );
112  new_local_root->setErrorReduction( lr->getErrorReduction() );
113  new_local_root->setSplitValue( lr->getSplitValue() );
114  new_local_root->setSplitVariable( lr->getSplitVariable() );
115  new_local_root->setFitValue( lr->getFitValue() );
116  new_local_root->setTotalError( lr->getTotalError() );
117  new_local_root->setAvgError( lr->getAvgError() );
118  new_local_root->setNumEvents( lr->getNumEvents() );
119 // new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
120 
121  return new_local_root;
122 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
double getFitValue()
Definition: Node.cc:158
double getTotalError()
Definition: Node.cc:170
double getSplitValue()
Definition: Node.cc:136
std::string getName()
Definition: Node.cc:78
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:95
void setParent(Node *sParent)
Definition: Node.cc:119
int getNumEvents()
Definition: Node.cc:192
double getErrorReduction()
Definition: Node.cc:90
double getAvgError()
Definition: Node.cc:180
int getSplitVariable()
Definition: Node.cc:146
Node * Tree::filterEvent ( Event e)

Definition at line 298 of file Tree.cc.

References filterEventRecursive(), and rootNode.

Referenced by emtf::Forest::appendCorrection().

299 {
300 // Use trees which have already been built to fit a bunch of events
301 // given by the tEvents vector.
302 
303  // Filter the event into a predictive region (terminal node).
304  Node* node = filterEventRecursive(rootNode, e);
305  return node;
306 }
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:310
Node * rootNode
Definition: Tree.h:60
Node * Tree::filterEventRecursive ( Node node,
Event e 
)

Definition at line 310 of file Tree.cc.

References emtf::Node::filterEventToDaughter().

Referenced by filterEvent().

311 {
312 // Filter the event repeatedly into the daughter nodes until it
313 // falls into a terminal node.
314 
315 
316  Node* nextNode = node->filterEventToDaughter(e);
317  if(nextNode == nullptr) return node;
318 
319  return filterEventRecursive(nextNode, e);
320 }
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:310
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:395
void Tree::filterEvents ( std::vector< Event * > &  tEvents)

Definition at line 264 of file Tree.cc.

References filterEventsRecursive(), emtf::Node::getEvents(), and rootNode.

Referenced by emtf::Forest::appendCorrection().

265 {
266 // Use trees which have already been built to fit a bunch of events
267 // given by the tEvents vector.
268 
269  // Set the events to be filtered.
270  rootNode->getEvents() = std::vector< std::vector<Event*> >(1);
271  rootNode->getEvents()[0] = tEvents;
272 
273  // The tree now knows about the events it needs to fit.
274  // Filter them into a predictive region (terminal node).
276 }
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:199
void filterEventsRecursive(Node *node)
Definition: Tree.cc:280
Node * rootNode
Definition: Tree.h:60
void Tree::filterEventsRecursive ( Node node)

Definition at line 280 of file Tree.cc.

References emtf::Node::filterEventsToDaughters(), emtf::Node::getLeftDaughter(), and emtf::Node::getRightDaughter().

Referenced by filterEvents().

281 {
282 // Filter the events repeatedly into the daughter nodes until they
283 // fall into a terminal node.
284 
285  Node* left = node->getLeftDaughter();
286  Node* right = node->getRightDaughter();
287 
288  if(left == nullptr || right == nullptr) return;
289 
290  node->filterEventsToDaughters();
291 
292  filterEventsRecursive(left);
293  filterEventsRecursive(right);
294 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void filterEventsRecursive(Node *node)
Definition: Tree.cc:280
void filterEventsToDaughters()
Definition: Node.cc:350
void Tree::findLeafs ( Node local_root,
std::list< Node * > &  tn 
)
private

Definition at line 124 of file Tree.cc.

References emtf::Node::getLeftDaughter(), and emtf::Node::getRightDaughter().

Referenced by operator=(), and Tree().

125 {
126  if( !local_root->getLeftDaughter() && !local_root->getRightDaughter() ){
127  // leaf or ternimal node found
128  tn.push_back(local_root);
129  return;
130  }
131 
132  if( local_root->getLeftDaughter() )
133  findLeafs( local_root->getLeftDaughter(), tn );
134 
135  if( local_root->getRightDaughter() )
136  findLeafs( local_root->getRightDaughter(), tn );
137 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:124
double emtf::Tree::getBoostWeight ( void  ) const
inline

Definition at line 56 of file Tree.h.

References boostWeight.

Referenced by L1TMuonEndCapForestESProducer::produce().

56 { return boostWeight; }
double boostWeight
Definition: Tree.h:64
int Tree::getNumTerminalNodes ( )

Definition at line 178 of file Tree.cc.

References numTerminalNodes.

179 {
180  return numTerminalNodes;
181 }
int numTerminalNodes
Definition: Tree.h:62
Node * Tree::getRootNode ( )

Definition at line 159 of file Tree.cc.

References rootNode.

Referenced by operator=(), L1TMuonEndCapForestESProducer::produce(), and Tree().

160 {
161  return rootNode;
162 }
Node * rootNode
Definition: Tree.h:60
void Tree::getSplitValues ( std::vector< std::vector< double >> &  v)

Definition at line 396 of file Tree.cc.

References getSplitValuesRecursive(), rootNode, and findQualityFiles::v.

397 {
399 }
Node * rootNode
Definition: Tree.h:60
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:366
void Tree::getSplitValuesRecursive ( Node node,
std::vector< std::vector< double >> &  v 
)

Definition at line 366 of file Tree.cc.

References gather_cfg::cout, emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), pfDeepBoostedJetPreprocessParams_cfi::sv, and findQualityFiles::v.

Referenced by getSplitValues().

367 {
368 // We recursively go through all of the nodes in the tree and find the
369 // split points used for each split variable.
370 
371  Node* left = node->getLeftDaughter();
372  Node* right = node->getRightDaughter();
373 
374  // Terminal nodes don't contribute.
375  if(left==nullptr || right==nullptr) return;
376 
377  int sv = node->getSplitVariable();
378  double sp = node->getSplitValue();
379 
380  if(sv == -1)
381  {
382  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
383  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
384  }
385 
386  // Add the split point to the list for the correct split variable.
387  v[sv].push_back(sp);
388 
389  getSplitValuesRecursive(left, v);
390  getSplitValuesRecursive(right, v);
391 
392 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
double getSplitValue()
Definition: Node.cc:136
int getSplitVariable()
Definition: Node.cc:146
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:366
std::list< Node * > & Tree::getTerminalNodes ( )

Definition at line 171 of file Tree.cc.

References terminalNodes.

Referenced by emtf::Forest::updateEvents(), and emtf::Forest::updateRegTargets().

172 {
173  return terminalNodes;
174 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
void Tree::loadFromCondPayload ( const L1TMuonEndCapForest::DTree tree)

Definition at line 586 of file Tree.cc.

References loadFromCondPayloadRecursive(), and rootNode.

587 {
588  // start fresh in case this is not the only call to construct a tree
589  if( rootNode ) delete rootNode;
590  rootNode = new Node("root");
591 
592  const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
594 }
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:596
Node * rootNode
Definition: Tree.h:60
Definition: tree.py:1
void Tree::loadFromCondPayloadRecursive ( const L1TMuonEndCapForest::DTree tree,
const L1TMuonEndCapForest::DTreeNode node,
Node tnode 
)

Definition at line 596 of file Tree.cc.

References L1TMuonEndCapForest::DTreeNode::fitVal, emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), L1TMuonEndCapForest::DTreeNode::ileft, L1TMuonEndCapForest::DTreeNode::iright, numTerminalNodes, emtf::Node::setFitValue(), emtf::Node::setSplitValue(), emtf::Node::setSplitVariable(), L1TMuonEndCapForest::DTreeNode::splitVal, L1TMuonEndCapForest::DTreeNode::splitVar, terminalNodes, and emtf::Node::theMiracleOfChildBirth().

Referenced by loadFromCondPayload().

597 {
598  // Store gathered splitInfo into the node object.
599  tnode->setSplitVariable(node.splitVar);
600  tnode->setSplitValue(node.splitVal);
601  tnode->setFitValue(node.fitVal);
602 
603  // If there are no daughters we are done.
604  if( node.ileft == 0 || node.iright == 0) return; // root cannot be anyone's child
605  if( node.ileft >= tree.size() ||
606  node.iright >= tree.size() ) return; // out of range addressing on purpose
607 
608  // If there are daughters link the node objects appropriately.
609  tnode->theMiracleOfChildBirth();
610  Node* tleft = tnode->getLeftDaughter();
611  Node* tright = tnode->getRightDaughter();
612 
613  // Update the list of terminal nodes.
614  terminalNodes.remove(tnode);
615  terminalNodes.push_back(tleft);
616  terminalNodes.push_back(tright);
618 
621 }
void setFitValue(double sFitValue)
Definition: Node.cc:153
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
std::list< Node * > terminalNodes
Definition: Tree.h:61
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:596
int numTerminalNodes
Definition: Tree.h:62
void setSplitVariable(int sSplitVar)
Definition: Node.cc:141
Definition: tree.py:1
void setSplitValue(double sSplitValue)
Definition: Node.cc:131
void theMiracleOfChildBirth()
Definition: Node.cc:335
void Tree::loadFromXML ( const char *  filename)

Definition at line 476 of file Tree.cc.

References boostWeight, loadFromXMLRecursive(), rootNode, AlCaHLTBitMon_QueryRunRegistry::string, cmsPerfSuiteHarvest::xmldoc, and xmlVersion.

477 {
478  // First create the engine.
479  TXMLEngine* xml = new TXMLEngine;
480 
481  // Now try to parse xml file.
482  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
483  if (xmldoc==nullptr)
484  {
485  delete xml;
486  return;
487  }
488 
489  // Get access to main node of the xml file.
490  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
491 
492  // the original 2016 pT xmls define the source tree node to be the top-level xml node
493  // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
494  // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
495  if( std::string("BinaryTree") == xml->GetNodeName(mainnode) ){
496  XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
497  attr = xml->GetNextAttr(attr);
498  boostWeight = (attr ? strtod(xml->GetAttrValue(attr),nullptr) : 0);
499  // step inside the top-level xml node
500  mainnode = xml->GetChild(mainnode);
501  xmlVersion = 2017;
502  } else {
503  boostWeight = 0;
504  xmlVersion = 2016;
505  }
506  // Recursively connect nodes together.
507  loadFromXMLRecursive(xml, mainnode, rootNode);
508 
509  // Release memory before exit
510  xml->FreeDoc(xmldoc);
511  delete xml;
512 }
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:516
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
void Tree::loadFromXMLRecursive ( TXMLEngine *  xml,
XMLNodePointer_t  node,
Node tnode 
)

Definition at line 516 of file Tree.cc.

References emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), mps_fire::i, numTerminalNodes, emtf::Node::setFitValue(), emtf::Node::setSplitValue(), emtf::Node::setSplitVariable(), terminalNodes, emtf::Node::theMiracleOfChildBirth(), and xmlVersion.

Referenced by loadFromXML().

517 {
518 
519  // Get the split information from xml.
520  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
521  std::vector<std::string> splitInfo(3);
522  if( xmlVersion >= 2017 ){
523  for(unsigned int i=0,j=0; i<10; i++)
524  {
525  if(i==3 || i==4 || i==6){
526  splitInfo[j++] = xml->GetAttrValue(attr);
527  }
528  attr = xml->GetNextAttr(attr);
529  }
530  } else {
531  for(unsigned int i=0; i<3; i++)
532  {
533  splitInfo[i] = xml->GetAttrValue(attr);
534  attr = xml->GetNextAttr(attr);
535  }
536  }
537 
538  // Convert strings into numbers.
539  std::stringstream converter;
540  int splitVar;
541  double splitVal;
542  double fitVal;
543 
544  converter << splitInfo[0];
545  converter >> splitVar;
546  converter.str("");
547  converter.clear();
548 
549  converter << splitInfo[1];
550  converter >> splitVal;
551  converter.str("");
552  converter.clear();
553 
554  converter << splitInfo[2];
555  converter >> fitVal;
556  converter.str("");
557  converter.clear();
558 
559  // Store gathered splitInfo into the node object.
560  tnode->setSplitVariable(splitVar);
561  tnode->setSplitValue(splitVal);
562  tnode->setFitValue(fitVal);
563 
564  // Get the xml daughters of the current xml node.
565  XMLNodePointer_t xleft = xml->GetChild(xnode);
566  XMLNodePointer_t xright = xml->GetNext(xleft);
567 
568  // If there are no daughters we are done.
569  if(xleft == nullptr || xright == nullptr) return;
570 
571  // If there are daughters link the node objects appropriately.
572  tnode->theMiracleOfChildBirth();
573  Node* tleft = tnode->getLeftDaughter();
574  Node* tright = tnode->getRightDaughter();
575 
576  // Update the list of terminal nodes.
577  terminalNodes.remove(tnode);
578  terminalNodes.push_back(tleft);
579  terminalNodes.push_back(tright);
581 
582  loadFromXMLRecursive(xml, xleft, tleft);
583  loadFromXMLRecursive(xml, xright, tright);
584 }
void setFitValue(double sFitValue)
Definition: Node.cc:153
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
std::list< Node * > terminalNodes
Definition: Tree.h:61
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:516
int numTerminalNodes
Definition: Tree.h:62
void setSplitVariable(int sSplitVar)
Definition: Node.cc:141
unsigned xmlVersion
Definition: Tree.h:65
void setSplitValue(double sSplitValue)
Definition: Node.cc:131
void theMiracleOfChildBirth()
Definition: Node.cc:335
Tree & Tree::operator= ( const Tree tree)

if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();

Definition at line 77 of file Tree.cc.

References boostWeight, copyFrom(), findLeafs(), getRootNode(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

77  {
78  if(rootNode) delete rootNode;
79  // unfortunately, authors of these classes didn't use const qualifiers
80  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
82  rmsError = tree.rmsError;
83  boostWeight = tree.boostWeight;
84  xmlVersion = tree.xmlVersion;
85 
86  terminalNodes.resize(0);
87  // find new leafs
89 
91 
92  return *this;
93 }
std::list< Node * > terminalNodes
Definition: Tree.h:61
double rmsError
Definition: Tree.h:63
Node * getRootNode()
Definition: Tree.cc:159
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:124
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:95
int numTerminalNodes
Definition: Tree.h:62
double boostWeight
Definition: Tree.h:64
Node * rootNode
Definition: Tree.h:60
unsigned xmlVersion
Definition: Tree.h:65
void Tree::rankVariables ( std::vector< double > &  v)

Definition at line 358 of file Tree.cc.

References rankVariablesRecursive(), and rootNode.

359 {
361 }
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:325
Node * rootNode
Definition: Tree.h:60
void Tree::rankVariablesRecursive ( Node node,
std::vector< double > &  v 
)

Definition at line 325 of file Tree.cc.

References emtf::Node::getErrorReduction(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), emtf::Node::getSplitVariable(), and pfDeepBoostedJetPreprocessParams_cfi::sv.

Referenced by rankVariables().

326 {
327 // We recursively go through all of the nodes in the tree and find the
328 // total error reduction for each variable. The one with the most
329 // error reduction should be the most important.
330 
331  Node* left = node->getLeftDaughter();
332  Node* right = node->getRightDaughter();
333 
334  // Terminal nodes don't contribute to error reduction.
335  if(left==nullptr || right==nullptr) return;
336 
337  int sv = node->getSplitVariable();
338  double er = node->getErrorReduction();
339 
340  //if(sv == -1)
341  //{
342  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
343  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
344  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
345  //}
346 
347  // Add error reduction to the current total for the appropriate
348  // variable.
349  v[sv] += er;
350 
351  rankVariablesRecursive(left, v);
352  rankVariablesRecursive(right, v);
353 
354 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:325
double getErrorReduction()
Definition: Node.cc:90
int getSplitVariable()
Definition: Node.cc:146
void Tree::saveToXML ( const char *  filename)

Definition at line 428 of file Tree.cc.

References addXMLAttributes(), emtf::Node::getName(), rootNode, saveToXMLRecursive(), and cmsPerfSuiteHarvest::xmldoc.

Referenced by emtf::Forest::doRegression().

429 {
430 
431  TXMLEngine* xml = new TXMLEngine();
432 
433  // Add the root node.
434  XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
435  addXMLAttributes(xml, rootNode, root);
436 
437  // Recursively write the tree to XML.
438  saveToXMLRecursive(xml, rootNode, root);
439 
440  // Make the XML Document.
441  XMLDocPointer_t xmldoc = xml->NewDoc();
442  xml->DocSetRootElement(xmldoc, root);
443 
444  // Save to file.
445  xml->SaveDoc(xmldoc, c);
446 
447  // Clean up.
448  xml->FreeDoc(xmldoc);
449  delete xml;
450 }
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:417
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:454
std::string getName()
Definition: Node.cc:78
Node * rootNode
Definition: Tree.h:60
void Tree::saveToXMLRecursive ( TXMLEngine *  xml,
Node node,
XMLNodePointer_t  np 
)

Definition at line 454 of file Tree.cc.

References addXMLAttributes(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), checklumidiff::l, and alignCSCRings::r.

Referenced by saveToXML().

455 {
456  Node* l = node->getLeftDaughter();
457  Node* r = node->getRightDaughter();
458 
459  if(l==nullptr || r==nullptr) return;
460 
461  // Add children to the XMLEngine.
462  XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
463  XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
464 
465  // Add attributes to the children.
466  addXMLAttributes(xml, l, left);
467  addXMLAttributes(xml, r, right);
468 
469  // Recurse.
470  saveToXMLRecursive(xml, l, left);
471  saveToXMLRecursive(xml, r, right);
472 }
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:417
int np
Definition: AMPTWrapper.h:33
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:454
void emtf::Tree::setBoostWeight ( double  wgt)
inline

Definition at line 57 of file Tree.h.

References boostWeight.

57 { boostWeight = wgt; }
double boostWeight
Definition: Tree.h:64
void Tree::setRootNode ( Node sRootNode)

Definition at line 154 of file Tree.cc.

References rootNode.

155 {
156  rootNode = sRootNode;
157 }
Node * rootNode
Definition: Tree.h:60
void Tree::setTerminalNodes ( std::list< Node * > &  sTNodes)

Definition at line 166 of file Tree.cc.

References terminalNodes.

167 {
168  terminalNodes = sTNodes;
169 }
std::list< Node * > terminalNodes
Definition: Tree.h:61

Member Data Documentation

double emtf::Tree::boostWeight
private

Definition at line 64 of file Tree.h.

Referenced by getBoostWeight(), loadFromXML(), operator=(), setBoostWeight(), and Tree().

int emtf::Tree::numTerminalNodes
private
double emtf::Tree::rmsError
private

Definition at line 63 of file Tree.h.

Referenced by calcError(), operator=(), and Tree().

Node* emtf::Tree::rootNode
private
std::list<Node*> emtf::Tree::terminalNodes
private
unsigned emtf::Tree::xmlVersion
private

Definition at line 65 of file Tree.h.

Referenced by loadFromXML(), loadFromXMLRecursive(), operator=(), and Tree().