CMS 3D CMS Logo

List of all members | Public Member Functions | Private Member Functions | Private Attributes
emtf::Tree Class Reference

#include <Tree.h>

Public Member Functions

void addXMLAttributes (TXMLEngine *xml, Node *node, XMLNodePointer_t np)
 
void buildTree (int nodeLimit)
 
void calcError ()
 
NodefilterEvent (Event *e)
 
NodefilterEventRecursive (Node *node, Event *e)
 
void filterEvents (std::vector< Event *> &tEvents)
 
void filterEventsRecursive (Node *node)
 
double getBoostWeight (void) const
 
int getNumTerminalNodes ()
 
NodegetRootNode ()
 
void getSplitValues (std::vector< std::vector< double >> &v)
 
void getSplitValuesRecursive (Node *node, std::vector< std::vector< double >> &v)
 
std::list< Node * > & getTerminalNodes ()
 
void loadFromCondPayload (const L1TMuonEndCapForest::DTree &tree)
 
void loadFromCondPayloadRecursive (const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
 
void loadFromXML (const char *filename)
 
void loadFromXMLRecursive (TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
 
Treeoperator= (const Tree &tree)
 
void rankVariables (std::vector< double > &v)
 
void rankVariablesRecursive (Node *node, std::vector< double > &v)
 
void saveToXML (const char *filename)
 
void saveToXMLRecursive (TXMLEngine *xml, Node *node, XMLNodePointer_t np)
 
void setBoostWeight (double wgt)
 
void setRootNode (Node *sRootNode)
 
void setTerminalNodes (std::list< Node *> &sTNodes)
 
 Tree ()
 
 Tree (std::vector< std::vector< Event *>> &cEvents)
 
 Tree (const Tree &tree)
 
 Tree (Tree &&tree)
 
 ~Tree ()
 

Private Member Functions

NodecopyFrom (const Node *local_root)
 
void findLeafs (Node *local_root, std::list< Node *> &tn)
 

Private Attributes

double boostWeight
 
int numTerminalNodes
 
double rmsError
 
NoderootNode
 
std::list< Node * > terminalNodes
 
unsigned xmlVersion
 

Detailed Description

Definition at line 15 of file Tree.h.

Constructor & Destructor Documentation

◆ Tree() [1/4]

Tree::Tree ( )

Definition at line 30 of file Tree.cc.

References boostWeight, numTerminalNodes, rootNode, terminalNodes, and xmlVersion.

30  {
31  rootNode = new Node("root");
32 
33  terminalNodes.push_back(rootNode);
34  numTerminalNodes = 1;
35  boostWeight = 0;
36  xmlVersion = 2017;
37 }
std::list< Node * > terminalNodes
Definition: Tree.h:62
TGeoNode Node
int numTerminalNodes
Definition: Tree.h:63
double boostWeight
Definition: Tree.h:65
Node * rootNode
Definition: Tree.h:61
unsigned xmlVersion
Definition: Tree.h:66

◆ Tree() [2/4]

Tree::Tree ( std::vector< std::vector< Event *>> &  cEvents)

Definition at line 39 of file Tree.cc.

References boostWeight, numTerminalNodes, rootNode, emtf::Node::setEvents(), terminalNodes, and xmlVersion.

39  {
40  rootNode = new Node("root");
41  rootNode->setEvents(cEvents);
42 
43  terminalNodes.push_back(rootNode);
44  numTerminalNodes = 1;
45  boostWeight = 0;
46  xmlVersion = 2017;
47 }
std::list< Node * > terminalNodes
Definition: Tree.h:62
TGeoNode Node
void setEvents(std::vector< std::vector< Event *> > &sEvents)
Definition: Node.cc:134
int numTerminalNodes
Definition: Tree.h:63
double boostWeight
Definition: Tree.h:65
Node * rootNode
Definition: Tree.h:61
unsigned xmlVersion
Definition: Tree.h:66

◆ ~Tree()

Tree::~Tree ( )

Definition at line 52 of file Tree.cc.

References rootNode.

52  {
53  // When the tree is destroyed it will delete all of the nodes in the tree.
54  // The deletion begins with the rootnode and continues recursively.
55  if (rootNode)
56  delete rootNode;
57 }
Node * rootNode
Definition: Tree.h:61

◆ Tree() [3/4]

Tree::Tree ( const Tree tree)

if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();

Definition at line 59 of file Tree.cc.

References boostWeight, copyFrom(), findLeafs(), getRootNode(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

59  {
60  // unfortunately, authors of these classes didn't use const qualifiers
61  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
62  numTerminalNodes = tree.numTerminalNodes;
63  rmsError = tree.rmsError;
64  boostWeight = tree.boostWeight;
65  xmlVersion = tree.xmlVersion;
66 
67  terminalNodes.resize(0);
68  // find new leafs
70 
72 }
std::list< Node * > terminalNodes
Definition: Tree.h:62
double rmsError
Definition: Tree.h:64
Node * getRootNode()
Definition: Tree.cc:155
void findLeafs(Node *local_root, std::list< Node *> &tn)
Definition: Tree.cc:124
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:93
int numTerminalNodes
Definition: Tree.h:63
double boostWeight
Definition: Tree.h:65
Node * rootNode
Definition: Tree.h:61
unsigned xmlVersion
Definition: Tree.h:66
Definition: tree.py:1

◆ Tree() [4/4]

Tree::Tree ( Tree &&  tree)

Definition at line 138 of file Tree.cc.

References boostWeight, eostools::move(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

138  {
139  if (rootNode)
140  delete rootNode; // this line is the only reason not to use default move constructor
141  rootNode = tree.rootNode;
142  terminalNodes = std::move(tree.terminalNodes);
143  numTerminalNodes = tree.numTerminalNodes;
144  rmsError = tree.rmsError;
145  boostWeight = tree.boostWeight;
146  xmlVersion = tree.xmlVersion;
147 }
std::list< Node * > terminalNodes
Definition: Tree.h:62
double rmsError
Definition: Tree.h:64
int numTerminalNodes
Definition: Tree.h:63
double boostWeight
Definition: Tree.h:65
Node * rootNode
Definition: Tree.h:61
unsigned xmlVersion
Definition: Tree.h:66
Definition: tree.py:1
def move(src, dest)
Definition: eostools.py:511

Member Function Documentation

◆ addXMLAttributes()

void Tree::addXMLAttributes ( TXMLEngine *  xml,
Node node,
XMLNodePointer_t  np 
)

Definition at line 381 of file Tree.cc.

References emtf::Node::getFitValue(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), np, and emtf::numToStr().

Referenced by saveToXML(), and saveToXMLRecursive().

381  {
382  // Convert Node members into XML attributes
383  // and add them to the XMLEngine.
384  xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
385  xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
386  xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
387 }
double getFitValue()
Definition: Node.cc:112
int np
Definition: AMPTWrapper.h:43
double getSplitValue()
Definition: Node.cc:102
int getSplitVariable()
Definition: Node.cc:106
std::string numToStr(T num)
Definition: Utilities.h:43

◆ buildTree()

void Tree::buildTree ( int  nodeLimit)

Definition at line 185 of file Tree.cc.

References calcError(), emtf::Node::calcOptimumSplit(), emtf::Node::filterEventsToDaughters(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), numTerminalNodes, rootNode, terminalNodes, and emtf::Node::theMiracleOfChildBirth().

185  {
186  // We greedily pick the best terminal node to split.
187  double bestNodeErrorReduction = -1;
188  Node* nodeToSplit = nullptr;
189 
190  if (numTerminalNodes == 1) {
192  calcError();
193  // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
194  }
195 
196  for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
197  if ((*it)->getErrorReduction() > bestNodeErrorReduction) {
198  bestNodeErrorReduction = (*it)->getErrorReduction();
199  nodeToSplit = (*it);
200  }
201  }
202 
203  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
204 
205  // If all of the nodes have one event we can't add any more nodes and reduce the error.
206  if (nodeToSplit == nullptr)
207  return;
208 
209  // Create daughter nodes, and link the nodes together appropriately.
210  nodeToSplit->theMiracleOfChildBirth();
211 
212  // Get left and right daughters for reference.
213  Node* left = nodeToSplit->getLeftDaughter();
214  Node* right = nodeToSplit->getRightDaughter();
215 
216  // Update the list of terminal nodes.
217  terminalNodes.remove(nodeToSplit);
218  terminalNodes.push_back(left);
219  terminalNodes.push_back(right);
221 
222  // Filter the events from the parent into the daughters.
223  nodeToSplit->filterEventsToDaughters();
224 
225  // Calculate the best splits for the new nodes.
226  left->calcOptimumSplit();
227  right->calcOptimumSplit();
228 
229  // See if the error reduces as we add more nodes.
230  calcError();
231 
232  if (numTerminalNodes % 1 == 0) {
233  // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
234  }
235 
236  // Repeat until done.
237  if (numTerminalNodes < nodeLimit)
238  buildTree(nodeLimit);
239 }
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
std::list< Node * > terminalNodes
Definition: Tree.h:62
void buildTree(int nodeLimit)
Definition: Tree.cc:185
int numTerminalNodes
Definition: Tree.h:63
void calcError()
Definition: Tree.cc:171
Node * rootNode
Definition: Tree.h:61
void filterEventsToDaughters()
Definition: Node.cc:267
void calcOptimumSplit()
Definition: Node.cc:143
void theMiracleOfChildBirth()
Definition: Node.cc:253

◆ calcError()

void Tree::calcError ( )

Definition at line 171 of file Tree.cc.

References emtf::Node::getNumEvents(), rmsError, rootNode, mathSSE::sqrt(), and terminalNodes.

Referenced by buildTree().

171  {
172  // Loop through the separate predictive regions (terminal nodes) and
173  // add up the errors to get the error of the entire space.
174 
175  double totalSquaredError = 0;
176 
177  for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
178  totalSquaredError += (*it)->getTotalError();
179  }
180  rmsError = sqrt(totalSquaredError / rootNode->getNumEvents());
181 }
std::list< Node * > terminalNodes
Definition: Tree.h:62
double rmsError
Definition: Tree.h:64
T sqrt(T t)
Definition: SSEVec.h:19
int getNumEvents()
Definition: Node.cc:128
Node * rootNode
Definition: Tree.h:61

◆ copyFrom()

Node * Tree::copyFrom ( const Node local_root)
private

Definition at line 93 of file Tree.cc.

References emtf::Node::getAvgError(), emtf::Node::getErrorReduction(), emtf::Node::getFitValue(), emtf::Node::getLeftDaughter(), emtf::Node::getName(), emtf::Node::getNumEvents(), emtf::Node::getRightDaughter(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), emtf::Node::getTotalError(), and emtf::Node::setParent().

Referenced by operator=(), and Tree().

93  {
94  // end-case
95  if (!local_root)
96  return nullptr;
97 
98  Node* lr = const_cast<Node*>(local_root);
99 
100  // recursion
101  Node* left_new_child = copyFrom(lr->getLeftDaughter());
102  Node* right_new_child = copyFrom(lr->getRightDaughter());
103 
104  // performing main work at this level
105  Node* new_local_root = new Node(lr->getName());
106  if (left_new_child)
107  left_new_child->setParent(new_local_root);
108  if (right_new_child)
109  right_new_child->setParent(new_local_root);
110  new_local_root->setLeftDaughter(left_new_child);
111  new_local_root->setRightDaughter(right_new_child);
112  new_local_root->setErrorReduction(lr->getErrorReduction());
113  new_local_root->setSplitValue(lr->getSplitValue());
114  new_local_root->setSplitVariable(lr->getSplitVariable());
115  new_local_root->setFitValue(lr->getFitValue());
116  new_local_root->setTotalError(lr->getTotalError());
117  new_local_root->setAvgError(lr->getAvgError());
118  new_local_root->setNumEvents(lr->getNumEvents());
119  // new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
120 
121  return new_local_root;
122 }
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
double getFitValue()
Definition: Node.cc:112
TGeoNode Node
double getTotalError()
Definition: Node.cc:118
double getSplitValue()
Definition: Node.cc:102
std::string getName()
Definition: Node.cc:74
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:93
void setParent(Node *sParent)
Definition: Node.cc:94
int getNumEvents()
Definition: Node.cc:128
double getErrorReduction()
Definition: Node.cc:80
double getAvgError()
Definition: Node.cc:122
int getSplitVariable()
Definition: Node.cc:106

◆ filterEvent()

Node * Tree::filterEvent ( Event e)

Definition at line 276 of file Tree.cc.

References MillePedeFileConverter_cfg::e, filterEventRecursive(), and rootNode.

276  {
277  // Use trees which have already been built to fit a bunch of events
278  // given by the tEvents vector.
279 
280  // Filter the event into a predictive region (terminal node).
282  return node;
283 }
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:287
Node * rootNode
Definition: Tree.h:61

◆ filterEventRecursive()

Node * Tree::filterEventRecursive ( Node node,
Event e 
)

Definition at line 287 of file Tree.cc.

References MillePedeFileConverter_cfg::e, and emtf::Node::filterEventToDaughter().

Referenced by filterEvent().

287  {
288  // Filter the event repeatedly into the daughter nodes until it
289  // falls into a terminal node.
290 
291  Node* nextNode = node->filterEventToDaughter(e);
292  if (nextNode == nullptr)
293  return node;
294 
295  return filterEventRecursive(nextNode, e);
296 }
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:287
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:314

◆ filterEvents()

void Tree::filterEvents ( std::vector< Event *> &  tEvents)

Definition at line 243 of file Tree.cc.

References filterEventsRecursive(), emtf::Node::getEvents(), and rootNode.

243  {
244  // Use trees which have already been built to fit a bunch of events
245  // given by the tEvents vector.
246 
247  // Set the events to be filtered.
248  rootNode->getEvents() = std::vector<std::vector<Event*>>(1);
249  rootNode->getEvents()[0] = tEvents;
250 
251  // The tree now knows about the events it needs to fit.
252  // Filter them into a predictive region (terminal node).
254 }
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:132
void filterEventsRecursive(Node *node)
Definition: Tree.cc:258
Node * rootNode
Definition: Tree.h:61

◆ filterEventsRecursive()

void Tree::filterEventsRecursive ( Node node)

Definition at line 258 of file Tree.cc.

References emtf::Node::filterEventsToDaughters(), emtf::Node::getLeftDaughter(), and emtf::Node::getRightDaughter().

Referenced by filterEvents().

258  {
259  // Filter the events repeatedly into the daughter nodes until they
260  // fall into a terminal node.
261 
262  Node* left = node->getLeftDaughter();
263  Node* right = node->getRightDaughter();
264 
265  if (left == nullptr || right == nullptr)
266  return;
267 
268  node->filterEventsToDaughters();
269 
270  filterEventsRecursive(left);
271  filterEventsRecursive(right);
272 }
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
void filterEventsRecursive(Node *node)
Definition: Tree.cc:258
void filterEventsToDaughters()
Definition: Node.cc:267

◆ findLeafs()

void Tree::findLeafs ( Node local_root,
std::list< Node *> &  tn 
)
private

Definition at line 124 of file Tree.cc.

References emtf::Node::getLeftDaughter(), and emtf::Node::getRightDaughter().

Referenced by operator=(), and Tree().

124  {
125  if (!local_root->getLeftDaughter() && !local_root->getRightDaughter()) {
126  // leaf or ternimal node found
127  tn.push_back(local_root);
128  return;
129  }
130 
131  if (local_root->getLeftDaughter())
132  findLeafs(local_root->getLeftDaughter(), tn);
133 
134  if (local_root->getRightDaughter())
135  findLeafs(local_root->getRightDaughter(), tn);
136 }
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
void findLeafs(Node *local_root, std::list< Node *> &tn)
Definition: Tree.cc:124

◆ getBoostWeight()

double emtf::Tree::getBoostWeight ( void  ) const
inline

Definition at line 57 of file Tree.h.

References boostWeight.

Referenced by L1TMuonEndCapForestESProducer::produce().

57 { return boostWeight; }
double boostWeight
Definition: Tree.h:65

◆ getNumTerminalNodes()

int Tree::getNumTerminalNodes ( )

Definition at line 165 of file Tree.cc.

References numTerminalNodes.

165 { return numTerminalNodes; }
int numTerminalNodes
Definition: Tree.h:63

◆ getRootNode()

Node * Tree::getRootNode ( )

Definition at line 155 of file Tree.cc.

References rootNode.

Referenced by operator=(), L1TMuonEndCapForestESProducer::produce(), and Tree().

155 { return rootNode; }
Node * rootNode
Definition: Tree.h:61

◆ getSplitValues()

void Tree::getSplitValues ( std::vector< std::vector< double >> &  v)

Definition at line 364 of file Tree.cc.

References getSplitValuesRecursive(), rootNode, and findQualityFiles::v.

Node * rootNode
Definition: Tree.h:61
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:336

◆ getSplitValuesRecursive()

void Tree::getSplitValuesRecursive ( Node node,
std::vector< std::vector< double >> &  v 
)

Definition at line 336 of file Tree.cc.

References gather_cfg::cout, emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), emtf::Node::getSplitValue(), emtf::Node::getSplitVariable(), pfDeepBoostedJetPreprocessParams_cfi::sv, and findQualityFiles::v.

Referenced by getSplitValues().

336  {
337  // We recursively go through all of the nodes in the tree and find the
338  // split points used for each split variable.
339 
340  Node* left = node->getLeftDaughter();
341  Node* right = node->getRightDaughter();
342 
343  // Terminal nodes don't contribute.
344  if (left == nullptr || right == nullptr)
345  return;
346 
347  int sv = node->getSplitVariable();
348  double sp = node->getSplitValue();
349 
350  if (sv == -1) {
351  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
352  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
353  }
354 
355  // Add the split point to the list for the correct split variable.
356  v[sv].push_back(sp);
357 
358  getSplitValuesRecursive(left, v);
359  getSplitValuesRecursive(right, v);
360 }
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
double getSplitValue()
Definition: Node.cc:102
int getSplitVariable()
Definition: Node.cc:106
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:336

◆ getTerminalNodes()

std::list< Node * > & Tree::getTerminalNodes ( )

Definition at line 161 of file Tree.cc.

References terminalNodes.

161 { return terminalNodes; }
std::list< Node * > terminalNodes
Definition: Tree.h:62

◆ loadFromCondPayload()

void Tree::loadFromCondPayload ( const L1TMuonEndCapForest::DTree tree)

Definition at line 550 of file Tree.cc.

References loadFromCondPayloadRecursive(), and rootNode.

550  {
551  // start fresh in case this is not the only call to construct a tree
552  if (rootNode)
553  delete rootNode;
554  rootNode = new Node("root");
555 
556  const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
558 }
TGeoNode Node
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:560
Node * rootNode
Definition: Tree.h:61
Definition: tree.py:1

◆ loadFromCondPayloadRecursive()

void Tree::loadFromCondPayloadRecursive ( const L1TMuonEndCapForest::DTree tree,
const L1TMuonEndCapForest::DTreeNode node,
Node tnode 
)

Definition at line 560 of file Tree.cc.

References L1TMuonEndCapForest::DTreeNode::fitVal, emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), L1TMuonEndCapForest::DTreeNode::ileft, L1TMuonEndCapForest::DTreeNode::iright, numTerminalNodes, emtf::Node::setFitValue(), emtf::Node::setSplitValue(), emtf::Node::setSplitVariable(), L1TMuonEndCapForest::DTreeNode::splitVal, L1TMuonEndCapForest::DTreeNode::splitVar, terminalNodes, and emtf::Node::theMiracleOfChildBirth().

