CMS 3D CMS Logo

Node.cc
Go to the documentation of this file.
1 // Node.cxx //
3 // =====================================================================//
4 // This is the object implementation of a node, which is the //
5 // fundamental unit of a decision tree. //
6 // References include //
7 // *Elements of Statistical Learning by Hastie, //
8 // Tibshirani, and Friedman. //
9 // *Greedy Function Approximation: A Gradient Boosting Machine. //
10 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
11 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
12 // //
14 
16 // _______________________Includes_______________________________________//
18 
20 
21 #include "TRandom3.h"
22 #include "TStopwatch.h"
23 #include <iostream>
24 #include <fstream>
25 
27 // _______________________Constructor(s)________________________________//
29 
30 using namespace emtf;
31 
33 {
34  name = "";
35  leftDaughter = nullptr;
36  rightDaughter = nullptr;
37  parent = nullptr;
38  splitValue = -99;
39  splitVariable = -1;
40  avgError = -1;
41  totalError = -1;
42  errorReduction = -1;
43 }
44 
46 {
47  name = cName;
48  leftDaughter = nullptr;
49  rightDaughter = nullptr;
50  parent = nullptr;
51  splitValue = -99;
52  splitVariable = -1;
53  avgError = -1;
54  totalError = -1;
55  errorReduction = -1;
56 }
57 
59 // _______________________Destructor____________________________________//
61 
63 {
64 // Recursively delete all nodes in the tree.
65  if(leftDaughter) delete leftDaughter;
66  if(rightDaughter) delete rightDaughter;
67 }
68 
70 // ______________________Get/Set________________________________________//
72 
74 {
75  name = sName;
76 }
77 
79 {
80  return name;
81 }
82 
83 // ----------------------------------------------------------------------
84 
85 void Node::setErrorReduction(double sErrorReduction)
86 {
87  errorReduction = sErrorReduction;
88 }
89 
91 {
92  return errorReduction;
93 }
94 
95 // ----------------------------------------------------------------------
96 
97 void Node::setLeftDaughter(Node *sLeftDaughter)
98 {
99  leftDaughter = sLeftDaughter;
100 }
101 
103 {
104  return leftDaughter;
105 }
106 
107 void Node::setRightDaughter(Node *sRightDaughter)
108 {
109  rightDaughter = sRightDaughter;
110 }
111 
113 {
114  return rightDaughter;
115 }
116 
117 // ----------------------------------------------------------------------
118 
119 void Node::setParent(Node *sParent)
120 {
121  parent = sParent;
122 }
123 
125 {
126  return parent;
127 }
128 
129 // ----------------------------------------------------------------------
130 
131 void Node::setSplitValue(double sSplitValue)
132 {
133  splitValue = sSplitValue;
134 }
135 
137 {
138  return splitValue;
139 }
140 
141 void Node::setSplitVariable(int sSplitVar)
142 {
143  splitVariable = sSplitVar;
144 }
145 
147 {
148  return splitVariable;
149 }
150 
151 // ----------------------------------------------------------------------
152 
153 void Node::setFitValue(double sFitValue)
154 {
155  fitValue = sFitValue;
156 }
157 
159 {
160  return fitValue;
161 }
162 
163 // ----------------------------------------------------------------------
164 
165 void Node::setTotalError(double sTotalError)
166 {
167  totalError = sTotalError;
168 }
169 
171 {
172  return totalError;
173 }
174 
175 void Node::setAvgError(double sAvgError)
176 {
177  avgError = sAvgError;
178 }
179 
181 {
182  return avgError;
183 }
184 
185 // ----------------------------------------------------------------------
186 
187 void Node::setNumEvents(int sNumEvents)
188 {
189  numEvents = sNumEvents;
190 }
191 
193 {
194  return numEvents;
195 }
196 
197 // ----------------------------------------------------------------------
198 
199 std::vector< std::vector<Event*> >& Node::getEvents()
200 {
201  return events;
202 }
203 
204 void Node::setEvents(std::vector< std::vector<Event*> >& sEvents)
205 {
206  events = sEvents;
207  numEvents = events[0].size();
208 }
209 
211 // ______________________Performace_Functions___________________________//
213 
215 {
216 // Determines the split variable and split point which would most reduce the error for the given node (region).
217 // In the process we calculate the fitValue and Error. The general aglorithm is based upon Luis Torgo's thesis.
218 // Check out the reference for a more in depth outline. This part is chapter 3.
219 
220  // Intialize some variables.
221  double bestSplitValue = 0;
222  int bestSplitVariable = -1;
223  double bestErrorReduction = -1;
224 
225  double SUM = 0;
226  double SSUM = 0;
227  numEvents = events[0].size();
228 
229  double candidateErrorReduction = -1;
230 
231  // Calculate the sum of the target variables and the sum of
232  // the target variables squared. We use these later.
233  for(unsigned int i=0; i<events[0].size(); i++)
234  {
235  double target = events[0][i]->data[0];
236  SUM += target;
237  SSUM += target*target;
238  }
239 
240  unsigned int numVars = events.size();
241 
242  // Calculate the best split point for each variable
243  for(unsigned int variableToCheck = 1; variableToCheck < numVars; variableToCheck++)
244  {
245 
246  // The sum of the target variables in the left, right nodes
247  double SUMleft = 0;
248  double SUMright = SUM;
249 
250  // The number of events in the left, right nodes
251  int nleft = 1;
252  int nright = events[variableToCheck].size()-1;
253 
254  int candidateSplitVariable = variableToCheck;
255 
256  std::vector<Event*>& v = events[variableToCheck];
257 
258  // Find the best split point for this variable
259  for(unsigned int i=1; i<v.size(); i++)
260  {
261  // As the candidate split point interates, the number of events in the
262  // left/right node increases/decreases and SUMleft/right increases/decreases.
263 
264  SUMleft = SUMleft + v[i-1]->data[0];
265  SUMright = SUMright - v[i-1]->data[0];
266 
267  // No need to check the split point if x on both sides is equal
268  if(v[i-1]->data[candidateSplitVariable] < v[i]->data[candidateSplitVariable])
269  {
270  // Finding the maximum error reduction for Least Squares boils down to maximizing
271  // the following statement.
272  candidateErrorReduction = SUMleft*SUMleft/nleft + SUMright*SUMright/nright - SUM*SUM/numEvents;
273 // std::cout << "candidateErrorReduction= " << candidateErrorReduction << std::endl << std::endl;
274 
275  // if the new candidate is better than the current best, then we have a new overall best.
276  if(candidateErrorReduction > bestErrorReduction)
277  {
278  bestErrorReduction = candidateErrorReduction;
279  bestSplitValue = (v[i-1]->data[candidateSplitVariable] + v[i]->data[candidateSplitVariable])/2;
280  bestSplitVariable = candidateSplitVariable;
281  }
282  }
283 
284  nright = nright-1;
285  nleft = nleft+1;
286  }
287  }
288 
289  // Store the information gained from our computations.
290 
291  // The fit value is the average for least squares.
292  fitValue = SUM/numEvents;
293 // std::cout << "fitValue= " << fitValue << std::endl;
294 
295  // n*[ <y^2>-k^2 ]
296  totalError = SSUM - SUM*SUM/numEvents;
297 // std::cout << "totalError= " << totalError << std::endl;
298 
299  // [ <y^2>-k^2 ]
301 // std::cout << "avgError= " << avgError << std::endl;
302 
303 
304  errorReduction = bestErrorReduction;
305 // std::cout << "errorReduction= " << errorReduction << std::endl;
306 
307  splitVariable = bestSplitVariable;
308 // std::cout << "splitVariable= " << splitVariable << std::endl;
309 
310  splitValue = bestSplitValue;
311 // std::cout << "splitValue= " << splitValue << std::endl;
312 
313  //if(bestSplitVariable == -1) std::cout << "splitVar = -1. numEvents = " << numEvents << ". errRed = " << errorReduction << std::endl;
314 }
315 
316 // ----------------------------------------------------------------------
317 
319 {
320  std::cout << std::endl << "Listing Events... " << std::endl;
321 
322  for(unsigned int i=0; i < events.size(); i++)
323  {
324  std::cout << std::endl << "Variable " << i << " vector contents: " << std::endl;
325  for(unsigned int j=0; j < events[i].size(); j++)
326  {
327  events[i][j]->outputEvent();
328  }
329  std::cout << std::endl;
330  }
331 }
332 
333 // ----------------------------------------------------------------------
334 
336 {
337  // Create Daughter Nodes
338  Node* left = new Node(name + " left");
339  Node* right = new Node(name + " right");
340 
341  // Link the Nodes Appropriately
342  leftDaughter = left;
343  rightDaughter = right;
344  left->setParent(this);
345  right->setParent(this);
346 }
347 
348 // ----------------------------------------------------------------------
349 
351 {
352 // Keeping sorted copies of the event vectors allows us to save on
353 // computation time. That way we don't have to resort the events
354 // each time we calculate the splitpoint for a node. We sort them once.
355 // Every time we split a node, we simply filter them down correctly
356 // preserving the order. This way we have O(n) efficiency instead
357 // of O(nlogn) efficiency.
358 
359 // Anyways, this function takes events from the parent node
360 // and filters an event into the left or right daughter
361 // node depending on whether it is < or > the split point
362 // for the given split variable.
363 
364  int sv = splitVariable;
365  double sp = splitValue;
366 
367  Node* left = leftDaughter;
368  Node* right = rightDaughter;
369 
370  std::vector< std::vector<Event*> > l(events.size());
371  std::vector< std::vector<Event*> > r(events.size());
372 
373  for(unsigned int i=0; i<events.size(); i++)
374  {
375  for(unsigned int j=0; j<events[i].size(); j++)
376  {
377  Event* e = events[i][j];
378  if(e->data[sv] < sp) l[i].push_back(e);
379  if(e->data[sv] > sp) r[i].push_back(e);
380  }
381  }
382 
383  events = std::vector< std::vector<Event*> >();
384 
385  left->getEvents().swap(l);
386  right->getEvents().swap(r);
387 
388  // Set the number of events in the node.
389  left->setNumEvents(left->getEvents()[0].size());
390  right->setNumEvents(right->getEvents()[0].size());
391 }
392 
393 // ----------------------------------------------------------------------
394 
396 {
397 // Anyways, this function takes an event from the parent node
398 // and filters an event into the left or right daughter
399 // node depending on whether it is < or > the split point
400 // for the given split variable.
401 
402  int sv = splitVariable;
403  double sp = splitValue;
404 
405  Node* left = leftDaughter;
406  Node* right = rightDaughter;
407  Node* nextNode = nullptr;
408 
409  if(left ==nullptr || right ==nullptr) return nullptr;
410 
411  if(e->data[sv] < sp) nextNode = left;
412  if(e->data[sv] >= sp) nextNode = right;
413 
414  return nextNode;
415 }
void setFitValue(double sFitValue)
Definition: Node.cc:153
Node * getRightDaughter()
Definition: Node.cc:112
Node * getLeftDaughter()
Definition: Node.cc:102
Node * leftDaughter
Definition: Node.h:69
double getFitValue()
Definition: Node.cc:158
Node()
Definition: Node.cc:32
int splitVariable
Definition: Node.h:74
Node * getParent()
Definition: Node.cc:124
void listEvents()
Definition: Node.cc:318
std::string name
Definition: Node.h:67
double fitValue
Definition: Node.h:80
Node * parent
Definition: Node.h:71
int numEvents
Definition: Node.h:81
Definition: Event.h:15
double errorReduction
Definition: Node.h:76
void setRightDaughter(Node *sLeftDaughter)
Definition: Node.cc:107
std::vector< std::vector< Event * > > events
Definition: Node.h:83
void setTotalError(double sTotalError)
Definition: Node.cc:165
double splitValue
Definition: Node.h:73
double getTotalError()
Definition: Node.cc:170
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:395
void setErrorReduction(double sErrorReduction)
Definition: Node.cc:85
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:204
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:199
void setAvgError(double sAvgError)
Definition: Node.cc:175
double getSplitValue()
Definition: Node.cc:136
Node * rightDaughter
Definition: Node.h:70
std::string getName()
Definition: Node.cc:78
double avgError
Definition: Node.h:78
void setParent(Node *sParent)
Definition: Node.cc:119
double totalError
Definition: Node.h:77
void setName(std::string sName)
Definition: Node.cc:73
#define SUM(A, B)
void setSplitVariable(int sSplitVar)
Definition: Node.cc:141
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:82
int getNumEvents()
Definition: Node.cc:192
void setNumEvents(int sNumEvents)
Definition: Node.cc:187
~Node()
Definition: Node.cc:62
double getErrorReduction()
Definition: Node.cc:90
void filterEventsToDaughters()
Definition: Node.cc:350
double getAvgError()
Definition: Node.cc:180
int getSplitVariable()
Definition: Node.cc:146
void calcOptimumSplit()
Definition: Node.cc:214
std::vector< double > data
Definition: Event.h:31
void setSplitValue(double sSplitValue)
Definition: Node.cc:131
void setLeftDaughter(Node *sLeftDaughter)
Definition: Node.cc:97
void theMiracleOfChildBirth()
Definition: Node.cc:335