CMS 3D CMS Logo

Node.cc
Go to the documentation of this file.
1 // Node.cxx //
3 // =====================================================================//
4 // This is the object implementation of a node, which is the //
5 // fundamental unit of a decision tree. //
6 // References include //
7 // *Elements of Statistical Learning by Hastie, //
8 // Tibshirani, and Friedman. //
9 // *Greedy Function Approximation: A Gradient Boosting Machine. //
10 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
11 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
12 // //
14 
16 // _______________________Includes_______________________________________//
18 
20 #include "TRandom3.h"
21 #include "TStopwatch.h"
22 #include <iostream>
23 #include <fstream>
24 
26 // _______________________Constructor(s)________________________________//
28 
29 using namespace emtf;
30 
32 {
33  name = "";
34  leftDaughter = 0;
35  rightDaughter = 0;
36  parent = 0;
37  splitValue = -99;
38  splitVariable = -1;
39  avgError = -1;
40  totalError = -1;
41  errorReduction = -1;
42 }
43 
45 {
46  name = cName;
47  leftDaughter = 0;
48  rightDaughter = 0;
49  parent = 0;
50  splitValue = -99;
51  splitVariable = -1;
52  avgError = -1;
53  totalError = -1;
54  errorReduction = -1;
55 }
56 
58 // _______________________Destructor____________________________________//
60 
62 {
63  // Recursively delete all nodes in the tree.
64  delete leftDaughter;
65  delete rightDaughter;
66 }
67 
69 // ______________________Get/Set________________________________________//
71 
73 {
74  name = sName;
75 }
76 
78 {
79  return name;
80 }
81 
82 // ----------------------------------------------------------------------
83 
84 void Node::setErrorReduction(Double_t sErrorReduction)
85 {
86  errorReduction = sErrorReduction;
87 }
88 
90 {
91  return errorReduction;
92 }
93 
94 // ----------------------------------------------------------------------
95 
96 void Node::setLeftDaughter(Node *sLeftDaughter)
97 {
98  leftDaughter = sLeftDaughter;
99 }
100 
102 {
103  return leftDaughter;
104 }
105 
106 void Node::setRightDaughter(Node *sRightDaughter)
107 {
108  rightDaughter = sRightDaughter;
109 }
110 
112 {
113  return rightDaughter;
114 }
115 
116 // ----------------------------------------------------------------------
117 
118 void Node::setParent(Node *sParent)
119 {
120  parent = sParent;
121 }
122 
124 {
125  return parent;
126 }
127 
128 // ----------------------------------------------------------------------
129 
130 void Node::setSplitValue(Double_t sSplitValue)
131 {
132  splitValue = sSplitValue;
133 }
134 
136 {
137  return splitValue;
138 }
139 
140 void Node::setSplitVariable(Int_t sSplitVar)
141 {
142  splitVariable = sSplitVar;
143 }
144 
146 {
147  return splitVariable;
148 }
149 
150 // ----------------------------------------------------------------------
151 
152 void Node::setFitValue(Double_t sFitValue)
153 {
154  fitValue = sFitValue;
155 }
156 
158 {
159  return fitValue;
160 }
161 
162 // ----------------------------------------------------------------------
163 
164 void Node::setTotalError(Double_t sTotalError)
165 {
166  totalError = sTotalError;
167 }
168 
170 {
171  return totalError;
172 }
173 
174 void Node::setAvgError(Double_t sAvgError)
175 {
176  avgError = sAvgError;
177 }
178 
180 {
181  return avgError;
182 }
183 
184 // ----------------------------------------------------------------------
185 
186 void Node::setNumEvents(Int_t sNumEvents)
187 {
188  numEvents = sNumEvents;
189 }
190 
192 {
193  return numEvents;
194 }
195 
196 // ----------------------------------------------------------------------
197 
198 std::vector< std::vector<Event*> >& Node::getEvents()
199 {
200  return events;
201 }
202 
203 void Node::setEvents(std::vector< std::vector<Event*> >& sEvents)
204 {
205  events = sEvents;
206  numEvents = events[0].size();
207 }
208 
210 // ______________________Performace_Functions___________________________//
212 
214 {
215  // Determines the split variable and split point which would most reduce the error for the given node (region).
216  // In the process we calculate the fitValue and Error. The general aglorithm is based upon Luis Torgo's thesis.
217  // Check out the reference for a more in depth outline. This part is chapter 3.
218 
219  // Intialize some variables.
220  Double_t bestSplitValue = 0;
221  Int_t bestSplitVariable = -1;
222  Double_t bestErrorReduction = -1;
223 
224  Double_t SUM = 0;
225  Double_t SSUM = 0;
226  numEvents = events[0].size();
227 
228  Double_t candidateErrorReduction = -1;
229 
230  // Calculate the sum of the target variables and the sum of
231  // the target variables squared. We use these later.
232  for(unsigned int i=0; i<events[0].size(); i++)
233  {
234  Double_t target = events[0][i]->data[0];
235  SUM += target;
236  SSUM += target*target;
237  }
238 
239  unsigned int numVars = events.size();
240 
241  // Calculate the best split point for each variable
242  for(unsigned int variableToCheck = 1; variableToCheck < numVars; variableToCheck++)
243  {
244 
245  // The sum of the target variables in the left, right nodes
246  Double_t SUMleft = 0;
247  Double_t SUMright = SUM;
248 
249  // The number of events in the left, right nodes
250  Int_t nleft = 1;
251  Int_t nright = events[variableToCheck].size()-1;
252 
253  Int_t candidateSplitVariable = variableToCheck;
254 
255  std::vector<Event*>& v = events[variableToCheck];
256 
257  // Find the best split point for this variable
258  for(unsigned int i=1; i<v.size(); i++)
259  {
260  // As the candidate split point interates, the number of events in the
261  // left/right node increases/decreases and SUMleft/right increases/decreases.
262 
263  SUMleft = SUMleft + v[i-1]->data[0];
264  SUMright = SUMright - v[i-1]->data[0];
265 
266  // No need to check the split point if x on both sides is equal
267  if(v[i-1]->data[candidateSplitVariable] < v[i]->data[candidateSplitVariable])
268  {
269  // Finding the maximum error reduction for Least Squares boils down to maximizing
270  // the following statement.
271  candidateErrorReduction = SUMleft*SUMleft/nleft + SUMright*SUMright/nright - SUM*SUM/numEvents;
272  // std::cout << "candidateErrorReduction= " << candidateErrorReduction << std::endl << std::endl;
273 
274  // if the new candidate is better than the current best, then we have a new overall best.
275  if(candidateErrorReduction > bestErrorReduction)
276  {
277  bestErrorReduction = candidateErrorReduction;
278  bestSplitValue = (v[i-1]->data[candidateSplitVariable] + v[i]->data[candidateSplitVariable])/2;
279  bestSplitVariable = candidateSplitVariable;
280  }
281  }
282 
283  nright = nright-1;
284  nleft = nleft+1;
285  }
286  }
287 
288  // Store the information gained from our computations.
289 
290  // The fit value is the average for least squares.
291  fitValue = SUM/numEvents;
292  // std::cout << "fitValue= " << fitValue << std::endl;
293 
294  // n*[ <y^2>-k^2 ]
295  totalError = SSUM - SUM*SUM/numEvents;
296  // std::cout << "totalError= " << totalError << std::endl;
297 
298  // [ <y^2>-k^2 ]
300  // std::cout << "avgError= " << avgError << std::endl;
301 
302 
303  errorReduction = bestErrorReduction;
304  // std::cout << "errorReduction= " << errorReduction << std::endl;
305 
306  splitVariable = bestSplitVariable;
307  // std::cout << "splitVariable= " << splitVariable << std::endl;
308 
309  splitValue = bestSplitValue;
310  // std::cout << "splitValue= " << splitValue << std::endl;
311 
312  //if(bestSplitVariable == -1) std::cout << "splitVar = -1. numEvents = " << numEvents << ". errRed = " << errorReduction << std::endl;
313 }
314 
315 // ----------------------------------------------------------------------
316 
318 {
319  std::cout << std::endl << "Listing Events... " << std::endl;
320 
321  for(unsigned int i=0; i < events.size(); i++)
322  {
323  std::cout << std::endl << "Variable " << i << " vector contents: " << std::endl;
324  for(unsigned int j=0; j < events[i].size(); j++)
325  {
326  events[i][j]->outputEvent();
327  }
328  std::cout << std::endl;
329  }
330 }
331 
332 // ----------------------------------------------------------------------
333 
335 {
336  // Create Daughter Nodes
337  Node* left = new Node(name + " left");
338  Node* right = new Node(name + " right");
339 
340  // Link the Nodes Appropriately
341  leftDaughter = left;
342  rightDaughter = right;
343  left->setParent(this);
344  right->setParent(this);
345 }
346 
347 // ----------------------------------------------------------------------
348 
350 {
351  // Keeping sorted copies of the event vectors allows us to save on
352  // computation time. That way we don't have to resort the events
353  // each time we calculate the splitpoint for a node. We sort them once.
354  // Every time we split a node, we simply filter them down correctly
355  // preserving the order. This way we have O(n) efficiency instead
356  // of O(nlogn) efficiency.
357 
358  // Anyways, this function takes events from the parent node
359  // and filters an event into the left or right daughter
360  // node depending on whether it is < or > the split point
361  // for the given split variable.
362 
363  Int_t sv = splitVariable;
364  Double_t sp = splitValue;
365 
366  Node* left = leftDaughter;
367  Node* right = rightDaughter;
368 
369  std::vector< std::vector<Event*> > l(events.size());
370  std::vector< std::vector<Event*> > r(events.size());
371 
372  for(unsigned int i=0; i<events.size(); i++)
373  {
374  for(unsigned int j=0; j<events[i].size(); j++)
375  {
376  Event* e = events[i][j];
377  if(e->data[sv] < sp) l[i].push_back(e);
378  if(e->data[sv] > sp) r[i].push_back(e);
379  }
380  }
381 
382  events = std::vector< std::vector<Event*> >();
383 
384  left->getEvents().swap(l);
385  right->getEvents().swap(r);
386 
387  // Set the number of events in the node.
388  left->setNumEvents(left->getEvents()[0].size());
389  right->setNumEvents(right->getEvents()[0].size());
390 }
391 
392 // ----------------------------------------------------------------------
393 
395 {
396  // Anyways, this function takes an event from the parent node
397  // and filters an event into the left or right daughter
398  // node depending on whether it is < or > the split point
399  // for the given split variable.
400 
401  Int_t sv = splitVariable;
402  Double_t sp = splitValue;
403 
404  Node* left = leftDaughter;
405  Node* right = rightDaughter;
406  Node* nextNode = 0;
407 
408  if(left ==0 || right ==0) return 0;
409 
410  if(e->data[sv] < sp) nextNode = left;
411  if(e->data[sv] > sp) nextNode = right;
412 
413  return nextNode;
414 }
Node * getRightDaughter()
Definition: Node.cc:111
Node * getLeftDaughter()
Definition: Node.cc:101
Node * leftDaughter
Definition: Node.h:63
Node()
Definition: Node.cc:31
Double_t getAvgError()
Definition: Node.cc:179
std::vector< std::vector< emtf::Event * > > events
Definition: Node.h:77
void setErrorReduction(Double_t sErrorReduction)
Definition: Node.cc:84
Int_t numEvents
Definition: Node.h:75
Node * getParent()
Definition: Node.cc:123
void listEvents()
Definition: Node.cc:317
std::string name
Definition: Node.h:61
Double_t getTotalError()
Definition: Node.cc:169
Node * parent
Definition: Node.h:65
Definition: Event.h:15
void setEvents(std::vector< std::vector< emtf::Event * > > &sEvents)
Definition: Node.cc:203
void setNumEvents(Int_t sNumEvents)
Definition: Node.cc:186
void setRightDaughter(Node *sLeftDaughter)
Definition: Node.cc:106
void setAvgError(Double_t sAvgError)
Definition: Node.cc:174
Double_t fitValue
Definition: Node.h:74
Double_t getErrorReduction()
Definition: Node.cc:89
std::vector< Double_t > data
Definition: Event.h:30
void setFitValue(Double_t sFitValue)
Definition: Node.cc:152
std::vector< std::vector< emtf::Event * > > & getEvents()
Definition: Node.cc:198
void setTotalError(Double_t sTotalError)
Definition: Node.cc:164
void setSplitVariable(Int_t sSplitVar)
Definition: Node.cc:140
Double_t totalError
Definition: Node.h:71
Double_t getSplitValue()
Definition: Node.cc:135
Node * rightDaughter
Definition: Node.h:64
std::string getName()
Definition: Node.cc:77
Int_t getNumEvents()
Definition: Node.cc:191
Double_t getFitValue()
Definition: Node.cc:157
void setParent(Node *sParent)
Definition: Node.cc:118
Double_t avgError
Definition: Node.h:72
Int_t getSplitVariable()
Definition: Node.cc:145
void setName(std::string sName)
Definition: Node.cc:72
#define SUM(A, B)
Int_t splitVariable
Definition: Node.h:68
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:82
void setSplitValue(Double_t sSplitValue)
Definition: Node.cc:130
~Node()
Definition: Node.cc:61
Node * filterEventToDaughter(emtf::Event *e)
Definition: Node.cc:394
Double_t errorReduction
Definition: Node.h:70
void filterEventsToDaughters()
Definition: Node.cc:349
void calcOptimumSplit()
Definition: Node.cc:213
Double_t splitValue
Definition: Node.h:67
void setLeftDaughter(Node *sLeftDaughter)
Definition: Node.cc:96
void theMiracleOfChildBirth()
Definition: Node.cc:334