Referenced by loadFromCondPayload().

562  {
563  // Store gathered splitInfo into the node object.
564  tnode->setSplitVariable(node.splitVar);
565  tnode->setSplitValue(node.splitVal);
566  tnode->setFitValue(node.fitVal);
567 
568  // If there are no daughters we are done.
569  if (node.ileft == 0 || node.iright == 0)
570  return; // root cannot be anyone's child
571  if (node.ileft >= tree.size() || node.iright >= tree.size())
572  return; // out of range addressing on purpose
573 
574  // If there are daughters link the node objects appropriately.
575  tnode->theMiracleOfChildBirth();
576  Node* tleft = tnode->getLeftDaughter();
577  Node* tright = tnode->getRightDaughter();
578 
579  // Update the list of terminal nodes.
580  terminalNodes.remove(tnode);
581  terminalNodes.push_back(tleft);
582  terminalNodes.push_back(tright);
584 
587 }
void setFitValue(double sFitValue)
Definition: Node.cc:110
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
std::list< Node * > terminalNodes
Definition: Tree.h:62
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:560
int numTerminalNodes
Definition: Tree.h:63
void setSplitVariable(int sSplitVar)
Definition: Node.cc:104
Definition: tree.py:1
void setSplitValue(double sSplitValue)
Definition: Node.cc:100
void theMiracleOfChildBirth()
Definition: Node.cc:253

