CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
Tree.cc
Go to the documentation of this file.
1 // Tree.cxx //
3 // =====================================================================//
4 // This is the object implementation of a decision tree. //
5 // References include //
6 // *Elements of Statistical Learning by Hastie, //
7 // Tibshirani, and Friedman. //
8 // *Greedy Function Approximation: A Gradient Boosting Machine. //
9 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
10 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
11 // //
13 
15 // _______________________Includes_______________________________________//
17 
19 #include <iostream>
20 #include <sstream>
21 
23 // _______________________Constructor(s)________________________________//
25 
27 {
28  rootNode = new Node("root");
29 
30  terminalNodes.push_back(rootNode);
31  numTerminalNodes = 1;
32 }
33 
34 Tree::Tree(std::vector< std::vector<Event*> >& cEvents)
35 {
36  rootNode = new Node("root");
37  rootNode->setEvents(cEvents);
38 
39  terminalNodes.push_back(rootNode);
40  numTerminalNodes = 1;
41 }
43 // _______________________Destructor____________________________________//
45 
46 
48 {
49 // When the tree is destroyed it will delete all of the nodes in the tree.
50 // The deletion begins with the rootnode and continues recursively.
51  delete rootNode;
52 }
53 
55 // ______________________Get/Set________________________________________//
57 
58 void Tree::setRootNode(Node *sRootNode)
59 {
60  rootNode = sRootNode;
61 }
62 
64 {
65  return rootNode;
66 }
67 
68 // ----------------------------------------------------------------------
69 
70 void Tree::setTerminalNodes(std::list<Node*>& sTNodes)
71 {
72  terminalNodes = sTNodes;
73 }
74 
75 std::list<Node*>& Tree::getTerminalNodes()
76 {
77  return terminalNodes;
78 }
79 
80 // ----------------------------------------------------------------------
81 
83 {
84  return numTerminalNodes;
85 }
86 
88 // ______________________Performace_____________________________________//
90 
92 {
93 // Loop through the separate predictive regions (terminal nodes) and
94 // add up the errors to get the error of the entire space.
95 
96  Double_t totalSquaredError = 0;
97 
98  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
99  {
100  totalSquaredError += (*it)->getTotalError();
101  }
102  rmsError = sqrt( totalSquaredError/rootNode->getNumEvents() );
103 }
104 
105 // ----------------------------------------------------------------------
106 
107 void Tree::buildTree(Int_t nodeLimit)
108 {
109  // We greedily pick the best terminal node to split.
110  Double_t bestNodeErrorReduction = -1;
111  Node* nodeToSplit = 0;
112 
113  if(numTerminalNodes == 1)
114  {
116  calcError();
117 // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
118  }
119 
120  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
121  {
122  if( (*it)->getErrorReduction() > bestNodeErrorReduction )
123  {
124  bestNodeErrorReduction = (*it)->getErrorReduction();
125  nodeToSplit = (*it);
126  }
127  }
128 
129  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
130 
131  // If all of the nodes have one event we can't add any more nodes and reduce the error.
132  if(nodeToSplit == 0) return;
133 
134  // Create daughter nodes, and link the nodes together appropriately.
135  nodeToSplit->theMiracleOfChildBirth();
136 
137  // Get left and right daughters for reference.
138  Node* left = nodeToSplit->getLeftDaughter();
139  Node* right = nodeToSplit->getRightDaughter();
140 
141  // Update the list of terminal nodes.
142  terminalNodes.remove(nodeToSplit);
143  terminalNodes.push_back(left);
144  terminalNodes.push_back(right);
146 
147  // Filter the events from the parent into the daughters.
148  nodeToSplit->filterEventsToDaughters();
149 
150  // Calculate the best splits for the new nodes.
151  left->calcOptimumSplit();
152  right->calcOptimumSplit();
153 
154  // See if the error reduces as we add more nodes.
155  calcError();
156 
157  if(numTerminalNodes % 1 == 0)
158  {
159 // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
160  }
161 
162  // Repeat until done.
163  if(numTerminalNodes < nodeLimit) buildTree(nodeLimit);
164 }
165 
166 // ----------------------------------------------------------------------
167 
168 void Tree::filterEvents(std::vector<Event*>& tEvents)
169 {
170 // Use trees which have already been built to fit a bunch of events
171 // given by the tEvents vector.
172 
173  // Set the events to be filtered.
174  rootNode->getEvents() = std::vector< std::vector<Event*> >(1);
175  rootNode->getEvents()[0] = tEvents;
176 
177  // The tree now knows about the events it needs to fit.
178  // Filter them into a predictive region (terminal node).
180 }
181 
182 // ----------------------------------------------------------------------
183 
185 {
186 // Filter the events repeatedly into the daughter nodes until they
187 // fall into a terminal node.
188 
189  Node* left = node->getLeftDaughter();
190  Node* right = node->getRightDaughter();
191 
192  if(left == 0 || right == 0) return;
193 
194  node->filterEventsToDaughters();
195 
196  filterEventsRecursive(left);
197  filterEventsRecursive(right);
198 }
199 
200 // ----------------------------------------------------------------------
201 
203 {
204 // Use trees which have already been built to fit a bunch of events
205 // given by the tEvents vector.
206 
207  // Filter the event into a predictive region (terminal node).
208  Node* node = filterEventRecursive(rootNode, e);
209  return node;
210 }
211 
212 // ----------------------------------------------------------------------
213 
215 {
216 // Filter the event repeatedly into the daughter nodes until it
217 // falls into a terminal node.
218 
219 
220  Node* nextNode = node->filterEventToDaughter(e);
221  if(nextNode == 0) return node;
222 
223  return filterEventRecursive(nextNode, e);
224 }
225 
226 // ----------------------------------------------------------------------
227 
228 
229 void Tree::rankVariablesRecursive(Node* node, std::vector<Double_t>& v)
230 {
231 // We recursively go through all of the nodes in the tree and find the
232 // total error reduction for each variable. The one with the most
233 // error reduction should be the most important.
234 
235  Node* left = node->getLeftDaughter();
236  Node* right = node->getRightDaughter();
237 
238  // Terminal nodes don't contribute to error reduction.
239  if(left==0 || right==0) return;
240 
241  Int_t sv = node->getSplitVariable();
242  Double_t er = node->getErrorReduction();
243 
244  //if(sv == -1)
245  //{
246  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
247  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
248  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
249  //}
250 
251  // Add error reduction to the current total for the appropriate
252  // variable.
253  v[sv] += er;
254 
255  rankVariablesRecursive(left, v);
256  rankVariablesRecursive(right, v);
257 
258 }
259 
260 // ----------------------------------------------------------------------
261 
262 void Tree::rankVariables(std::vector<Double_t>& v)
263 {
265 }
266 
267 // ----------------------------------------------------------------------
268 
269 
270 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<Double_t>>& v)
271 {
272 // We recursively go through all of the nodes in the tree and find the
273 // split points used for each split variable.
274 
275  Node* left = node->getLeftDaughter();
276  Node* right = node->getRightDaughter();
277 
278  // Terminal nodes don't contribute.
279  if(left==0 || right==0) return;
280 
281  Int_t sv = node->getSplitVariable();
282  Double_t sp = node->getSplitValue();
283 
284  if(sv == -1)
285  {
286  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
287  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
288  }
289 
290  // Add the split point to the list for the correct split variable.
291  v[sv].push_back(sp);
292 
293  getSplitValuesRecursive(left, v);
294  getSplitValuesRecursive(right, v);
295 
296 }
297 
298 // ----------------------------------------------------------------------
299 
300 void Tree::getSplitValues(std::vector<std::vector<Double_t>>& v)
301 {
303 }
304 
306 // ______________________Storage/Retrieval______________________________//
308 
309 template <typename T>
311 {
312 // Convert a number to a string.
313  std::stringstream ss;
314  ss << num;
315  std::string s = ss.str();
316  return s;
317 }
318 
319 // ----------------------------------------------------------------------
320 
321 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
322 {
323  // Convert Node members into XML attributes
324  // and add them to the XMLEngine.
325  xml->NewAttr(np, 0, "splitVar", numToStr(node->getSplitVariable()).c_str());
326  xml->NewAttr(np, 0, "splitVal", numToStr(node->getSplitValue()).c_str());
327  xml->NewAttr(np, 0, "fitVal", numToStr(node->getFitValue()).c_str());
328 }
329 
330 // ----------------------------------------------------------------------
331 
332 void Tree::saveToXML(const char* c)
333 {
334 
335  TXMLEngine* xml = new TXMLEngine();
336 
337  // Add the root node.
338  XMLNodePointer_t root = xml->NewChild(0, 0, rootNode->getName().c_str());
339  addXMLAttributes(xml, rootNode, root);
340 
341  // Recursively write the tree to XML.
342  saveToXMLRecursive(xml, rootNode, root);
343 
344  // Make the XML Document.
345  XMLDocPointer_t xmldoc = xml->NewDoc();
346  xml->DocSetRootElement(xmldoc, root);
347 
348  // Save to file.
349  xml->SaveDoc(xmldoc, c);
350 
351  // Clean up.
352  xml->FreeDoc(xmldoc);
353  delete xml;
354 }
355 
356 // ----------------------------------------------------------------------
357 
358 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
359 {
360  Node* l = node->getLeftDaughter();
361  Node* r = node->getRightDaughter();
362 
363  if(l==0 || r==0) return;
364 
365  // Add children to the XMLEngine.
366  XMLNodePointer_t left = xml->NewChild(np, 0, "left");
367  XMLNodePointer_t right = xml->NewChild(np, 0, "right");
368 
369  // Add attributes to the children.
370  addXMLAttributes(xml, l, left);
371  addXMLAttributes(xml, r, right);
372 
373  // Recurse.
374  saveToXMLRecursive(xml, l, left);
375  saveToXMLRecursive(xml, r, right);
376 }
377 
378 // ----------------------------------------------------------------------
379 
380 void Tree::loadFromXML(const char* filename)
381 {
382  // First create the engine.
383  TXMLEngine* xml = new TXMLEngine;
384 
385  // Now try to parse xml file.
386  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
387  if (xmldoc==0)
388  {
389  delete xml;
390  return;
391  }
392 
393  // Get access to main node of the xml file.
394  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
395 
396  // Recursively connect nodes together.
397  loadFromXMLRecursive(xml, mainnode, rootNode);
398 
399  // Release memory before exit
400  xml->FreeDoc(xmldoc);
401  delete xml;
402 }
403 
404 // ----------------------------------------------------------------------
405 
406 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode)
407 {
408 
409  // Get the split information from xml.
410  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
411  std::vector<std::string> splitInfo(3);
412  for(unsigned int i=0; i<3; i++)
413  {
414  splitInfo[i] = xml->GetAttrValue(attr);
415  attr = xml->GetNextAttr(attr);
416  }
417 
418  // Convert strings into numbers.
419  std::stringstream converter;
420  Int_t splitVar;
421  Double_t splitVal;
422  Double_t fitVal;
423 
424  converter << splitInfo[0];
425  converter >> splitVar;
426  converter.str("");
427  converter.clear();
428 
429  converter << splitInfo[1];
430  converter >> splitVal;
431  converter.str("");
432  converter.clear();
433 
434  converter << splitInfo[2];
435  converter >> fitVal;
436  converter.str("");
437  converter.clear();
438 
439  // Store gathered splitInfo into the node object.
440  tnode->setSplitVariable(splitVar);
441  tnode->setSplitValue(splitVal);
442  tnode->setFitValue(fitVal);
443 
444  // Get the xml daughters of the current xml node.
445  XMLNodePointer_t xleft = xml->GetChild(xnode);
446  XMLNodePointer_t xright = xml->GetNext(xleft);
447 
448  // If there are no daughters we are done.
449  if(xleft == 0 || xright == 0) return;
450 
451  // If there are daughters link the node objects appropriately.
452  tnode->theMiracleOfChildBirth();
453  Node* tleft = tnode->getLeftDaughter();
454  Node* tright = tnode->getRightDaughter();
455 
456  // Update the list of terminal nodes.
457  terminalNodes.remove(tnode);
458  terminalNodes.push_back(tleft);
459  terminalNodes.push_back(tright);
461 
462  loadFromXMLRecursive(xml, xleft, tleft);
463  loadFromXMLRecursive(xml, xright, tright);
464 }
Node * filterEvent(Event *e)
Definition: Tree.cc:202
Double_t getFitValue()
Definition: Node.cc:155
int i
Definition: DBlmapReader.cc:9
void filterEventsRecursive(Node *node)
Definition: Tree.cc:184
Definition: Node.h:10
Int_t getNumEvents()
Definition: Node.cc:189
Node * getLeftDaughter()
Definition: Node.cc:99
void rankVariablesRecursive(Node *node, std::vector< Double_t > &v)
Definition: Tree.cc:229
void loadFromXML(const char *filename)
Definition: Tree.cc:380
void setSplitValue(Double_t sSplitValue)
Definition: Node.cc:128
~Tree()
Definition: Tree.cc:47
void setTerminalNodes(std::list< Node * > &sTNodes)
Definition: Tree.cc:70
string xmldoc
Some module&#39;s global variables.
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:392
Definition: Event.h:16
std::list< Node * > terminalNodes
Definition: Tree.h:54
void buildTree(Int_t nodeLimit)
Definition: Tree.cc:107
Double_t getSplitValue()
Definition: Node.cc:133
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:358
int np
Definition: AMPTWrapper.h:33
void theMiracleOfChildBirth()
Definition: Node.cc:332
void getSplitValues(std::vector< std::vector< Double_t >> &v)
Definition: Tree.cc:300
std::string getName()
Definition: Node.cc:75
Node * getRightDaughter()
Definition: Node.cc:109
T sqrt(T t)
Definition: SSEVec.h:18
void getSplitValuesRecursive(Node *node, std::vector< std::vector< Double_t >> &v)
Definition: Tree.cc:270
void setRootNode(Node *sRootNode)
Definition: Tree.cc:58
Int_t getSplitVariable()
Definition: Node.cc:143
void filterEventsToDaughters()
Definition: Node.cc:347
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:214
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:196
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:201
Double_t getErrorReduction()
Definition: Node.cc:87
Node * getRootNode()
Definition: Tree.cc:63
Tree()
Definition: Tree.cc:26
void setFitValue(Double_t sFitValue)
Definition: Node.cc:150
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:75
Double_t rmsError
Definition: Tree.h:56
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:321
void filterEvents(std::vector< Event * > &tEvents)
Definition: Tree.cc:168
Int_t numTerminalNodes
Definition: Tree.h:55
tuple filename
Definition: lut2db_cfg.py:20
void setSplitVariable(Int_t sSplitVar)
Definition: Node.cc:138
tuple cout
Definition: gather_cfg.py:145
Node * rootNode
Definition: Tree.h:53
void calcError()
Definition: Tree.cc:91
void rankVariables(std::vector< Double_t > &v)
Definition: Tree.cc:262
Definition: sp.h:21
edm::TrieNode< PDet > Node
void saveToXML(const char *filename)
Definition: Tree.cc:332
std::string numToStr(T num)
Definition: Utilities.h:43
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:406
long double T
void calcOptimumSplit()
Definition: Node.cc:211
Int_t getNumTerminalNodes()
Definition: Tree.cc:82