CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
Tree.cc
Go to the documentation of this file.
1 // Tree.cxx //
3 // =====================================================================//
4 // This is the object implementation of a decision tree. //
5 // References include //
6 // *Elements of Statistical Learning by Hastie, //
7 // Tibshirani, and Friedman. //
8 // *Greedy Function Approximation: A Gradient Boosting Machine. //
9 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
10 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
11 // //
13 
15 // _______________________Includes_______________________________________//
17 
19 
20 #include <iostream>
21 #include <sstream>
22 #include <cmath>
23 
25 // _______________________Constructor(s)________________________________//
27 
28 using namespace emtf;
29 
31  rootNode = new Node("root");
32 
33  terminalNodes.push_back(rootNode);
34  numTerminalNodes = 1;
35  boostWeight = 0;
36  xmlVersion = 2017;
37 }
38 
39 Tree::Tree(std::vector<std::vector<Event*>>& cEvents) {
40  rootNode = new Node("root");
41  rootNode->setEvents(cEvents);
42 
43  terminalNodes.push_back(rootNode);
44  numTerminalNodes = 1;
45  boostWeight = 0;
46  xmlVersion = 2017;
47 }
49 // _______________________Destructor____________________________________//
51 
53  // When the tree is destroyed it will delete all of the nodes in the tree.
54  // The deletion begins with the rootnode and continues recursively.
55  if (rootNode)
56  delete rootNode;
57 }
58 
59 Tree::Tree(const Tree& tree) {
60  // unfortunately, authors of these classes didn't use const qualifiers
61  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
63  rmsError = tree.rmsError;
64  boostWeight = tree.boostWeight;
65  xmlVersion = tree.xmlVersion;
66 
67  terminalNodes.resize(0);
68  // find new leafs
70 
72 }
73 
75  if (rootNode)
76  delete rootNode;
77  // unfortunately, authors of these classes didn't use const qualifiers
78  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
80  rmsError = tree.rmsError;
81  boostWeight = tree.boostWeight;
82  xmlVersion = tree.xmlVersion;
83 
84  terminalNodes.resize(0);
85  // find new leafs
87 
89 
90  return *this;
91 }
92 
93 Node* Tree::copyFrom(const Node* local_root) {
94  // end-case
95  if (!local_root)
96  return nullptr;
97 
98  Node* lr = const_cast<Node*>(local_root);
99 
100  // recursion
101  Node* left_new_child = copyFrom(lr->getLeftDaughter());
102  Node* right_new_child = copyFrom(lr->getRightDaughter());
103 
104  // performing main work at this level
105  Node* new_local_root = new Node(lr->getName());
106  if (left_new_child)
107  left_new_child->setParent(new_local_root);
108  if (right_new_child)
109  right_new_child->setParent(new_local_root);
110  new_local_root->setLeftDaughter(left_new_child);
111  new_local_root->setRightDaughter(right_new_child);
112  new_local_root->setErrorReduction(lr->getErrorReduction());
113  new_local_root->setSplitValue(lr->getSplitValue());
114  new_local_root->setSplitVariable(lr->getSplitVariable());
115  new_local_root->setFitValue(lr->getFitValue());
116  new_local_root->setTotalError(lr->getTotalError());
117  new_local_root->setAvgError(lr->getAvgError());
118  new_local_root->setNumEvents(lr->getNumEvents());
119  // new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
120 
121  return new_local_root;
122 }
123 
124 void Tree::findLeafs(Node* local_root, std::list<Node*>& tn) {
125  if (!local_root->getLeftDaughter() && !local_root->getRightDaughter()) {
126  // leaf or ternimal node found
127  tn.push_back(local_root);
128  return;
129  }
130 
131  if (local_root->getLeftDaughter())
132  findLeafs(local_root->getLeftDaughter(), tn);
133 
134  if (local_root->getRightDaughter())
135  findLeafs(local_root->getRightDaughter(), tn);
136 }
137 
139  if (rootNode)
140  delete rootNode; // this line is the only reason not to use default move constructor
141  rootNode = tree.rootNode;
142  terminalNodes = std::move(tree.terminalNodes);
143  numTerminalNodes = tree.numTerminalNodes;
144  rmsError = tree.rmsError;
145  boostWeight = tree.boostWeight;
146  xmlVersion = tree.xmlVersion;
147 }
148 
150 // ______________________Get/Set________________________________________//
152 
153 void Tree::setRootNode(Node* sRootNode) { rootNode = sRootNode; }
154 
156 
157 // ----------------------------------------------------------------------
158 
159 void Tree::setTerminalNodes(std::list<Node*>& sTNodes) { terminalNodes = sTNodes; }
160 
161 std::list<Node*>& Tree::getTerminalNodes() { return terminalNodes; }
162 
163 // ----------------------------------------------------------------------
164 
166 
168 // ______________________Performace_____________________________________//
170 
172  // Loop through the separate predictive regions (terminal nodes) and
173  // add up the errors to get the error of the entire space.
174 
175  double totalSquaredError = 0;
176 
177  for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
178  totalSquaredError += (*it)->getTotalError();
179  }
180  rmsError = sqrt(totalSquaredError / rootNode->getNumEvents());
181 }
182 
183 // ----------------------------------------------------------------------
184 
185 void Tree::buildTree(int nodeLimit) {
186  // We greedily pick the best terminal node to split.
187  double bestNodeErrorReduction = -1;
188  Node* nodeToSplit = nullptr;
189 
190  if (numTerminalNodes == 1) {
192  calcError();
193  // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
194  }
195 
196  for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
197  if ((*it)->getErrorReduction() > bestNodeErrorReduction) {
198  bestNodeErrorReduction = (*it)->getErrorReduction();
199  nodeToSplit = (*it);
200  }
201  }
202 
203  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
204 
205  // If all of the nodes have one event we can't add any more nodes and reduce the error.
206  if (nodeToSplit == nullptr)
207  return;
208 
209  // Create daughter nodes, and link the nodes together appropriately.
210  nodeToSplit->theMiracleOfChildBirth();
211 
212  // Get left and right daughters for reference.
213  Node* left = nodeToSplit->getLeftDaughter();
214  Node* right = nodeToSplit->getRightDaughter();
215 
216  // Update the list of terminal nodes.
217  terminalNodes.remove(nodeToSplit);
218  terminalNodes.push_back(left);
219  terminalNodes.push_back(right);
221 
222  // Filter the events from the parent into the daughters.
223  nodeToSplit->filterEventsToDaughters();
224 
225  // Calculate the best splits for the new nodes.
226  left->calcOptimumSplit();
227  right->calcOptimumSplit();
228 
229  // See if the error reduces as we add more nodes.
230  calcError();
231 
232  if (numTerminalNodes % 1 == 0) {
233  // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
234  }
235 
236  // Repeat until done.
237  if (numTerminalNodes < nodeLimit)
238  buildTree(nodeLimit);
239 }
240 
241 // ----------------------------------------------------------------------
242 
243 void Tree::filterEvents(std::vector<Event*>& tEvents) {
244  // Use trees which have already been built to fit a bunch of events
245  // given by the tEvents vector.
246 
247  // Set the events to be filtered.
248  rootNode->getEvents() = std::vector<std::vector<Event*>>(1);
249  rootNode->getEvents()[0] = tEvents;
250 
251  // The tree now knows about the events it needs to fit.
252  // Filter them into a predictive region (terminal node).
254 }
255 
256 // ----------------------------------------------------------------------
257 
259  // Filter the events repeatedly into the daughter nodes until they
260  // fall into a terminal node.
261 
262  Node* left = node->getLeftDaughter();
263  Node* right = node->getRightDaughter();
264 
265  if (left == nullptr || right == nullptr)
266  return;
267 
268  node->filterEventsToDaughters();
269 
270  filterEventsRecursive(left);
271  filterEventsRecursive(right);
272 }
273 
274 // ----------------------------------------------------------------------
275 
277  // Use trees which have already been built to fit a bunch of events
278  // given by the tEvents vector.
279 
280  // Filter the event into a predictive region (terminal node).
281  Node* node = filterEventRecursive(rootNode, e);
282  return node;
283 }
284 
285 // ----------------------------------------------------------------------
286 
288  // Filter the event repeatedly into the daughter nodes until it
289  // falls into a terminal node.
290 
291  Node* nextNode = node->filterEventToDaughter(e);
292  if (nextNode == nullptr)
293  return node;
294 
295  return filterEventRecursive(nextNode, e);
296 }
297 
298 // ----------------------------------------------------------------------
299 
300 void Tree::rankVariablesRecursive(Node* node, std::vector<double>& v) {
301  // We recursively go through all of the nodes in the tree and find the
302  // total error reduction for each variable. The one with the most
303  // error reduction should be the most important.
304 
305  Node* left = node->getLeftDaughter();
306  Node* right = node->getRightDaughter();
307 
308  // Terminal nodes don't contribute to error reduction.
309  if (left == nullptr || right == nullptr)
310  return;
311 
312  int sv = node->getSplitVariable();
313  double er = node->getErrorReduction();
314 
315  //if(sv == -1)
316  //{
317  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
318  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
319  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
320  //}
321 
322  // Add error reduction to the current total for the appropriate
323  // variable.
324  v[sv] += er;
325 
326  rankVariablesRecursive(left, v);
327  rankVariablesRecursive(right, v);
328 }
329 
330 // ----------------------------------------------------------------------
331 
332 void Tree::rankVariables(std::vector<double>& v) { rankVariablesRecursive(rootNode, v); }
333 
334 // ----------------------------------------------------------------------
335 
336 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<double>>& v) {
337  // We recursively go through all of the nodes in the tree and find the
338  // split points used for each split variable.
339 
340  Node* left = node->getLeftDaughter();
341  Node* right = node->getRightDaughter();
342 
343  // Terminal nodes don't contribute.
344  if (left == nullptr || right == nullptr)
345  return;
346 
347  int sv = node->getSplitVariable();
348  double sp = node->getSplitValue();
349 
350  if (sv == -1) {
351  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
352  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
353  }
354 
355  // Add the split point to the list for the correct split variable.
356  v[sv].push_back(sp);
357 
358  getSplitValuesRecursive(left, v);
359  getSplitValuesRecursive(right, v);
360 }
361 
362 // ----------------------------------------------------------------------
363 
365 
367 // ______________________Storage/Retrieval______________________________//
369 
370 template <typename T>
372  // Convert a number to a string.
373  std::stringstream ss;
374  ss << num;
375  std::string s = ss.str();
376  return s;
377 }
378 
379 // ----------------------------------------------------------------------
380 
381 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
382  // Convert Node members into XML attributes
383  // and add them to the XMLEngine.
384  xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
385  xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
386  xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
387 }
388 
389 // ----------------------------------------------------------------------
390 
391 void Tree::saveToXML(const char* c) {
392  TXMLEngine* xml = new TXMLEngine();
393 
394  // Add the root node.
395  XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
396  addXMLAttributes(xml, rootNode, root);
397 
398  // Recursively write the tree to XML.
399  saveToXMLRecursive(xml, rootNode, root);
400 
401  // Make the XML Document.
402  XMLDocPointer_t xmldoc = xml->NewDoc();
403  xml->DocSetRootElement(xmldoc, root);
404 
405  // Save to file.
406  xml->SaveDoc(xmldoc, c);
407 
408  // Clean up.
409  xml->FreeDoc(xmldoc);
410  delete xml;
411 }
412 
413 // ----------------------------------------------------------------------
414 
415 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
416  Node* l = node->getLeftDaughter();
417  Node* r = node->getRightDaughter();
418 
419  if (l == nullptr || r == nullptr)
420  return;
421 
422  // Add children to the XMLEngine.
423  XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
424  XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
425 
426  // Add attributes to the children.
427  addXMLAttributes(xml, l, left);
428  addXMLAttributes(xml, r, right);
429 
430  // Recurse.
431  saveToXMLRecursive(xml, l, left);
432  saveToXMLRecursive(xml, r, right);
433 }
434 
435 // ----------------------------------------------------------------------
436 
437 void Tree::loadFromXML(const char* filename) {
438  // First create the engine.
439  TXMLEngine* xml = new TXMLEngine;
440 
441  // Now try to parse xml file.
442  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
443  if (xmldoc == nullptr) {
444  delete xml;
445  return;
446  }
447 
448  // Get access to main node of the xml file.
449  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
450 
451  // the original 2016 pT xmls define the source tree node to be the top-level xml node
452  // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
453  // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
454  if (std::string("BinaryTree") == xml->GetNodeName(mainnode)) {
455  XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
456  attr = xml->GetNextAttr(attr);
457  boostWeight = (attr ? strtod(xml->GetAttrValue(attr), nullptr) : 0);
458  // step inside the top-level xml node
459  mainnode = xml->GetChild(mainnode);
460  xmlVersion = 2017;
461  } else {
462  boostWeight = 0;
463  xmlVersion = 2016;
464  }
465  // Recursively connect nodes together.
466  loadFromXMLRecursive(xml, mainnode, rootNode);
467 
468  // Release memory before exit
469  xml->FreeDoc(xmldoc);
470  delete xml;
471 }
472 
473 // ----------------------------------------------------------------------
474 
475 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode) {
476  // Get the split information from xml.
477  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
478  std::vector<std::string> splitInfo(3);
479  if (xmlVersion >= 2017) {
480  for (unsigned int i = 0, j = 0; i < 10; i++) {
481  if (i == 3 || i == 4 || i == 6) {
482  splitInfo[j++] = xml->GetAttrValue(attr);
483  }
484  attr = xml->GetNextAttr(attr);
485  }
486  } else {
487  for (unsigned int i = 0; i < 3; i++) {
488  splitInfo[i] = xml->GetAttrValue(attr);
489  attr = xml->GetNextAttr(attr);
490  }
491  }
492 
493  // Convert strings into numbers.
494  std::stringstream converter;
495  int splitVar;
496  double splitVal;
497  double fitVal;
498 
499  converter << splitInfo[0];
500  converter >> splitVar;
501  converter.str("");
502  converter.clear();
503 
504  converter << splitInfo[1];
505  converter >> splitVal;
506  converter.str("");
507  converter.clear();
508 
509  converter << splitInfo[2];
510  converter >> fitVal;
511  converter.str("");
512  converter.clear();
513 
514  // Store gathered splitInfo into the node object.
515  tnode->setSplitVariable(splitVar);
516  tnode->setSplitValue(splitVal);
517  tnode->setFitValue(fitVal);
518 
519  // Get the xml daughters of the current xml node.
520  XMLNodePointer_t xleft = xml->GetChild(xnode);
521  XMLNodePointer_t xright = xml->GetNext(xleft);
522 
523  // If there are no daughters we are done.
524  if (xleft == nullptr || xright == nullptr)
525  return;
526 
527  // If there are daughters link the node objects appropriately.
528  tnode->theMiracleOfChildBirth();
529  Node* tleft = tnode->getLeftDaughter();
530  Node* tright = tnode->getRightDaughter();
531 
532  // Update the list of terminal nodes.
533  terminalNodes.remove(tnode);
534  terminalNodes.push_back(tleft);
535  terminalNodes.push_back(tright);
537 
538  loadFromXMLRecursive(xml, xleft, tleft);
539  loadFromXMLRecursive(xml, xright, tright);
540 }
541 
543  // start fresh in case this is not the only call to construct a tree
544  if (rootNode)
545  delete rootNode;
546  rootNode = new Node("root");
547 
548  const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
549  loadFromCondPayloadRecursive(tree, mainnode, rootNode);
550 }
551 
553  const L1TMuonEndCapForest::DTreeNode& node,
554  Node* tnode) {
555  // Store gathered splitInfo into the node object.
556  tnode->setSplitVariable(node.splitVar);
557  tnode->setSplitValue(node.splitVal);
558  tnode->setFitValue(node.fitVal);
559 
560  // If there are no daughters we are done.
561  if (node.ileft == 0 || node.iright == 0)
562  return; // root cannot be anyone's child
563  if (node.ileft >= tree.size() || node.iright >= tree.size())
564  return; // out of range addressing on purpose
565 
566  // If there are daughters link the node objects appropriately.
567  tnode->theMiracleOfChildBirth();
568  Node* tleft = tnode->getLeftDaughter();
569  Node* tright = tnode->getRightDaughter();
570 
571  // Update the list of terminal nodes.
572  terminalNodes.remove(tnode);
573  terminalNodes.push_back(tleft);
574  terminalNodes.push_back(tright);
576 
577  loadFromCondPayloadRecursive(tree, tree[node.ileft], tleft);
578  loadFromCondPayloadRecursive(tree, tree[node.iright], tright);
579 }
void setFitValue(double sFitValue)
Definition: Node.cc:110
Node * getRightDaughter()
Definition: Node.cc:90
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:287
Node * getLeftDaughter()
Definition: Node.cc:86
double getFitValue()
Definition: Node.cc:112
const edm::EventSetup & c
std::list< Node * > terminalNodes
Definition: Tree.h:62
void buildTree(int nodeLimit)
Definition: Tree.cc:185
void getSplitValues(std::vector< std::vector< double >> &v)
Definition: Tree.cc:364
Tree()
Definition: Tree.cc:30
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:475
double rmsError
Definition: Tree.h:64
TGeoNode Node
int getNumTerminalNodes()
Definition: Tree.cc:165
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:381
Node * getRootNode()
Definition: Tree.cc:155
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:552
string xmldoc
Some module&#39;s global variables.
Tree & operator=(const Tree &tree)
Definition: Tree.cc:74
void setTerminalNodes(std::list< Node * > &sTNodes)
Definition: Tree.cc:159
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:124
int np
Definition: AMPTWrapper.h:43
double getTotalError()
Definition: Node.cc:118
void loadFromCondPayload(const L1TMuonEndCapForest::DTree &tree)
Definition: Tree.cc:542
T sqrt(T t)
Definition: SSEVec.h:19
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:314
void rankVariables(std::vector< double > &v)
Definition: Tree.cc:332
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:300
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:134
def move
Definition: eostools.py:511
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:132
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:415
double getSplitValue()
Definition: Node.cc:102
std::string getName()
Definition: Node.cc:74
std::vector< DTreeNode > DTree
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:161
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:93
void setParent(Node *sParent)
Definition: Node.cc:94
void filterEventsRecursive(Node *node)
Definition: Tree.cc:258
int numTerminalNodes
Definition: Tree.h:63
void calcError()
Definition: Tree.cc:171
void setSplitVariable(int sSplitVar)
Definition: Node.cc:104
double boostWeight
Definition: Tree.h:65
Node * filterEvent(Event *e)
Definition: Tree.cc:276
void loadFromXML(const char *filename)
Definition: Tree.cc:437
void saveToXML(const char *filename)
Definition: Tree.cc:391
int getNumEvents()
Definition: Node.cc:128
Node * rootNode
Definition: Tree.h:61
double getErrorReduction()
Definition: Node.cc:80
unsigned xmlVersion
Definition: Tree.h:66
tuple filename
Definition: lut2db_cfg.py:20
void filterEventsToDaughters()
Definition: Node.cc:267
tuple cout
Definition: gather_cfg.py:144
double getAvgError()
Definition: Node.cc:122
int getSplitVariable()
Definition: Node.cc:106
void calcOptimumSplit()
Definition: Node.cc:143
std::string numToStr(T num)
Definition: Utilities.h:43
void filterEvents(std::vector< Event * > &tEvents)
Definition: Tree.cc:243
long double T
void setRootNode(Node *sRootNode)
Definition: Tree.cc:153
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:336
void setSplitValue(double sSplitValue)
Definition: Node.cc:100
~Tree()
Definition: Tree.cc:52
void theMiracleOfChildBirth()
Definition: Node.cc:253