◆ loadFromXML()

void Tree::loadFromXML ( const char *  filename)

Definition at line 437 of file Tree.cc.

References boostWeight, corrVsCorr::filename, loadFromXMLRecursive(), rootNode, AlCaHLTBitMon_QueryRunRegistry::string, ExtractAppInfoFromXML::xmldoc, and xmlVersion.

437  {
438  // First create the engine.
439  TXMLEngine* xml = new TXMLEngine;
440 
441  // Now try to parse xml file.
442  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
443  if (xmldoc == nullptr) {
444  delete xml;
445  return;
446  }
447 
448  // Get access to main node of the xml file.
449  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
450 
451  // the original 2016 pT xmls define the source tree node to be the top-level xml node
452  // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
453  // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
454  if (std::string("BinaryTree") == xml->GetNodeName(mainnode)) {
455  XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
456  while (std::string("boostWeight") != xml->GetAttrName(attr)) {
457  attr = xml->GetNextAttr(attr);
458  }
459  boostWeight = (attr ? strtod(xml->GetAttrValue(attr), nullptr) : 0);
460  // step inside the top-level xml node
461  mainnode = xml->GetChild(mainnode);
462  xmlVersion = 2017;
463  } else {
464  boostWeight = 0;
465  xmlVersion = 2016;
466  }
467  // Recursively connect nodes together.
468  loadFromXMLRecursive(xml, mainnode, rootNode);
469 
470  // Release memory before exit
471  xml->FreeDoc(xmldoc);
472  delete xml;
473 }
xmldoc
Some module&#39;s global variables.
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:477
double boostWeight
Definition: Tree.h:65
Node * rootNode
Definition: Tree.h:61
unsigned xmlVersion
Definition: Tree.h:66

