test
CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
Tree.cc
Go to the documentation of this file.
1 // Tree.cxx //
3 // =====================================================================//
4 // This is the object implementation of a decision tree. //
5 // References include //
6 // *Elements of Statistical Learning by Hastie, //
7 // Tibshirani, and Friedman. //
8 // *Greedy Function Approximation: A Gradient Boosting Machine. //
9 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
10 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
11 // //
13 
15 // _______________________Includes_______________________________________//
17 
19 #include <iostream>
20 #include <sstream>
21 
23 // _______________________Constructor(s)________________________________//
25 
27 {
28  rootNode = new Node("root");
29 
30  terminalNodes.push_back(rootNode);
31  numTerminalNodes = 1;
32 }
33 
34 Tree::Tree(std::vector< std::vector<Event*> >& cEvents)
35 {
36  rootNode = new Node("root");
37  rootNode->setEvents(cEvents);
38 
39  terminalNodes.push_back(rootNode);
40  numTerminalNodes = 1;
41 }
43 // _______________________Destructor____________________________________//
45 
47 {
48  // When the tree is destroyed it will delete all of the nodes in the tree.
49 // The deletion begins with the rootnode and continues recursively.
50  delete rootNode;
51 }
52 
54 // ______________________Get/Set________________________________________//
56 
57 void Tree::setRootNode(Node *sRootNode)
58 {
59  rootNode = sRootNode;
60 }
61 
63 {
64  return rootNode;
65 }
66 
67 // ----------------------------------------------------------------------
68 
69 void Tree::setTerminalNodes(std::list<Node*>& sTNodes)
70 {
71  terminalNodes = sTNodes;
72 }
73 
74 std::list<Node*>& Tree::getTerminalNodes()
75 {
76  return terminalNodes;
77 }
78 
79 // ----------------------------------------------------------------------
80 
82 {
83  return numTerminalNodes;
84 }
85 
87 // ______________________Performace_____________________________________//
89 
91 {
92  // Loop through the separate predictive regions (terminal nodes) and
93  // add up the errors to get the error of the entire space.
94 
95  Double_t totalSquaredError = 0;
96 
97  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
98  {
99  totalSquaredError += (*it)->getTotalError();
100  }
101  rmsError = sqrt( totalSquaredError/rootNode->getNumEvents() );
102 }
103 
104 // ----------------------------------------------------------------------
105 
106 void Tree::buildTree(Int_t nodeLimit)
107 {
108  // We greedily pick the best terminal node to split.
109  Double_t bestNodeErrorReduction = -1;
110  Node* nodeToSplit = 0;
111 
112  if(numTerminalNodes == 1)
113  {
115  calcError();
116  // std::cout << std::endl << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
117  }
118 
119  for(std::list<Node*>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); it++)
120  {
121  if( (*it)->getErrorReduction() > bestNodeErrorReduction )
122  {
123  bestNodeErrorReduction = (*it)->getErrorReduction();
124  nodeToSplit = (*it);
125  }
126  }
127 
128  //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
129 
130  // If all of the nodes have one event we can't add any more nodes and reduce the error.
131  if(nodeToSplit == 0) return;
132 
133  // Create daughter nodes, and link the nodes together appropriately.
134  nodeToSplit->theMiracleOfChildBirth();
135 
136  // Get left and right daughters for reference.
137  Node* left = nodeToSplit->getLeftDaughter();
138  Node* right = nodeToSplit->getRightDaughter();
139 
140  // Update the list of terminal nodes.
141  terminalNodes.remove(nodeToSplit);
142  terminalNodes.push_back(left);
143  terminalNodes.push_back(right);
145 
146  // Filter the events from the parent into the daughters.
147  nodeToSplit->filterEventsToDaughters();
148 
149  // Calculate the best splits for the new nodes.
150  left->calcOptimumSplit();
151  right->calcOptimumSplit();
152 
153  // See if the error reduces as we add more nodes.
154  calcError();
155 
156  if(numTerminalNodes % 1 == 0)
157  {
158  // std::cout << " " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
159  }
160 
161  // Repeat until done.
162  if(numTerminalNodes < nodeLimit) buildTree(nodeLimit);
163 }
164 
165 // ----------------------------------------------------------------------
166 
167 void Tree::filterEvents(std::vector<Event*>& tEvents)
168 {
169  // Use trees which have already been built to fit a bunch of events
170  // given by the tEvents vector.
171 
172  // Set the events to be filtered.
173  rootNode->getEvents() = std::vector< std::vector<Event*> >(1);
174  rootNode->getEvents()[0] = tEvents;
175 
176  // The tree now knows about the events it needs to fit.
177  // Filter them into a predictive region (terminal node).
179 }
180 
181 // ----------------------------------------------------------------------
182 
184 {
185  // Filter the events repeatedly into the daughter nodes until they
186  // fall into a terminal node.
187 
188  Node* left = node->getLeftDaughter();
189  Node* right = node->getRightDaughter();
190 
191  if(left == 0 || right == 0) return;
192 
193  node->filterEventsToDaughters();
194 
195  filterEventsRecursive(left);
196  filterEventsRecursive(right);
197 }
198 
199 // ----------------------------------------------------------------------
200 
202 {
203  // Use trees which have already been built to fit a bunch of events
204  // given by the tEvents vector.
205 
206  // Filter the event into a predictive region (terminal node).
207  Node* node = filterEventRecursive(rootNode, e);
208  return node;
209 }
210 
211 // ----------------------------------------------------------------------
212 
214 {
215  // Filter the event repeatedly into the daughter nodes until it
216  // falls into a terminal node.
217 
218 
219  Node* nextNode = node->filterEventToDaughter(e);
220  if(nextNode == 0) return node;
221 
222  return filterEventRecursive(nextNode, e);
223 }
224 
225 // ----------------------------------------------------------------------
226 
227 
228 void Tree::rankVariablesRecursive(Node* node, std::vector<Double_t>& v)
229 {
230  // We recursively go through all of the nodes in the tree and find the
231  // total error reduction for each variable. The one with the most
232  // error reduction should be the most important.
233 
234  Node* left = node->getLeftDaughter();
235  Node* right = node->getRightDaughter();
236 
237  // Terminal nodes don't contribute to error reduction.
238  if(left==0 || right==0) return;
239 
240  Int_t sv = node->getSplitVariable();
241  Double_t er = node->getErrorReduction();
242 
243  //if(sv == -1)
244  //{
245  //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
246  //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
247  //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
248  //}
249 
250  // Add error reduction to the current total for the appropriate
251  // variable.
252  v[sv] += er;
253 
254  rankVariablesRecursive(left, v);
255  rankVariablesRecursive(right, v);
256 
257 }
258 
259 // ----------------------------------------------------------------------
260 
261 void Tree::rankVariables(std::vector<Double_t>& v)
262 {
264 }
265 
266 // ----------------------------------------------------------------------
267 
268 
269 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<Double_t>>& v)
270 {
271  // We recursively go through all of the nodes in the tree and find the
272  // split points used for each split variable.
273 
274  Node* left = node->getLeftDaughter();
275  Node* right = node->getRightDaughter();
276 
277  // Terminal nodes don't contribute.
278  if(left==0 || right==0) return;
279 
280  Int_t sv = node->getSplitVariable();
281  Double_t sp = node->getSplitValue();
282 
283  if(sv == -1)
284  {
285  std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
286  std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
287  }
288 
289  // Add the split point to the list for the correct split variable.
290  v[sv].push_back(sp);
291 
292  getSplitValuesRecursive(left, v);
293  getSplitValuesRecursive(right, v);
294 
295 }
296 
297 // ----------------------------------------------------------------------
298 
299 void Tree::getSplitValues(std::vector<std::vector<Double_t>>& v)
300 {
302 }
303 
305 // ______________________Storage/Retrieval______________________________//
307 
308 template <typename T>
310 {
311  // Convert a number to a string.
312  std::stringstream ss;
313  ss << num;
314  std::string s = ss.str();
315  return s;
316 }
317 
318 // ----------------------------------------------------------------------
319 
320 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
321 {
322  // Convert Node members into XML attributes
323  // and add them to the XMLEngine.
324  xml->NewAttr(np, 0, "splitVar", numToStr(node->getSplitVariable()).c_str());
325  xml->NewAttr(np, 0, "splitVal", numToStr(node->getSplitValue()).c_str());
326  xml->NewAttr(np, 0, "fitVal", numToStr(node->getFitValue()).c_str());
327 }
328 
329 // ----------------------------------------------------------------------
330 
331 void Tree::saveToXML(const char* c)
332 {
333 
334  TXMLEngine* xml = new TXMLEngine();
335 
336  // Add the root node.
337  XMLNodePointer_t root = xml->NewChild(0, 0, rootNode->getName().c_str());
338  addXMLAttributes(xml, rootNode, root);
339 
340  // Recursively write the tree to XML.
341  saveToXMLRecursive(xml, rootNode, root);
342 
343  // Make the XML Document.
344  XMLDocPointer_t xmldoc = xml->NewDoc();
345  xml->DocSetRootElement(xmldoc, root);
346 
347  // Save to file.
348  xml->SaveDoc(xmldoc, c);
349 
350  // Clean up.
351  xml->FreeDoc(xmldoc);
352  delete xml;
353 }
354 
355 // ----------------------------------------------------------------------
356 
357 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np)
358 {
359  Node* l = node->getLeftDaughter();
360  Node* r = node->getRightDaughter();
361 
362  if(l==0 || r==0) return;
363 
364  // Add children to the XMLEngine.
365  XMLNodePointer_t left = xml->NewChild(np, 0, "left");
366  XMLNodePointer_t right = xml->NewChild(np, 0, "right");
367 
368  // Add attributes to the children.
369  addXMLAttributes(xml, l, left);
370  addXMLAttributes(xml, r, right);
371 
372  // Recurse.
373  saveToXMLRecursive(xml, l, left);
374  saveToXMLRecursive(xml, r, right);
375 }
376 
377 // ----------------------------------------------------------------------
378 
379 void Tree::loadFromXML(const char* filename)
380 {
381  // First create the engine.
382  TXMLEngine* xml = new TXMLEngine;
383 
384  // Now try to parse xml file.
385  XMLDocPointer_t xmldoc = xml->ParseFile(filename);
386  if (xmldoc==0)
387  {
388  delete xml;
389  return;
390  }
391 
392  // Get access to main node of the xml file.
393  XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
394 
395  // Recursively connect nodes together.
396  loadFromXMLRecursive(xml, mainnode, rootNode);
397 
398  // Release memory before exit
399  xml->FreeDoc(xmldoc);
400  delete xml;
401 }
402 
403 // ----------------------------------------------------------------------
404 
405 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode)
406 {
407 
408  // Get the split information from xml.
409  XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
410  std::vector<std::string> splitInfo(3);
411  for(unsigned int i=0; i<3; i++)
412  {
413  splitInfo[i] = xml->GetAttrValue(attr);
414  attr = xml->GetNextAttr(attr);
415  }
416 
417  // Convert strings into numbers.
418  std::stringstream converter;
419  Int_t splitVar;
420  Double_t splitVal;
421  Double_t fitVal;
422 
423  converter << splitInfo[0];
424  converter >> splitVar;
425  converter.str("");
426  converter.clear();
427 
428  converter << splitInfo[1];
429  converter >> splitVal;
430  converter.str("");
431  converter.clear();
432 
433  converter << splitInfo[2];
434  converter >> fitVal;
435  converter.str("");
436  converter.clear();
437 
438  // Store gathered splitInfo into the node object.
439  tnode->setSplitVariable(splitVar);
440  tnode->setSplitValue(splitVal);
441  tnode->setFitValue(fitVal);
442 
443  // Get the xml daughters of the current xml node.
444  XMLNodePointer_t xleft = xml->GetChild(xnode);
445  XMLNodePointer_t xright = xml->GetNext(xleft);
446 
447  // If there are no daughters we are done.
448  if(xleft == 0 || xright == 0) return;
449 
450  // If there are daughters link the node objects appropriately.
451  tnode->theMiracleOfChildBirth();
452  Node* tleft = tnode->getLeftDaughter();
453  Node* tright = tnode->getRightDaughter();
454 
455  // Update the list of terminal nodes.
456  terminalNodes.remove(tnode);
457  terminalNodes.push_back(tleft);
458  terminalNodes.push_back(tright);
460 
461  loadFromXMLRecursive(xml, xleft, tleft);
462  loadFromXMLRecursive(xml, xright, tright);
463 }
Node * filterEvent(Event *e)
Definition: Tree.cc:201
Double_t getFitValue()
Definition: Node.cc:155
int i
Definition: DBlmapReader.cc:9
void filterEventsRecursive(Node *node)
Definition: Tree.cc:183
Definition: Node.h:10
Int_t getNumEvents()
Definition: Node.cc:189
Node * getLeftDaughter()
Definition: Node.cc:99
void rankVariablesRecursive(Node *node, std::vector< Double_t > &v)
Definition: Tree.cc:228
void loadFromXML(const char *filename)
Definition: Tree.cc:379
void setSplitValue(Double_t sSplitValue)
Definition: Node.cc:128
~Tree()
Definition: Tree.cc:46
void setTerminalNodes(std::list< Node * > &sTNodes)
Definition: Tree.cc:69
string xmldoc
Some module&#39;s global variables.
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:392
Definition: Event.h:16
std::list< Node * > terminalNodes
Definition: Tree.h:54
void buildTree(Int_t nodeLimit)
Definition: Tree.cc:106
Double_t getSplitValue()
Definition: Node.cc:133
void saveToXMLRecursive(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:357
int np
Definition: AMPTWrapper.h:33
void theMiracleOfChildBirth()
Definition: Node.cc:332
void getSplitValues(std::vector< std::vector< Double_t >> &v)
Definition: Tree.cc:299
std::string getName()
Definition: Node.cc:75
Node * getRightDaughter()
Definition: Node.cc:109
T sqrt(T t)
Definition: SSEVec.h:18
void getSplitValuesRecursive(Node *node, std::vector< std::vector< Double_t >> &v)
Definition: Tree.cc:269
void setRootNode(Node *sRootNode)
Definition: Tree.cc:57
Int_t getSplitVariable()
Definition: Node.cc:143
void filterEventsToDaughters()
Definition: Node.cc:347
Node * filterEventRecursive(Node *node, Event *e)
Definition: Tree.cc:213
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:196
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:201
Double_t getErrorReduction()
Definition: Node.cc:87
Node * getRootNode()
Definition: Tree.cc:62
Tree()
Definition: Tree.cc:26
void setFitValue(Double_t sFitValue)
Definition: Node.cc:150
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:74
Double_t rmsError
Definition: Tree.h:56
void addXMLAttributes(TXMLEngine *xml, Node *node, XMLNodePointer_t np)
Definition: Tree.cc:320
void filterEvents(std::vector< Event * > &tEvents)
Definition: Tree.cc:167
Int_t numTerminalNodes
Definition: Tree.h:55
tuple filename
Definition: lut2db_cfg.py:20
void setSplitVariable(Int_t sSplitVar)
Definition: Node.cc:138
tuple cout
Definition: gather_cfg.py:145
Node * rootNode
Definition: Tree.h:53
void calcError()
Definition: Tree.cc:90
void rankVariables(std::vector< Double_t > &v)
Definition: Tree.cc:261
Definition: sp.h:21
edm::TrieNode< PDet > Node
void saveToXML(const char *filename)
Definition: Tree.cc:331
std::string numToStr(T num)
Definition: Utilities.h:43
void loadFromXMLRecursive(TXMLEngine *xml, XMLNodePointer_t node, Node *tnode)
Definition: Tree.cc:405
long double T
void calcOptimumSplit()
Definition: Node.cc:211
Int_t getNumTerminalNodes()
Definition: Tree.cc:81