CMS 3D CMS Logo

Tree.cc
Go to the documentation of this file.
1 // Tree.cxx //
3 // =====================================================================//
4 // This is the object implementation of a decision tree. //
5 // References include //
6 // *Elements of Statistical Learning by Hastie, //
7 // Tibshirani, and Friedman. //
8 // *Greedy Function Approximation: A Gradient Boosting Machine. //
9 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
10 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
11 // //
13 
15 // _______________________Includes_______________________________________//
17 
19 
20 #include <iostream>
21 #include <sstream>
22 
24 // _______________________Constructor(s)________________________________//
26 
27 using namespace emtf;
28 
30 {
31  rootNode = new Node("root");
32 
33  terminalNodes.push_back(rootNode);
34  numTerminalNodes = 1;
35  boostWeight = 0;
36  xmlVersion = 2017;
37 }
38 
39 Tree::Tree(std::vector< std::vector<Event*> >& cEvents)
40 {
41  rootNode = new Node("root");
42  rootNode->setEvents(cEvents);
43 
44  terminalNodes.push_back(rootNode);
45  numTerminalNodes = 1;
46  boostWeight = 0;
47  xmlVersion = 2017;
48 }
50 // _______________________Destructor____________________________________//
52 
53 
55 {
56 // When the tree is destroyed it will delete all of the nodes in the tree.
57 // The deletion begins with the rootnode and continues recursively.
58  if(rootNode) delete rootNode;
59 }
60 
62 {
63  // unfortunately, authors of these classes didn't use const qualifiers
64  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
66  rmsError = tree.rmsError;
67  boostWeight = tree.boostWeight;
68  xmlVersion = tree.xmlVersion;
69 
70  terminalNodes.resize(0);
71  // find new leafs
73 
75 }
76 
78  if(rootNode) delete rootNode;
79  // unfortunately, authors of these classes didn't use const qualifiers
80  rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
82  rmsError = tree.rmsError;
83  boostWeight = tree.boostWeight;
84  xmlVersion = tree.xmlVersion;
85 
86  terminalNodes.resize(0);
87  // find new leafs
89 
91 
92  return *this;
93 }
94 
95 Node* Tree::copyFrom(const Node *local_root)
96 {
97  // end-case
98  if( !local_root ) return nullptr;
99 
100  Node *lr = const_cast<Node*>(local_root);
101 
102  // recursion
103  Node *left_new_child = copyFrom( lr->getLeftDaughter() );
104  Node *right_new_child = copyFrom( lr->getRightDaughter() );
105 
106  // performing main work at this level
107  Node *new_local_root = new Node( lr->getName() );
108  if( left_new_child ) left_new_child ->setParent(new_local_root);
109  if( right_new_child ) right_new_child->setParent(new_local_root);
110  new_local_root->setLeftDaughter ( left_new_child );
111  new_local_root->setRightDaughter( right_new_child );
112  new_local_root->setErrorReduction( lr->getErrorReduction() );
113  new_local_root->setSplitValue( lr->getSplitValue() );
114  new_local_root->setSplitVariable( lr->getSplitVariable() );
115  new_local_root->setFitValue( lr->getFitValue() );
116  new_local_root->setTotalError( lr->getTotalError() );
117  new_local_root->setAvgError( lr->getAvgError() );
118  new_local_root->setNumEvents( lr->getNumEvents() );
119 // new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
120 
121  return new_local_root;
122 }
123 
124 void Tree::findLeafs(Node *local_root, std::list<Node*> &tn)
125 {
126  if( !local_root->getLeftDaughter() && !local_root->getRightDaughter() ){
127  // leaf or ternimal node found
128  tn.push_back(local_root);
129  return;
130  }
131 
132  if( local_root->getLeftDaughter() )
133  findLeafs( local_root->getLeftDaughter(), tn );
134 
135  if( local_root->getRightDaughter() )
136  findLeafs( local_root->getRightDaughter(), tn );
137 }
138 
140 {
141  if(rootNode) delete rootNode; // this line is the only reason not to use default move constructor
142  rootNode = tree.rootNode;
143  terminalNodes = std::move(tree.terminalNodes);
144  numTerminalNodes = tree.numTerminalNodes;
145  rmsError = tree.rmsError;
146  boostWeight = tree.boostWeight;
147  xmlVersion = tree.xmlVersion;
148 }
149 
151 // ______________________Get/Set________________________________________//
153 
154 void Tree::setRootNode(Node *sRootNode)
155 {
156  rootNode = sRootNode;
157 }
158 
160 {
161  return rootNode;
162 }
163 
164 // ----------------------------------------------------------------------
165 
166 void Tree::setTerminalNodes(std::list<Node*>& sTNodes)
167 {
168  terminalNodes = sTNodes;
169 }
170 
171 std::list<Node*>& Tree::getTerminalNodes()
172 {
173  return terminalNodes;
174 }
175 
176 // ----------------------------------------------------------------------
177 
179 {
180  return numTerminalNodes;
181 }
182 
184 // ______________________Performace_____________________________________//
186 
188 {
189 // Loop through the separate predictive regions (terminal nodes) and
190 // add up the errors to get the error of the entire space.
191 
192  double totalSquaredError = 0;
193 
194  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
195  {
196  totalSquaredError += (*it)->getTotalError();
197  }
198  rmsError = sqrt( totalSquaredError/rootNode->getNumEvents() );
199 }
200 
201 // ----------------------------------------------------------------------
202 
203 void Tree::buildTree(int nodeLimit)
204 {
205  // We greedily pick the best terminal node to split.
206  double bestNodeErrorReduction = -1;
207  Node* nodeToSplit = nullptr;
208 
209  if(numTerminalNodes == 1)
210  {
212  calcError();
213 // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
214  }
215 
216  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
217  {
218  if( (*it)->getErrorReduction() > bestNodeErrorReduction )
219  {
220  bestNodeErrorReduction = (*it)->getErrorReduction();
221  nodeToSplit = (*it);
222  }
223  }
224 
225  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
226 
227  // If all of the nodes have one event we can't add any more nodes and reduce the error.
228  if(nodeToSplit == nullptr) return;
229 
230  // Create daughter nodes, and link the nodes together appropriately.
231  nodeToSplit->theMiracleOfChildBirth();
232 
233  // Get left and right daughters for reference.
234  Node* left = nodeToSplit->getLeftDaughter();
235  Node* right = nodeToSplit->getRightDaughter();
236 
237  // Update the list of terminal nodes.
238  terminalNodes.remove(nodeToSplit);
239  terminalNodes.push_back(left);
240  terminalNodes.push_back(right);
242 
243  // Filter the events from the parent into the daughters.
244  nodeToSplit->filterEventsToDaughters();
245 
246  // Calculate the best splits for the new nodes.
247  left->calcOptimumSplit();
248  right->calcOptimumSplit();
249 
250  // See if the error reduces as we add more nodes.
251  calcError();
252 
253  if(numTerminalNodes % 1 == 0)
254  {
255 // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
256  }
257 
258  // Repeat until done.
259  if(numTerminalNodes < nodeLimit) buildTree(nodeLimit);
260 }
261 
262 // ----------------------------------------------------------------------
263 
264 void Tree::filterEvents(std::vector<Event*>& tEvents)
265 {
266 // Use trees which have already been built to fit a bunch of events
267 // given by the tEvents vector.
268 
269  // Set the events to be filtered.
270  rootNode->getEvents() = std::vector< std::vector<Event*> >(1);
271  rootNode->getEvents()[0] = tEvents;
272 
273  // The tree now knows about the events it needs to fit.
274  // Filter them into a predictive region (terminal node).
276 }
277 
278 // ----------------------------------------------------------------------
279 
281 {
282 // Filter the events repeatedly into the daughter nodes until they
283 // fall into a terminal node.
284 
285  Node* left = node->getLeftDaughter();
286  Node* right = node->getRightDaughter();
287 
288  if(left == nullptr || right == nullptr) return;
289 
290  node->filterEventsToDaughters();
291 
292  filterEventsRecursive(left);
293  filterEventsRecursive(right);
294 }
295 
296 // ----------------------------------------------------------------------
297 
299 {
300 // Use trees which have already been built to fit a bunch of events
301 // given by the tEvents vector.
302 
303  // Filter the event into a predictive region (terminal node).
304  Node* node = filterEventRecursive(rootNode, e);
305  return node;
306 }
307 
308 // ----------------------------------------------------------------------
309 
311 {
312 // Filter the event repeatedly into the daughter nodes until it
313 // falls into a terminal node.
314 
315 
316  Node* nextNode = node->filterEventToDaughter(e);
317  if(nextNode == nullptr) return node;
318 
319  return filterEventRecursive(nextNode, e);
320 }
321 
322 // ----------------------------------------------------------------------
323 
324 
325 void Tree::rankVariablesRecursive(Node* node, std::vector<double>& v)
326 {
327 // We recursively go through all of the nodes in the tree and find the
328 // total error reduction for each variable. The one with the most
329 // error reduction should be the most important.
330 
331  Node* left = node->getLeftDaughter();
332  Node* right = node->getRightDaughter();
333 
334  // Terminal nodes don't contribute to error reduction.
335  if(left==nullptr || right==nullptr) return;
336 
337  int sv = node->getSplitVariable();
338  double er = node->getErrorReduction();
339 
340  //if(sv == -1)
341  //{
342  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
343  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
344  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
345  //}
346 
347  // Add error reduction to the current total for the appropriate
348  // variable.
349  v[sv] += er;
350 
351  rankVariablesRecursive(left, v);
352  rankVariablesRecursive(right, v);
353 
354 }
355 
356 // ----------------------------------------------------------------------
357 
358 void Tree::rankVariables(std::vector<double>& v)
359 {
361 }
362 
363 // ----------------------------------------------------------------------
364 
365 
366 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<double>>& v)
367 {
368 // We recursively go through all of the nodes in the tree and find the
369 // split points used for each split variable.
370 
371  Node* left = node->getLeftDaughter();
372  Node* right = node->getRightDaughter();
373 
374  // Terminal nodes don't contribute.
375  if(left==nullptr || right==nullptr) return;
376 
377  int sv = node->getSplitVariable();
378  double sp = node->getSplitValue();
379 
380  if(sv == -1)
381  {
382  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
383  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
384  }
385 
386  // Add the split point to the list for the correct split variable.
387  v[sv].push_back(sp);
388 
389  getSplitValuesRecursive(left, v);
390  getSplitValuesRecursive(right, v);
391 
392 }
393 
394 // ----------------------------------------------------------------------
395 
396 void Tree::getSplitValues(std::vector<std::vector<double>>& v)
397 {
399 }
400 
402 // ______________________Storage/Retrieval______________________________//
404 
405 template <typename T>
407 {
408 // Convert a number to a string.
409  std::stringstream ss;
410  ss << num;
411  std::string s = ss.str();
412  return s;
413 }
414 
415 // ----------------------------------------------------------------------
416 
417 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
418 {
419  // Convert Node members into XML attributes
420  // and add them to the XMLEngine.
421  xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
422  xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
423  xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
424 }
425 
426 // ----------------------------------------------------------------------
427 
428 void Tree::saveToXML(const char* c)
429 {
430 
431  TXMLEngine* xml = new TXMLEngine();
432 
433  // Add the root node.
434  XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
435  addXMLAttributes(xml, rootNode, root);
436 
437  // Recursively write the tree to XML.
438  saveToXMLRecursive(xml, rootNode, root);
439 
440  // Make the XML Document.
441  XMLDocPointer_t xmldoc = xml->NewDoc();
442  xml->DocSetRootElement(xmldoc, root);
443 
444  // Save to file.
445  xml->SaveDoc(xmldoc, c);
446 
447  // Clean up.
448  xml->FreeDoc(xmldoc);
449  delete xml;
450 }
451 
452 // ----------------------------------------------------------------------
453 
454 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
455 {
456  Node* l = node->getLeftDaughter();
457  Node* r = node->getRightDaughter();
458 
459  if(l==nullptr || r==nullptr) return;
460 
461  // Add children to the XMLEngine.
462  XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
463  XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
464 
465  // Add attributes to the children.
466  addXMLAttributes(xml, l, left);
467  addXMLAttributes(xml, r, right);
468 
469  // Recurse.
470  saveToXMLRecursive(xml, l, left);
471  saveToXMLRecursive(xml, r, right);
472 }
473 
474 // ----------------------------------------------------------------------
475 
476 void Tree::loadFromXML(const char* filename)
477 {
478  // First create the engine.
479  TXMLEngine* xml = new TXMLEngine;
480 
481  // Now try to parse xml file.
482  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
483  if (xmldoc==nullptr)
484  {
485  delete xml;
486  return;
487  }
488 
489  // Get access to main node of the xml file.
490  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
491 
492  // the original 2016 pT xmls define the source tree node to be the top-level xml node
493  // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
494  // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
495  if( std::string("BinaryTree") == xml->GetNodeName(mainnode) ){
496  XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
497  attr = xml->GetNextAttr(attr);
498  boostWeight = (attr ? strtod(xml->GetAttrValue(attr),nullptr) : 0);
499  // step inside the top-level xml node
500  mainnode = xml->GetChild(mainnode);
501  xmlVersion = 2017;
502  } else {
503  boostWeight = 0;
504  xmlVersion = 2016;
505  }
506  // Recursively connect nodes together.
507  loadFromXMLRecursive(xml, mainnode, rootNode);
508 
509  // Release memory before exit
510  xml->FreeDoc(xmldoc);
511  delete xml;
512 }
513 
514 // ----------------------------------------------------------------------
515 
516 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode)
517 {
518 
519  // Get the split information from xml.
520  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
521  std::vector<std::string> splitInfo(3);
522  if( xmlVersion >= 2017 ){
523  for(unsigned int i=0,j=0; i<10; i++)
524  {
525  if(i==3 || i==4 || i==6){
526  splitInfo[j++] = xml->GetAttrValue(attr);
527  }
528  attr = xml->GetNextAttr(attr);
529  }
530  } else {
531  for(unsigned int i=0; i<3; i++)
532  {
533  splitInfo[i] = xml->GetAttrValue(attr);
534  attr = xml->GetNextAttr(attr);
535  }
536  }
537 
538  // Convert strings into numbers.
539  std::stringstream converter;
540  int splitVar;
541  double splitVal;
542  double fitVal;
543 
544  converter << splitInfo[0];
545  converter >> splitVar;
546  converter.str("");
547  converter.clear();
548 
549  converter << splitInfo[1];
550  converter >> splitVal;
551  converter.str("");
552  converter.clear();
553 
554  converter << splitInfo[2];
555  converter >> fitVal;
556  converter.str("");
557  converter.clear();
558 
559  // Store gathered splitInfo into the node object.
560  tnode->setSplitVariable(splitVar);
561  tnode->setSplitValue(splitVal);
562  tnode->setFitValue(fitVal);
563 
564  // Get the xml daughters of the current xml node.
565  XMLNodePointer_t xleft = xml->GetChild(xnode);
566  XMLNodePointer_t xright = xml->GetNext(xleft);
567 
568  // If there are no daughters we are done.
569  if(xleft == nullptr || xright == nullptr) return;
570 
571  // If there are daughters link the node objects appropriately.
572  tnode->theMiracleOfChildBirth();
573  Node* tleft = tnode->getLeftDaughter();
574  Node* tright = tnode->getRightDaughter();
575 
576  // Update the list of terminal nodes.
577  terminalNodes.remove(tnode);
578  terminalNodes.push_back(tleft);
579  terminalNodes.push_back(tright);
581 
582  loadFromXMLRecursive(xml, xleft, tleft);
583  loadFromXMLRecursive(xml, xright, tright);
584 }
585 
587 {
588  // start fresh in case this is not the only call to construct a tree
589  if( rootNode ) delete rootNode;
590  rootNode = new Node("root");
591 
592  const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
593  loadFromCondPayloadRecursive(tree, mainnode, rootNode);
594 }
595 
597 {
598  // Store gathered splitInfo into the node object.
599  tnode->setSplitVariable(node.splitVar);
600  tnode->setSplitValue(node.splitVal);
601  tnode->setFitValue(node.fitVal);
602 
603  // If there are no daughters we are done.
604  if( node.ileft == 0 || node.iright == 0) return; // root cannot be anyone's child
605  if( node.ileft >= tree.size() ||
606  node.iright >= tree.size() ) return; // out of range addressing on purpose
607 
608  // If there are daughters link the node objects appropriately.
609  tnode->theMiracleOfChildBirth();
610  Node* tleft = tnode->getLeftDaughter();
611  Node* tright = tnode->getRightDaughter();
612 
613  // Update the list of terminal nodes.
614  terminalNodes.remove(tnode);
615  terminalNodes.push_back(tleft);
616  terminalNodes.push_back(tright);
618 
619  loadFromCondPayloadRecursive(tree, tree[node.ileft], tleft);
620  loadFromCondPayloadRecursive(tree, tree[node.iright], tright);
621 }
void setFitValue(double sFitValue)
Definition: Node.cc:153
Node * getRightDaughter()
Definition: Node.cc:112
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:310
Node * getLeftDaughter()
Definition: Node.cc:102
double getFitValue()
Definition: Node.cc:158
std::list< Node * > terminalNodes
Definition: Tree.h:61
void buildTree(int nodeLimit)
Definition: Tree.cc:203
void getSplitValues(std::vector< std::vector< double >> &v)
Definition: Tree.cc:396
Tree()
Definition: Tree.cc:29
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:516
double rmsError
Definition: Tree.h:63
int getNumTerminalNodes()
Definition: Tree.cc:178
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:417
Node * getRootNode()
Definition: Tree.cc:159
void loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree &tree, const L1TMuonEndCapForest::DTreeNode &node, Node *tnode)
Definition: Tree.cc:596
Definition: Event.h:15
Tree & operator=(const Tree &tree)
Definition: Tree.cc:77
void setTerminalNodes(std::list< Node * > &sTNodes)
Definition: Tree.cc:166
void findLeafs(Node *local_root, std::list< Node * > &tn)
Definition: Tree.cc:124
int np
Definition: AMPTWrapper.h:33
double getTotalError()
Definition: Node.cc:170
void loadFromCondPayload(const L1TMuonEndCapForest::DTree &tree)
Definition: Tree.cc:586
T sqrt(T t)
Definition: SSEVec.h:18
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:395
void rankVariables(std::vector< double > &v)
Definition: Tree.cc:358
void rankVariablesRecursive(Node *node, std::vector< double > &v)
Definition: Tree.cc:325
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:204
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:199
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:454
double getSplitValue()
Definition: Node.cc:136
std::string getName()
Definition: Node.cc:78
std::vector< DTreeNode > DTree
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:171
Node * copyFrom(const Node *local_root)
Definition: Tree.cc:95
void setParent(Node *sParent)
Definition: Node.cc:119
void filterEventsRecursive(Node *node)
Definition: Tree.cc:280
int numTerminalNodes
Definition: Tree.h:62
void calcError()
Definition: Tree.cc:187
void setSplitVariable(int sSplitVar)
Definition: Node.cc:141
double boostWeight
Definition: Tree.h:64
Node * filterEvent(Event *e)
Definition: Tree.cc:298
void loadFromXML(const char *filename)
Definition: Tree.cc:476
void saveToXML(const char *filename)
Definition: Tree.cc:428
int getNumEvents()
Definition: Node.cc:192
Node * rootNode
Definition: Tree.h:60
double getErrorReduction()
Definition: Node.cc:90
unsigned xmlVersion
Definition: Tree.h:65
void filterEventsToDaughters()
Definition: Node.cc:350
double getAvgError()
Definition: Node.cc:180
int getSplitVariable()
Definition: Node.cc:146
Definition: tree.py:1
void calcOptimumSplit()
Definition: Node.cc:214
std::string numToStr(T num)
Definition: Utilities.h:44
void filterEvents(std::vector< Event * > &tEvents)
Definition: Tree.cc:264
long double T
void setRootNode(Node *sRootNode)
Definition: Tree.cc:154
void getSplitValuesRecursive(Node *node, std::vector< std::vector< double >> &v)
Definition: Tree.cc:366
void setSplitValue(double sSplitValue)
Definition: Node.cc:131
def move(src, dest)
Definition: eostools.py:510
~Tree()
Definition: Tree.cc:54
void theMiracleOfChildBirth()
Definition: Node.cc:335