◆ loadFromXMLRecursive()

void Tree::loadFromXMLRecursive ( TXMLEngine *  xml,
XMLNodePointer_t  node,
Node tnode 
)

Definition at line 477 of file Tree.cc.

References emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), mps_fire::i, numTerminalNodes, emtf::Node::setFitValue(), emtf::Node::setSplitValue(), emtf::Node::setSplitVariable(), AlCaHLTBitMon_QueryRunRegistry::string, terminalNodes, emtf::Node::theMiracleOfChildBirth(), and xmlVersion.

Referenced by loadFromXML().

477  {
478  // Get the split information from xml.
479  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
480  std::vector<std::string> splitInfo(3);
481  if (xmlVersion >= 2017) {
482  for (unsigned int i = 0; i < 10; i++) {
483  if (std::string("IVar") == xml->GetAttrName(attr)) {
484  splitInfo[0] = xml->GetAttrValue(attr);
485  }
486  if (std::string("Cut") == xml->GetAttrName(attr)) {
487  splitInfo[1] = xml->GetAttrValue(attr);
488  }
489  if (std::string("res") == xml->GetAttrName(attr)) {
490  splitInfo[2] = xml->GetAttrValue(attr);
491  }
492  attr = xml->GetNextAttr(attr);
493  }
494  } else {
495  for (unsigned int i = 0; i < 3; i++) {
496  splitInfo[i] = xml->GetAttrValue(attr);
497  attr = xml->GetNextAttr(attr);
498  }
499  }
500 
501  // Convert strings into numbers.
502  std::stringstream converter;
503  int splitVar;
504  double splitVal;
505  double fitVal;
506 
507  converter << splitInfo[0];
508  converter >> splitVar;
509  converter.str("");
510  converter.clear();
511 
512  converter << splitInfo[1];
513  converter >> splitVal;
514  converter.str("");
515  converter.clear();
516 
517  converter << splitInfo[2];
518  converter >> fitVal;
519  converter.str("");
520  converter.clear();
521 
522  // Store gathered splitInfo into the node object.
523  tnode->setSplitVariable(splitVar);
524  tnode->setSplitValue(splitVal);
525  tnode->setFitValue(fitVal);
526 
527  // Get the xml daughters of the current xml node.
528  XMLNodePointer_t xleft = xml->GetChild(xnode);
529  XMLNodePointer_t xright = xml->GetNext(xleft);
530 
531  // If there are no daughters we are done.
532  if (xleft == nullptr || xright == nullptr)
533  return;
534 
535  // If there are daughters link the node objects appropriately.
536  tnode->theMiracleOfChildBirth();
537  Node* tleft = tnode->getLeftDaughter();
538  Node* tright = tnode->getRightDaughter();
539 
540  // Update the list of terminal nodes.
541  terminalNodes.remove(tnode);
542  terminalNodes.push_back(tleft);
543  terminalNodes.push_back(tright);
545 
546  loadFromXMLRecursive(xml, xleft, tleft);
547  loadFromXMLRecursive(xml, xright, tright);
548 }
void setFitValue(double sFitValue)
Definition: Node.cc:110
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
std::list< Node * > terminalNodes
Definition: Tree.h:62
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:477
int numTerminalNodes
Definition: Tree.h:63
void setSplitVariable(int sSplitVar)
Definition: Node.cc:104
unsigned xmlVersion
Definition: Tree.h:66
void setSplitValue(double sSplitValue)
Definition: Node.cc:100
void theMiracleOfChildBirth()
Definition: Node.cc:253

