CMS 3D CMS Logo

Tree.cc
Go to the documentation of this file.
1 // Tree.cxx //
3 // =====================================================================//
4 // This is the object implementation of a decision tree. //
5 // References include //
6 // *Elements of Statistical Learning by Hastie, //
7 // Tibshirani, and Friedman. //
8 // *Greedy Function Approximation: A Gradient Boosting Machine. //
9 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
10 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
11 // //
13 
15 // _______________________Includes_______________________________________//
17 
19 #include <iostream>
20 #include <sstream>
21 
22 using namespace emtf;
23 
25 // _______________________Constructor(s)________________________________//
27 
29 {
30  rootNode = new Node("root");
31 
32  terminalNodes.push_back(rootNode);
33  numTerminalNodes = 1;
34 }
35 
36 Tree::Tree(std::vector< std::vector<Event*> >& cEvents)
37 {
38  rootNode = new Node("root");
39  rootNode->setEvents(cEvents);
40 
41  terminalNodes.push_back(rootNode);
42  numTerminalNodes = 1;
43 }
45 // _______________________Destructor____________________________________//
47 
49 {
50  // When the tree is destroyed it will delete all of the nodes in the tree.
51 // The deletion begins with the rootnode and continues recursively.
52  delete rootNode;
53 }
54 
56 // ______________________Get/Set________________________________________//
58 
59 void Tree::setRootNode(Node *sRootNode)
60 {
61  rootNode = sRootNode;
62 }
63 
65 {
66  return rootNode;
67 }
68 
69 // ----------------------------------------------------------------------
70 
71 void Tree::setTerminalNodes(std::list<Node*>& sTNodes)
72 {
73  terminalNodes = sTNodes;
74 }
75 
76 std::list<Node*>& Tree::getTerminalNodes()
77 {
78  return terminalNodes;
79 }
80 
81 // ----------------------------------------------------------------------
82 
84 {
85  return numTerminalNodes;
86 }
87 
89 // ______________________Performace_____________________________________//
91 
93 {
94  // Loop through the separate predictive regions (terminal nodes) and
95  // add up the errors to get the error of the entire space.
96 
97  Double_t totalSquaredError = 0;
98 
99  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
100  {
101  totalSquaredError += (*it)->getTotalError();
102  }
103  rmsError = sqrt( totalSquaredError/rootNode->getNumEvents() );
104 }
105 
106 // ----------------------------------------------------------------------
107 
108 void Tree::buildTree(Int_t nodeLimit)
109 {
110  // We greedily pick the best terminal node to split.
111  Double_t bestNodeErrorReduction = -1;
112  Node* nodeToSplit = 0;
113 
114  if(numTerminalNodes == 1)
115  {
117  calcError();
118  // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
119  }
120 
121  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
122  {
123  if( (*it)->getErrorReduction() > bestNodeErrorReduction )
124  {
125  bestNodeErrorReduction = (*it)->getErrorReduction();
126  nodeToSplit = (*it);
127  }
128  }
129 
130  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
131 
132  // If all of the nodes have one event we can't add any more nodes and reduce the error.
133  if(nodeToSplit == 0) return;
134 
135  // Create daughter nodes, and link the nodes together appropriately.
136  nodeToSplit->theMiracleOfChildBirth();
137 
138  // Get left and right daughters for reference.
139  Node* left = nodeToSplit->getLeftDaughter();
140  Node* right = nodeToSplit->getRightDaughter();
141 
142  // Update the list of terminal nodes.
143  terminalNodes.remove(nodeToSplit);
144  terminalNodes.push_back(left);
145  terminalNodes.push_back(right);
147 
148  // Filter the events from the parent into the daughters.
149  nodeToSplit->filterEventsToDaughters();
150 
151  // Calculate the best splits for the new nodes.
152  left->calcOptimumSplit();
153  right->calcOptimumSplit();
154 
155  // See if the error reduces as we add more nodes.
156  calcError();
157 
158  if(numTerminalNodes % 1 == 0)
159  {
160  // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
161  }
162 
163  // Repeat until done.
164  if(numTerminalNodes < nodeLimit) buildTree(nodeLimit);
165 }
166 
167 // ----------------------------------------------------------------------
168 
169 void Tree::filterEvents(std::vector<Event*>& tEvents)
170 {
171  // Use trees which have already been built to fit a bunch of events
172  // given by the tEvents vector.
173 
174  // Set the events to be filtered.
175  rootNode->getEvents() = std::vector< std::vector<Event*> >(1);
176  rootNode->getEvents()[0] = tEvents;
177 
178  // The tree now knows about the events it needs to fit.
179  // Filter them into a predictive region (terminal node).
181 }
182 
183 // ----------------------------------------------------------------------
184 
186 {
187  // Filter the events repeatedly into the daughter nodes until they
188  // fall into a terminal node.
189 
190  Node* left = node->getLeftDaughter();
191  Node* right = node->getRightDaughter();
192 
193  if(left == 0 || right == 0) return;
194 
195  node->filterEventsToDaughters();
196 
197  filterEventsRecursive(left);
198  filterEventsRecursive(right);
199 }
200 
201 // ----------------------------------------------------------------------
202 
204 {
205  // Use trees which have already been built to fit a bunch of events
206  // given by the tEvents vector.
207 
208  // Filter the event into a predictive region (terminal node).
209  Node* node = filterEventRecursive(rootNode, e);
210  return node;
211 }
212 
213 // ----------------------------------------------------------------------
214 
216 {
217  // Filter the event repeatedly into the daughter nodes until it
218  // falls into a terminal node.
219 
220 
221  Node* nextNode = node->filterEventToDaughter(e);
222  if(nextNode == 0) return node;
223 
224  return filterEventRecursive(nextNode, e);
225 }
226 
227 // ----------------------------------------------------------------------
228 
229 
230 void Tree::rankVariablesRecursive(Node* node, std::vector<Double_t>& v)
231 {
232  // We recursively go through all of the nodes in the tree and find the
233  // total error reduction for each variable. The one with the most
234  // error reduction should be the most important.
235 
236  Node* left = node->getLeftDaughter();
237  Node* right = node->getRightDaughter();
238 
239  // Terminal nodes don't contribute to error reduction.
240  if(left==0 || right==0) return;
241 
242  Int_t sv = node->getSplitVariable();
243  Double_t er = node->getErrorReduction();
244 
245  //if(sv == -1)
246  //{
247  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
248  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
249  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
250  //}
251 
252  // Add error reduction to the current total for the appropriate
253  // variable.
254  v[sv] += er;
255 
256  rankVariablesRecursive(left, v);
257  rankVariablesRecursive(right, v);
258 
259 }
260 
261 // ----------------------------------------------------------------------
262 
263 void Tree::rankVariables(std::vector<Double_t>& v)
264 {
266 }
267 
268 // ----------------------------------------------------------------------
269 
270 
271 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<Double_t>>& v)
272 {
273  // We recursively go through all of the nodes in the tree and find the
274  // split points used for each split variable.
275 
276  Node* left = node->getLeftDaughter();
277  Node* right = node->getRightDaughter();
278 
279  // Terminal nodes don't contribute.
280  if(left==0 || right==0) return;
281 
282  Int_t sv = node->getSplitVariable();
283  Double_t sp = node->getSplitValue();
284 
285  if(sv == -1)
286  {
287  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
288  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
289  }
290 
291  // Add the split point to the list for the correct split variable.
292  v[sv].push_back(sp);
293 
294  getSplitValuesRecursive(left, v);
295  getSplitValuesRecursive(right, v);
296 
297 }
298 
299 // ----------------------------------------------------------------------
300 
301 void Tree::getSplitValues(std::vector<std::vector<Double_t>>& v)
302 {
304 }
305 
307 // ______________________Storage/Retrieval______________________________//
309 
310 template <typename T>
312 {
313  // Convert a number to a string.
314  std::stringstream ss;
315  ss << num;
316  std::string s = ss.str();
317  return s;
318 }
319 
320 // ----------------------------------------------------------------------
321 
322 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
323 {
324  // Convert Node members into XML attributes
325  // and add them to the XMLEngine.
326  xml->NewAttr(np, 0, "splitVar", numToStr(node->getSplitVariable()).c_str());
327  xml->NewAttr(np, 0, "splitVal", numToStr(node->getSplitValue()).c_str());
328  xml->NewAttr(np, 0, "fitVal", numToStr(node->getFitValue()).c_str());
329 }
330 
331 // ----------------------------------------------------------------------
332 
333 void Tree::saveToXML(const char* c)
334 {
335 
336  TXMLEngine* xml = new TXMLEngine();
337 
338  // Add the root node.
339  XMLNodePointer_t root = xml->NewChild(0, 0, rootNode->getName().c_str());
340  addXMLAttributes(xml, rootNode, root);
341 
342  // Recursively write the tree to XML.
343  saveToXMLRecursive(xml, rootNode, root);
344 
345  // Make the XML Document.
346  XMLDocPointer_t xmldoc = xml->NewDoc();
347  xml->DocSetRootElement(xmldoc, root);
348 
349  // Save to file.
350  xml->SaveDoc(xmldoc, c);
351 
352  // Clean up.
353  xml->FreeDoc(xmldoc);
354  delete xml;
355 }
356 
357 // ----------------------------------------------------------------------
358 
359 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
360 {
361  Node* l = node->getLeftDaughter();
362  Node* r = node->getRightDaughter();
363 
364  if(l==0 || r==0) return;
365 
366  // Add children to the XMLEngine.
367  XMLNodePointer_t left = xml->NewChild(np, 0, "left");
368  XMLNodePointer_t right = xml->NewChild(np, 0, "right");
369 
370  // Add attributes to the children.
371  addXMLAttributes(xml, l, left);
372  addXMLAttributes(xml, r, right);
373 
374  // Recurse.
375  saveToXMLRecursive(xml, l, left);
376  saveToXMLRecursive(xml, r, right);
377 }
378 
379 // ----------------------------------------------------------------------
380 
381 void Tree::loadFromXML(const char* filename)
382 {
383  // First create the engine.
384  TXMLEngine* xml = new TXMLEngine;
385 
386  // Now try to parse xml file.
387  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
388  if (xmldoc==0)
389  {
390  delete xml;
391  return;
392  }
393 
394  // Get access to main node of the xml file.
395  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
396 
397  // Recursively connect nodes together.
398  loadFromXMLRecursive(xml, mainnode, rootNode);
399 
400  // Release memory before exit
401  xml->FreeDoc(xmldoc);
402  delete xml;
403 }
404 
405 // ----------------------------------------------------------------------
406 
407 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode)
408 {
409 
410  // Get the split information from xml.
411  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
412  std::vector<std::string> splitInfo(3);
413  for(unsigned int i=0; i<3; i++)
414  {
415  splitInfo[i] = xml->GetAttrValue(attr);
416  attr = xml->GetNextAttr(attr);
417  }
418 
419  // Convert strings into numbers.
420  std::stringstream converter;
421  Int_t splitVar;
422  Double_t splitVal;
423  Double_t fitVal;
424 
425  converter << splitInfo[0];
426  converter >> splitVar;
427  converter.str("");
428  converter.clear();
429 
430  converter << splitInfo[1];
431  converter >> splitVal;
432  converter.str("");
433  converter.clear();
434 
435  converter << splitInfo[2];
436  converter >> fitVal;
437  converter.str("");
438  converter.clear();
439 
440  // Store gathered splitInfo into the node object.
441  tnode->setSplitVariable(splitVar);
442  tnode->setSplitValue(splitVal);
443  tnode->setFitValue(fitVal);
444 
445  // Get the xml daughters of the current xml node.
446  XMLNodePointer_t xleft = xml->GetChild(xnode);
447  XMLNodePointer_t xright = xml->GetNext(xleft);
448 
449  // If there are no daughters we are done.
450  if(xleft == 0 || xright == 0) return;
451 
452  // If there are daughters link the node objects appropriately.
453  tnode->theMiracleOfChildBirth();
454  Node* tleft = tnode->getLeftDaughter();
455  Node* tright = tnode->getRightDaughter();
456 
457  // Update the list of terminal nodes.
458  terminalNodes.remove(tnode);
459  terminalNodes.push_back(tleft);
460  terminalNodes.push_back(tright);
462 
463  loadFromXMLRecursive(xml, xleft, tleft);
464  loadFromXMLRecursive(xml, xright, tright);
465 }
Node * getRightDaughter()
Definition: Node.cc:111
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:215
Node * getLeftDaughter()
Definition: Node.cc:101
std::list< Node * > terminalNodes
Definition: Tree.h:55
void getSplitValues(std::vector< std::vector< Double_t >> &v)
Definition: Tree.cc:301
Tree()
Definition: Tree.cc:28
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:407
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:322
Node * getRootNode()
Definition: Tree.cc:64
Definition: Event.h:15
void getSplitValuesRecursive(Node *node, std::vector< std::vector< Double_t >> &v)
Definition: Tree.cc:271
void rankVariables(std::vector< Double_t > &v)
Definition: Tree.cc:263
void setEvents(std::vector< std::vector< emtf::Event * > > &sEvents)
Definition: Node.cc:203
void setTerminalNodes(std::list< Node * > &sTNodes)
Definition: Tree.cc:71
Double_t getErrorReduction()
Definition: Node.cc:89
int np
Definition: AMPTWrapper.h:33
T sqrt(T t)
Definition: SSEVec.h:18
void setFitValue(Double_t sFitValue)
Definition: Node.cc:152
std::vector< std::vector< emtf::Event * > > & getEvents()
Definition: Node.cc:198
void setSplitVariable(Int_t sSplitVar)
Definition: Node.cc:140
std::string numToStr(T num)
Definition: Tree.cc:311
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:359
Double_t getSplitValue()
Definition: Node.cc:135
std::string getName()
Definition: Node.cc:77
Int_t numTerminalNodes
Definition: Tree.h:56
Int_t getNumEvents()
Definition: Node.cc:191
Int_t getNumTerminalNodes()
Definition: Tree.cc:83
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:76
Double_t getFitValue()
Definition: Node.cc:157
Int_t getSplitVariable()
Definition: Node.cc:145
void filterEventsRecursive(Node *node)
Definition: Tree.cc:185
void buildTree(Int_t nodeLimit)
Definition: Tree.cc:108
void calcError()
Definition: Tree.cc:92
Node * filterEvent(Event *e)
Definition: Tree.cc:203
void loadFromXML(const char *filename)
Definition: Tree.cc:381
Double_t rmsError
Definition: Tree.h:57
void saveToXML(const char *filename)
Definition: Tree.cc:333
void setSplitValue(Double_t sSplitValue)
Definition: Node.cc:130
Node * filterEventToDaughter(emtf::Event *e)
Definition: Node.cc:394
Node * rootNode
Definition: Tree.h:54
void filterEventsToDaughters()
Definition: Node.cc:349
Definition: sp.h:21
void rankVariablesRecursive(Node *node, std::vector< Double_t > &v)
Definition: Tree.cc:230
void calcOptimumSplit()
Definition: Node.cc:213
void filterEvents(std::vector< Event * > &tEvents)
Definition: Tree.cc:169
long double T
void setRootNode(Node *sRootNode)
Definition: Tree.cc:59
~Tree()
Definition: Tree.cc:48
void theMiracleOfChildBirth()
Definition: Node.cc:334