CMS 3D CMS Logo

Tree.cc
Go to the documentation of this file.
1 // Tree.cxx //
3 // =====================================================================//
4 // This is the object implementation of a decision tree. //
5 // References include //
6 // *Elements of Statistical Learning by Hastie, //
7 // Tibshirani, and Friedman. //
8 // *Greedy Function Approximation: A Gradient Boosting Machine. //
9 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
10 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
11 // //
13 
15 // _______________________Includes_______________________________________//
17 
19 
20 #include <iostream>
21 #include <sstream>
22 #include <cmath>
23 
25 // _______________________Constructor(s)________________________________//
27 
28 using namespace emtf;
29 
31 {
32  rootNode = new Node("root");
33 
34  terminalNodes.push_back(rootNode);
35  numTerminalNodes = 1;
36  boostWeight = 0;
37  xmlVersion = 2017;
38 }
39 
40 Tree::Tree(std::vector< std::vector<Event*> >& cEvents)
41 {
42  rootNode = new Node("root");
43  rootNode->setEvents(cEvents);
44 
45  terminalNodes.push_back(rootNode);
46  numTerminalNodes = 1;
47  boostWeight = 0;
48  xmlVersion = 2017;
49 }
51 // _______________________Destructor____________________________________//
53 
54 
56 {
57 // When the tree is destroyed it will delete all of the nodes in the tree.
58 // The deletion begins with the rootnode and continues recursively.
59  if(rootNode) delete rootNode;
60 }
61 
63 {
64  // unfortunately, authors of these classes didn't use const qualifiers
65  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
67  rmsError = tree.rmsError;
68  boostWeight = tree.boostWeight;
69  xmlVersion = tree.xmlVersion;
70 
71  terminalNodes.resize(0);
72  // find new leafs
74 
76 }
77 
79  if(rootNode) delete rootNode;
80  // unfortunately, authors of these classes didn't use const qualifiers
81  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
83  rmsError = tree.rmsError;
84  boostWeight = tree.boostWeight;
85  xmlVersion = tree.xmlVersion;
86 
87  terminalNodes.resize(0);
88  // find new leafs
90 
92 
93  return *this;
94 }
95 
96 Node* Tree::copyFrom(const Node *local_root)
97 {
98  // end-case
99  if( !local_root ) return nullptr;
100 
101  Node *lr = const_cast<Node*>(local_root);
102 
103  // recursion
104  Node *left_new_child = copyFrom( lr->getLeftDaughter() );
105  Node *right_new_child = copyFrom( lr->getRightDaughter() );
106 
107  // performing main work at this level
108  Node *new_local_root = new Node( lr->getName() );
109  if( left_new_child ) left_new_child ->setParent(new_local_root);
110  if( right_new_child ) right_new_child->setParent(new_local_root);
111  new_local_root->setLeftDaughter ( left_new_child );
112  new_local_root->setRightDaughter( right_new_child );
113  new_local_root->setErrorReduction( lr->getErrorReduction() );
114  new_local_root->setSplitValue( lr->getSplitValue() );
115  new_local_root->setSplitVariable( lr->getSplitVariable() );
116  new_local_root->setFitValue( lr->getFitValue() );
117  new_local_root->setTotalError( lr->getTotalError() );
118  new_local_root->setAvgError( lr->getAvgError() );
119  new_local_root->setNumEvents( lr->getNumEvents() );
120 // new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
121 
122  return new_local_root;
123 }
124 
125 void Tree::findLeafs(Node *local_root, std::list<Node*> &tn)
126 {
127  if( !local_root->getLeftDaughter() && !local_root->getRightDaughter() ){
128  // leaf or ternimal node found
129  tn.push_back(local_root);
130  return;
131  }
132 
133  if( local_root->getLeftDaughter() )
134  findLeafs( local_root->getLeftDaughter(), tn );
135 
136  if( local_root->getRightDaughter() )
137  findLeafs( local_root->getRightDaughter(), tn );
138 }
139 
141 {
142  if(rootNode) delete rootNode; // this line is the only reason not to use default move constructor
143  rootNode = tree.rootNode;
144  terminalNodes = std::move(tree.terminalNodes);
145  numTerminalNodes = tree.numTerminalNodes;
146  rmsError = tree.rmsError;
147  boostWeight = tree.boostWeight;
148  xmlVersion = tree.xmlVersion;
149 }
150 
152 // ______________________Get/Set________________________________________//
154 
155 void Tree::setRootNode(Node *sRootNode)
156 {
157  rootNode = sRootNode;
158 }
159 
161 {
162  return rootNode;
163 }
164 
165 // ----------------------------------------------------------------------
166 
167 void Tree::setTerminalNodes(std::list<Node*>& sTNodes)
168 {
169  terminalNodes = sTNodes;
170 }
171 
172 std::list<Node*>& Tree::getTerminalNodes()
173 {
174  return terminalNodes;
175 }
176 
177 // ----------------------------------------------------------------------
178 
180 {
181  return numTerminalNodes;
182 }
183 
185 // ______________________Performace_____________________________________//
187 
189 {
190 // Loop through the separate predictive regions (terminal nodes) and
191 // add up the errors to get the error of the entire space.
192 
193  double totalSquaredError = 0;
194 
195  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
196  {
197  totalSquaredError += (*it)->getTotalError();
198  }
199  rmsError = sqrt( totalSquaredError/rootNode->getNumEvents() );
200 }
201 
202 // ----------------------------------------------------------------------
203 
204 void Tree::buildTree(int nodeLimit)
205 {
206  // We greedily pick the best terminal node to split.
207  double bestNodeErrorReduction = -1;
208  Node* nodeToSplit = nullptr;
209 
210  if(numTerminalNodes == 1)
211  {
213  calcError();
214 // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
215  }
216 
217  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
218  {
219  if( (*it)->getErrorReduction() > bestNodeErrorReduction )
220  {
221  bestNodeErrorReduction = (*it)->getErrorReduction();
222  nodeToSplit = (*it);
223  }
224  }
225 
226  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
227 
228  // If all of the nodes have one event we can't add any more nodes and reduce the error.
229  if(nodeToSplit == nullptr) return;
230 
231  // Create daughter nodes, and link the nodes together appropriately.
232  nodeToSplit->theMiracleOfChildBirth();
233 
234  // Get left and right daughters for reference.
235  Node* left = nodeToSplit->getLeftDaughter();
236  Node* right = nodeToSplit->getRightDaughter();
237 
238  // Update the list of terminal nodes.
239  terminalNodes.remove(nodeToSplit);
240  terminalNodes.push_back(left);
241  terminalNodes.push_back(right);
243 
244  // Filter the events from the parent into the daughters.
245  nodeToSplit->filterEventsToDaughters();
246 
247  // Calculate the best splits for the new nodes.
248  left->calcOptimumSplit();
249  right->calcOptimumSplit();
250 
251  // See if the error reduces as we add more nodes.
252  calcError();
253 
254  if(numTerminalNodes % 1 == 0)
255  {
256 // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
257  }
258 
259  // Repeat until done.
260  if(numTerminalNodes < nodeLimit) buildTree(nodeLimit);
261 }
262 
263 // ----------------------------------------------------------------------
264 
265 void Tree::filterEvents(std::vector<Event*>& tEvents)
266 {
267 // Use trees which have already been built to fit a bunch of events
268 // given by the tEvents vector.
269 
270  // Set the events to be filtered.
271  rootNode->getEvents() = std::vector< std::vector<Event*> >(1);
272  rootNode->getEvents()[0] = tEvents;
273 
274  // The tree now knows about the events it needs to fit.
275  // Filter them into a predictive region (terminal node).
277 }
278 
279 // ----------------------------------------------------------------------
280 
282 {
283 // Filter the events repeatedly into the daughter nodes until they
284 // fall into a terminal node.
285 
286  Node* left = node->getLeftDaughter();
287  Node* right = node->getRightDaughter();
288 
289  if(left == nullptr || right == nullptr) return;
290 
291  node->filterEventsToDaughters();
292 
293  filterEventsRecursive(left);
294  filterEventsRecursive(right);
295 }
296 
297 // ----------------------------------------------------------------------
298 
300 {
301 // Use trees which have already been built to fit a bunch of events
302 // given by the tEvents vector.
303 
304  // Filter the event into a predictive region (terminal node).
305  Node* node = filterEventRecursive(rootNode, e);
306  return node;
307 }
308 
309 // ----------------------------------------------------------------------
310 
312 {
313 // Filter the event repeatedly into the daughter nodes until it
314 // falls into a terminal node.
315 
316 
317  Node* nextNode = node->filterEventToDaughter(e);
318  if(nextNode == nullptr) return node;
319 
320  return filterEventRecursive(nextNode, e);
321 }
322 
323 // ----------------------------------------------------------------------
324 
325 
326 void Tree::rankVariablesRecursive(Node* node, std::vector<double>& v)
327 {
328 // We recursively go through all of the nodes in the tree and find the
329 // total error reduction for each variable. The one with the most
330 // error reduction should be the most important.
331 
332  Node* left = node->getLeftDaughter();
333  Node* right = node->getRightDaughter();
334 
335  // Terminal nodes don't contribute to error reduction.
336  if(left==nullptr || right==nullptr) return;
337 
338  int sv = node->getSplitVariable();
339  double er = node->getErrorReduction();
340 
341  //if(sv == -1)
342  //{
343  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
344  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
345  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
346  //}
347 
348  // Add error reduction to the current total for the appropriate
349  // variable.
350  v[sv] += er;
351 
352  rankVariablesRecursive(left, v);
353  rankVariablesRecursive(right, v);
354 
355 }
356 
357 // ----------------------------------------------------------------------
358 
359 void Tree::rankVariables(std::vector<double>& v)
360 {
362 }
363 
364 // ----------------------------------------------------------------------
365 
366 
367 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<double>>& v)
368 {
369 // We recursively go through all of the nodes in the tree and find the
370 // split points used for each split variable.
371 
372  Node* left = node->getLeftDaughter();
373  Node* right = node->getRightDaughter();
374 
375  // Terminal nodes don't contribute.
376  if(left==nullptr || right==nullptr) return;
377 
378  int sv = node->getSplitVariable();
379  double sp = node->getSplitValue();
380 
381  if(sv == -1)
382  {
383  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
384  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
385  }
386 
387  // Add the split point to the list for the correct split variable.
388  v[sv].push_back(sp);
389 
390  getSplitValuesRecursive(left, v);
391  getSplitValuesRecursive(right, v);
392 
393 }
394 
395 // ----------------------------------------------------------------------
396 
397 void Tree::getSplitValues(std::vector<std::vector<double>>& v)
398 {
400 }
401 
403 // ______________________Storage/Retrieval______________________________//
405 
406 template <typename T>
408 {
409 // Convert a number to a string.
410  std::stringstream ss;
411  ss << num;
412  std::string s = ss.str();
413  return s;
414 }
415 
416 // ----------------------------------------------------------------------
417 
418 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
419 {
420  // Convert Node members into XML attributes
421  // and add them to the XMLEngine.
422  xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
423  xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
424  xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
425 }
426 
427 // ----------------------------------------------------------------------
428 
429 void Tree::saveToXML(const char* c)
430 {
431 
432  TXMLEngine* xml = new TXMLEngine();
433 
434  // Add the root node.
435  XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
436  addXMLAttributes(xml, rootNode, root);
437 
438  // Recursively write the tree to XML.
439  saveToXMLRecursive(xml, rootNode, root);
440 
441  // Make the XML Document.
442  XMLDocPointer_t xmldoc = xml->NewDoc();
443  xml->DocSetRootElement(xmldoc, root);
444 
445  // Save to file.
446  xml->SaveDoc(xmldoc, c);
447 
448  // Clean up.
449  xml->FreeDoc(xmldoc);
450  delete xml;
451 }
452 
453 // ----------------------------------------------------------------------
454 
455 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
456 {
457  Node* l = node->getLeftDaughter();
458  Node* r = node->getRightDaughter();
459 
460  if(l==nullptr || r==nullptr) return;
461 
462  // Add children to the XMLEngine.
463  XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
464  XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
465 
466  // Add attributes to the children.
467  addXMLAttributes(xml, l, left);
468  addXMLAttributes(xml, r, right);
469 
470  // Recurse.
471  saveToXMLRecursive(xml, l, left);
472  saveToXMLRecursive(xml, r, right);
473 }
474 
475 // ----------------------------------------------------------------------
476 
477 void Tree::loadFromXML(const char* filename)
478 {
479  // First create the engine.
480  TXMLEngine* xml = new TXMLEngine;
481 
482  // Now try to parse xml file.
483  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
484  if (xmldoc==nullptr)
485  {
486  delete xml;
487  return;
488  }
489 
490  // Get access to main node of the xml file.
491  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
492 
493  // the original 2016 pT xmls define the source tree node to be the top-level xml node
494  // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
495  // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
496  if( std::string("BinaryTree") == xml->GetNodeName(mainnode) ){
497  XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
498  attr = xml->GetNextAttr(attr);
499  boostWeight = (attr ? strtod(xml->GetAttrValue(attr),nullptr) : 0);
500  // step inside the top-level xml node
501  mainnode = xml->GetChild(mainnode);
502  xmlVersion = 2017;
503  } else {
504  boostWeight = 0;
505  xmlVersion = 2016;
506  }
507  // Recursively connect nodes together.
508  loadFromXMLRecursive(xml, mainnode, rootNode);
509 
510  // Release memory before exit
511  xml->FreeDoc(xmldoc);
512  delete xml;
513 }
514 
515 // ----------------------------------------------------------------------
516 
517 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode)
518 {
519 
520  // Get the split information from xml.
521  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
522  std::vector<std::string> splitInfo(3);
523  if( xmlVersion >= 2017 ){
524  for(unsigned int i=0,j=0; i<10; i++)
525  {
526  if(i==3 || i==4 || i==6){
527  splitInfo[j++] = xml->GetAttrValue(attr);
528  }
529  attr = xml->GetNextAttr(attr);
530  }
531  } else {
532  for(unsigned int i=0; i<3; i++)
533  {
534  splitInfo[i] = xml->GetAttrValue(attr);
535  attr = xml->GetNextAttr(attr);
536  }
537  }
538 
539  // Convert strings into numbers.
540  std::stringstream converter;
541  int splitVar;
542  double splitVal;
543  double fitVal;
544 
545  converter << splitInfo[0];
546  converter >> splitVar;
547  converter.str("");
548  converter.clear();
549 
550  converter << splitInfo[1];
551  converter >> splitVal;
552  converter.str("");
553  converter.clear();
554 
555  converter << splitInfo[2];
556  converter >> fitVal;
557  converter.str("");
558  converter.clear();
559 
560  // Store gathered splitInfo into the node object.
561  tnode->setSplitVariable(splitVar);
562  tnode->setSplitValue(splitVal);
563  tnode->setFitValue(fitVal);
564 
565  // Get the xml daughters of the current xml node.
566  XMLNodePointer_t xleft = xml->GetChild(xnode);
567  XMLNodePointer_t xright = xml->GetNext(xleft);
568 
569  // If there are no daughters we are done.
570  if(xleft == nullptr || xright == nullptr) return;
571 
572  // If there are daughters link the node objects appropriately.
573  tnode->theMiracleOfChildBirth();
574  Node* tleft = tnode->getLeftDaughter();
575  Node* tright = tnode->getRightDaughter();
576 
577  // Update the list of terminal nodes.
578  terminalNodes.remove(tnode);
579  terminalNodes.push_back(tleft);
580  terminalNodes.push_back(tright);
582 
583  loadFromXMLRecursive(xml, xleft, tleft);
584  loadFromXMLRecursive(xml, xright, tright);
585 }
586 
588 {
589  // start fresh in case this is not the only call to construct a tree
590  if( rootNode ) delete rootNode;
591  rootNode = new Node("root");
592 
593  const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
594  loadFromCondPayloadRecursive(tree, mainnode, rootNode);
595 }
596 
598 {
599  // Store gathered splitInfo into the node object.
600  tnode->setSplitVariable(node.splitVar);
601  tnode->setSplitValue(node.splitVal);
602  tnode->setFitValue(node.fitVal);
603 
604  // If there are no daughters we are done.
605  if( node.ileft == 0 || node.iright == 0) return; // root cannot be anyone's child
606  if( node.ileft >= tree.size() ||
607  node.iright >= tree.size() ) return; // out of range addressing on purpose
608 
609  // If there are daughters link the node objects appropriately.
610  tnode->theMiracleOfChildBirth();
611  Node* tleft = tnode->getLeftDaughter();
612  Node* tright = tnode->getRightDaughter();
613 
614  // Update the list of terminal nodes.
615  terminalNodes.remove(tnode);
616  terminalNodes.push_back(tleft);
617  terminalNodes.push_back(tright);
619 
620  loadFromCondPayloadRecursive(tree, tree[node.ileft], tleft);
621  loadFromCondPayloadRecursive(tree, tree[node.iright], tright);
622 }
void setFitValue(double sFitValue)
Definition: Node.cc:153
Node * getRightDaughter()
Definition: Node.cc:112
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:311
Node * getLeftDaughter()
Definition: Node.cc:102
double getFitValue()
Definition: Node.cc:158
std::list< Node * > terminalNodes
Definition: Tree.h:61
void buildTree(int nodeLimit)
Definition: Tree.cc:204
void getSplitValues(std::vector< std::vector< double >> &v)
Definition: Tree.cc:397
Tree()
Definition: Tree.cc:30
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:517
double rmsError
Definition: Tree.h:63
int getNumTerminalNodes()
Definition: Tree.cc:179
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:418
Node * getRootNode()
Definition: Tree.cc:160
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:597
Definition: Event.h:15
Tree & operator=(const Tree &tree)
Definition: Tree.cc:78
void setTerminalNodes(std::list< Node * > &sTNodes)
Definition: Tree.cc:167
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:125
int np
Definition: AMPTWrapper.h:33
double getTotalError()
Definition: Node.cc:170
void loadFromCondPayload(const L1TMuonEndCapForest::DTree &tree)
Definition: Tree.cc:587
T sqrt(T t)
Definition: SSEVec.h:18
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:395
void rankVariables(std::vector< double > &v)
Definition: Tree.cc:359
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:326
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:204
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:199
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:455
double getSplitValue()
Definition: Node.cc:136
std::string getName()
Definition: Node.cc:78
std::vector< DTreeNode > DTree
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:172
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:96
void setParent(Node *sParent)
Definition: Node.cc:119
void filterEventsRecursive(Node *node)
Definition: Tree.cc:281
int numTerminalNodes
Definition: Tree.h:62
void calcError()
Definition: Tree.cc:188
void setSplitVariable(int sSplitVar)
Definition: Node.cc:141
double boostWeight
Definition: Tree.h:64
Node * filterEvent(Event *e)
Definition: Tree.cc:299
void loadFromXML(const char *filename)
Definition: Tree.cc:477
void saveToXML(const char *filename)
Definition: Tree.cc:429
int getNumEvents()
Definition: Node.cc:192
Node * rootNode
Definition: Tree.h:60
double getErrorReduction()
Definition: Node.cc:90
unsigned xmlVersion
Definition: Tree.h:65
void filterEventsToDaughters()
Definition: Node.cc:350
double getAvgError()
Definition: Node.cc:180
int getSplitVariable()
Definition: Node.cc:146
Definition: tree.py:1
void calcOptimumSplit()
Definition: Node.cc:214
std::string numToStr(T num)
Definition: Utilities.h:44
void filterEvents(std::vector< Event * > &tEvents)
Definition: Tree.cc:265
long double T
void setRootNode(Node *sRootNode)
Definition: Tree.cc:155
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:367
void setSplitValue(double sSplitValue)
Definition: Node.cc:131
def move(src, dest)
Definition: eostools.py:510
~Tree()
Definition: Tree.cc:55
void theMiracleOfChildBirth()
Definition: Node.cc:335