◆ operator=()

Tree & Tree::operator= ( const Tree tree)

if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();

Definition at line 74 of file Tree.cc.

References boostWeight, copyFrom(), findLeafs(), getRootNode(), numTerminalNodes, rmsError, rootNode, terminalNodes, and xmlVersion.

74  {
75  if (rootNode)
76  delete rootNode;
77  // unfortunately, authors of these classes didn't use const qualifiers
78  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
79  numTerminalNodes = tree.numTerminalNodes;
80  rmsError = tree.rmsError;
81  boostWeight = tree.boostWeight;
82  xmlVersion = tree.xmlVersion;
83 
84  terminalNodes.resize(0);
85  // find new leafs
87 
89 
90  return *this;
91 }
std::list< Node * > terminalNodes
Definition: Tree.h:62
double rmsError
Definition: Tree.h:64
Node * getRootNode()
Definition: Tree.cc:155
void findLeafs(Node *local_root, std::list< Node *> &tn)
Definition: Tree.cc:124
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:93
int numTerminalNodes
Definition: Tree.h:63
double boostWeight
Definition: Tree.h:65
Node * rootNode
Definition: Tree.h:61
unsigned xmlVersion
Definition: Tree.h:66
Definition: tree.py:1

◆ rankVariables()

void Tree::rankVariables ( std::vector< double > &  v)

Definition at line 332 of file Tree.cc.

References rankVariablesRecursive(), rootNode, and findQualityFiles::v.

void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:300
Node * rootNode
Definition: Tree.h:61

◆ rankVariablesRecursive()

void Tree::rankVariablesRecursive ( Node node,
std::vector< double > &  v 
)

Definition at line 300 of file Tree.cc.

References emtf::Node::getErrorReduction(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), emtf::Node::getSplitVariable(), pfDeepBoostedJetPreprocessParams_cfi::sv, and findQualityFiles::v.

Referenced by rankVariables().

300  {
301  // We recursively go through all of the nodes in the tree and find the
302  // total error reduction for each variable. The one with the most
303  // error reduction should be the most important.
304 
305  Node* left = node->getLeftDaughter();
306  Node* right = node->getRightDaughter();
307 
308  // Terminal nodes don't contribute to error reduction.
309  if (left == nullptr || right == nullptr)
310  return;
311 
312  int sv = node->getSplitVariable();
313  double er = node->getErrorReduction();
314 
315  //if(sv == -1)
316  //{
317  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
318  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
319  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
320  //}
321 
322  // Add error reduction to the current total for the appropriate
323  // variable.
324  v[sv] += er;
325 
326  rankVariablesRecursive(left, v);
327  rankVariablesRecursive(right, v);
328 }
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:300
double getErrorReduction()
Definition: Node.cc:80
int getSplitVariable()
Definition: Node.cc:106

◆ saveToXML()

void Tree::saveToXML ( const char *  filename)

Definition at line 391 of file Tree.cc.

References addXMLAttributes(), HltBtagPostValidation_cff::c, emtf::Node::getName(), rootNode, saveToXMLRecursive(), and ExtractAppInfoFromXML::xmldoc.

391  {
392  TXMLEngine* xml = new TXMLEngine();
393 
394  // Add the root node.
395  XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
397 
398  // Recursively write the tree to XML.
400 
401  // Make the XML Document.
402  XMLDocPointer_t xmldoc = xml->NewDoc();
403  xml->DocSetRootElement(xmldoc, root);
404 
405  // Save to file.
406  xml->SaveDoc(xmldoc, c);
407 
408  // Clean up.
409  xml->FreeDoc(xmldoc);
410  delete xml;
411 }
xmldoc
Some module&#39;s global variables.
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:381
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:415
std::string getName()
Definition: Node.cc:74
Node * rootNode
Definition: Tree.h:61

◆ saveToXMLRecursive()

void Tree::saveToXMLRecursive ( TXMLEngine *  xml,
Node node,
XMLNodePointer_t  np 
)

Definition at line 415 of file Tree.cc.

References addXMLAttributes(), emtf::Node::getLeftDaughter(), emtf::Node::getRightDaughter(), cmsLHEtoEOSManager::l, and np.

Referenced by saveToXML().

415  {
416  Node* l = node->getLeftDaughter();
417  Node* r = node->getRightDaughter();
418 
419  if (l == nullptr || r == nullptr)
420  return;
421 
422  // Add children to the XMLEngine.
423  XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
424  XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
425 
426  // Add attributes to the children.
427  addXMLAttributes(xml, l, left);
428  addXMLAttributes(xml, r, right);
429 
430  // Recurse.
431  saveToXMLRecursive(xml, l, left);
432  saveToXMLRecursive(xml, r, right);
433 }
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:381
int np
Definition: AMPTWrapper.h:43
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:415

◆ setBoostWeight()

void emtf::Tree::setBoostWeight ( double  wgt)
inline

Definition at line 58 of file Tree.h.

References boostWeight.

58 { boostWeight = wgt; }
double boostWeight
Definition: Tree.h:65

◆ setRootNode()

void Tree::setRootNode ( Node sRootNode)

Definition at line 153 of file Tree.cc.

References rootNode.

153 { rootNode = sRootNode; }
Node * rootNode
Definition: Tree.h:61

◆ setTerminalNodes()

void Tree::setTerminalNodes ( std::list< Node *> &  sTNodes)

Definition at line 159 of file Tree.cc.

References terminalNodes.

159 { terminalNodes = sTNodes; }
std::list< Node * > terminalNodes
Definition: Tree.h:62

Member Data Documentation

◆ boostWeight

double emtf::Tree::boostWeight
private

Definition at line 65 of file Tree.h.

Referenced by getBoostWeight(), loadFromXML(), operator=(), setBoostWeight(), and Tree().

◆ numTerminalNodes

int emtf::Tree::numTerminalNodes
private

◆ rmsError

double emtf::Tree::rmsError
private

Definition at line 64 of file Tree.h.

Referenced by calcError(), operator=(), and Tree().

◆ rootNode

Node* emtf::Tree::rootNode
private

◆ terminalNodes

std::list<Node*> emtf::Tree::terminalNodes
private

◆ xmlVersion

unsigned emtf::Tree::xmlVersion
private

Definition at line 66 of file Tree.h.

Referenced by loadFromXML(), loadFromXMLRecursive(), operator=(), and Tree().