CMS 3D CMS Logo

Node.cc
Go to the documentation of this file.
1 // Node.cxx //
3 // =====================================================================//
4 // This is the object implementation of a node, which is the //
5 // fundamental unit of a decision tree. //
6 // References include //
7 // *Elements of Statistical Learning by Hastie, //
8 // Tibshirani, and Friedman. //
9 // *Greedy Function Approximation: A Gradient Boosting Machine. //
10 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
11 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
12 // //
14 
16 // _______________________Includes_______________________________________//
18 
20 
21 #include "TRandom3.h"
22 #include "TStopwatch.h"
23 #include <iostream>
24 #include <fstream>
25 
27 // _______________________Constructor(s)________________________________//
29 
30 using namespace emtf;
31 
33  name = "";
34  leftDaughter = nullptr;
35  rightDaughter = nullptr;
36  parent = nullptr;
37  splitValue = -99;
38  splitVariable = -1;
39  avgError = -1;
40  totalError = -1;
41  errorReduction = -1;
42 }
43 
45  name = cName;
46  leftDaughter = nullptr;
47  rightDaughter = nullptr;
48  parent = nullptr;
49  splitValue = -99;
50  splitVariable = -1;
51  avgError = -1;
52  totalError = -1;
53  errorReduction = -1;
54 }
55 
57 // _______________________Destructor____________________________________//
59 
61  // Recursively delete all nodes in the tree.
62  if (leftDaughter)
63  delete leftDaughter;
64  if (rightDaughter)
65  delete rightDaughter;
66 }
67 
69 // ______________________Get/Set________________________________________//
71 
72 void Node::setName(std::string sName) { name = sName; }
73 
75 
76 // ----------------------------------------------------------------------
77 
78 void Node::setErrorReduction(double sErrorReduction) { errorReduction = sErrorReduction; }
79 
81 
82 // ----------------------------------------------------------------------
83 
84 void Node::setLeftDaughter(Node* sLeftDaughter) { leftDaughter = sLeftDaughter; }
85 
87 
88 void Node::setRightDaughter(Node* sRightDaughter) { rightDaughter = sRightDaughter; }
89 
91 
92 // ----------------------------------------------------------------------
93 
94 void Node::setParent(Node* sParent) { parent = sParent; }
95 
96 Node* Node::getParent() { return parent; }
97 
98 // ----------------------------------------------------------------------
99 
100 void Node::setSplitValue(double sSplitValue) { splitValue = sSplitValue; }
101 
102 double Node::getSplitValue() { return splitValue; }
103 
104 void Node::setSplitVariable(int sSplitVar) { splitVariable = sSplitVar; }
105 
107 
108 // ----------------------------------------------------------------------
109 
110 void Node::setFitValue(double sFitValue) { fitValue = sFitValue; }
111 
112 double Node::getFitValue() { return fitValue; }
113 
114 // ----------------------------------------------------------------------
115 
116 void Node::setTotalError(double sTotalError) { totalError = sTotalError; }
117 
118 double Node::getTotalError() { return totalError; }
119 
120 void Node::setAvgError(double sAvgError) { avgError = sAvgError; }
121 
122 double Node::getAvgError() { return avgError; }
123 
124 // ----------------------------------------------------------------------
125 
126 void Node::setNumEvents(int sNumEvents) { numEvents = sNumEvents; }
127 
128 int Node::getNumEvents() { return numEvents; }
129 
130 // ----------------------------------------------------------------------
131 
132 std::vector<std::vector<Event*> >& Node::getEvents() { return events; }
133 
134 void Node::setEvents(std::vector<std::vector<Event*> >& sEvents) {
135  events = sEvents;
136  numEvents = events[0].size();
137 }
138 
140 // ______________________Performace_Functions___________________________//
142 
144  // Determines the split variable and split point which would most reduce the error for the given node (region).
145  // In the process we calculate the fitValue and Error. The general aglorithm is based upon Luis Torgo's thesis.
146  // Check out the reference for a more in depth outline. This part is chapter 3.
147 
148  // Intialize some variables.
149  double bestSplitValue = 0;
150  int bestSplitVariable = -1;
151  double bestErrorReduction = -1;
152 
153  double SUM = 0;
154  double SSUM = 0;
155  numEvents = events[0].size();
156 
157  double candidateErrorReduction = -1;
158 
159  // Calculate the sum of the target variables and the sum of
160  // the target variables squared. We use these later.
161  for (unsigned int i = 0; i < events[0].size(); i++) {
162  double target = events[0][i]->data[0];
163  SUM += target;
164  SSUM += target * target;
165  }
166 
167  unsigned int numVars = events.size();
168 
169  // Calculate the best split point for each variable
170  for (unsigned int variableToCheck = 1; variableToCheck < numVars; variableToCheck++) {
171  // The sum of the target variables in the left, right nodes
172  double SUMleft = 0;
173  double SUMright = SUM;
174 
175  // The number of events in the left, right nodes
176  int nleft = 1;
177  int nright = events[variableToCheck].size() - 1;
178 
179  int candidateSplitVariable = variableToCheck;
180 
181  std::vector<Event*>& v = events[variableToCheck];
182 
183  // Find the best split point for this variable
184  for (unsigned int i = 1; i < v.size(); i++) {
185  // As the candidate split point interates, the number of events in the
186  // left/right node increases/decreases and SUMleft/right increases/decreases.
187 
188  SUMleft = SUMleft + v[i - 1]->data[0];
189  SUMright = SUMright - v[i - 1]->data[0];
190 
191  // No need to check the split point if x on both sides is equal
192  if (v[i - 1]->data[candidateSplitVariable] < v[i]->data[candidateSplitVariable]) {
193  // Finding the maximum error reduction for Least Squares boils down to maximizing
194  // the following statement.
195  candidateErrorReduction = SUMleft * SUMleft / nleft + SUMright * SUMright / nright - SUM * SUM / numEvents;
196  // std::cout << "candidateErrorReduction= " << candidateErrorReduction << std::endl << std::endl;
197 
198  // if the new candidate is better than the current best, then we have a new overall best.
199  if (candidateErrorReduction > bestErrorReduction) {
200  bestErrorReduction = candidateErrorReduction;
201  bestSplitValue = (v[i - 1]->data[candidateSplitVariable] + v[i]->data[candidateSplitVariable]) / 2;
202  bestSplitVariable = candidateSplitVariable;
203  }
204  }
205 
206  nright = nright - 1;
207  nleft = nleft + 1;
208  }
209  }
210 
211  // Store the information gained from our computations.
212 
213  // The fit value is the average for least squares.
214  fitValue = SUM / numEvents;
215  // std::cout << "fitValue= " << fitValue << std::endl;
216 
217  // n*[ <y^2>-k^2 ]
218  totalError = SSUM - SUM * SUM / numEvents;
219  // std::cout << "totalError= " << totalError << std::endl;
220 
221  // [ <y^2>-k^2 ]
223  // std::cout << "avgError= " << avgError << std::endl;
224 
225  errorReduction = bestErrorReduction;
226  // std::cout << "errorReduction= " << errorReduction << std::endl;
227 
228  splitVariable = bestSplitVariable;
229  // std::cout << "splitVariable= " << splitVariable << std::endl;
230 
231  splitValue = bestSplitValue;
232  // std::cout << "splitValue= " << splitValue << std::endl;
233 
234  //if(bestSplitVariable == -1) std::cout << "splitVar = -1. numEvents = " << numEvents << ". errRed = " << errorReduction << std::endl;
235 }
236 
237 // ----------------------------------------------------------------------
238 
240  std::cout << std::endl << "Listing Events... " << std::endl;
241 
242  for (unsigned int i = 0; i < events.size(); i++) {
243  std::cout << std::endl << "Variable " << i << " vector contents: " << std::endl;
244  for (unsigned int j = 0; j < events[i].size(); j++) {
245  events[i][j]->outputEvent();
246  }
247  std::cout << std::endl;
248  }
249 }
250 
251 // ----------------------------------------------------------------------
252 
254  // Create Daughter Nodes
255  Node* left = new Node(name + " left");
256  Node* right = new Node(name + " right");
257 
258  // Link the Nodes Appropriately
259  leftDaughter = left;
260  rightDaughter = right;
261  left->setParent(this);
262  right->setParent(this);
263 }
264 
265 // ----------------------------------------------------------------------
266 
268  // Keeping sorted copies of the event vectors allows us to save on
269  // computation time. That way we don't have to resort the events
270  // each time we calculate the splitpoint for a node. We sort them once.
271  // Every time we split a node, we simply filter them down correctly
272  // preserving the order. This way we have O(n) efficiency instead
273  // of O(nlogn) efficiency.
274 
275  // Anyways, this function takes events from the parent node
276  // and filters an event into the left or right daughter
277  // node depending on whether it is < or > the split point
278  // for the given split variable.
279 
280  unsigned int sv = splitVariable;
281  double sp = splitValue;
282 
283  Node* left = leftDaughter;
284  Node* right = rightDaughter;
285 
286  std::vector<std::vector<Event*> > l(events.size());
287  std::vector<std::vector<Event*> > r(events.size());
288 
289  for (unsigned int i = 0; i < events.size(); i++) {
290  for (unsigned int j = 0; j < events[i].size(); j++) {
291  Event* e = events[i][j];
292  // Prevent out-of-bounds access
293  if (sv >= e->data.size())
294  continue;
295  if (e->data[sv] < sp)
296  l[i].push_back(e);
297  if (e->data[sv] > sp)
298  r[i].push_back(e);
299  }
300  }
301 
302  events = std::vector<std::vector<Event*> >();
303 
304  left->getEvents().swap(l);
305  right->getEvents().swap(r);
306 
307  // Set the number of events in the node.
308  left->setNumEvents(left->getEvents()[0].size());
309  right->setNumEvents(right->getEvents()[0].size());
310 }
311 
312 // ----------------------------------------------------------------------
313 
315  // Anyways, this function takes an event from the parent node
316  // and filters an event into the left or right daughter
317  // node depending on whether it is < or > the split point
318  // for the given split variable.
319 
320  unsigned int sv = splitVariable;
321  double sp = splitValue;
322 
323  Node* left = leftDaughter;
324  Node* right = rightDaughter;
325  Node* nextNode = nullptr;
326 
327  // Prevent out-of-bounds access
328  if (left == nullptr || right == nullptr || sv >= e->data.size())
329  return nullptr;
330 
331  if (e->data[sv] < sp)
332  nextNode = left;
333  if (e->data[sv] >= sp)
334  nextNode = right;
335 
336  return nextNode;
337 }
void setFitValue(double sFitValue)
Definition: Node.cc:110
Node * getRightDaughter()
Definition: Node.cc:90
Node * getLeftDaughter()
Definition: Node.cc:86
Node * leftDaughter
Definition: Node.h:67
double getFitValue()
Definition: Node.cc:112
Node()
Definition: Node.cc:32
int splitVariable
Definition: Node.h:72
Node * getParent()
Definition: Node.cc:96
void listEvents()
Definition: Node.cc:239
std::string name
Definition: Node.h:65
double fitValue
Definition: Node.h:78
Node * parent
Definition: Node.h:69
int numEvents
Definition: Node.h:79
Definition: Event.h:15
double errorReduction
Definition: Node.h:74
void setRightDaughter(Node *sLeftDaughter)
Definition: Node.cc:88
void setEvents(std::vector< std::vector< Event *> > &sEvents)
Definition: Node.cc:134
void setTotalError(double sTotalError)
Definition: Node.cc:116
double splitValue
Definition: Node.h:71
double getTotalError()
Definition: Node.cc:118
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:314
void setErrorReduction(double sErrorReduction)
Definition: Node.cc:78
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:132
void setAvgError(double sAvgError)
Definition: Node.cc:120
double getSplitValue()
Definition: Node.cc:102
Node * rightDaughter
Definition: Node.h:68
std::string getName()
Definition: Node.cc:74
std::vector< std::vector< Event * > > events
Definition: Node.h:81
double avgError
Definition: Node.h:76
void setParent(Node *sParent)
Definition: Node.cc:94
double totalError
Definition: Node.h:75
void setName(std::string sName)
Definition: Node.cc:72
#define SUM(A, B)
void setSplitVariable(int sSplitVar)
Definition: Node.cc:104
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:79
int getNumEvents()
Definition: Node.cc:128
void setNumEvents(int sNumEvents)
Definition: Node.cc:126
~Node()
Definition: Node.cc:60
double getErrorReduction()
Definition: Node.cc:80
void filterEventsToDaughters()
Definition: Node.cc:267
double getAvgError()
Definition: Node.cc:122
int getSplitVariable()
Definition: Node.cc:106
void calcOptimumSplit()
Definition: Node.cc:143
void setSplitValue(double sSplitValue)
Definition: Node.cc:100
void setLeftDaughter(Node *sLeftDaughter)
Definition: Node.cc:84
void theMiracleOfChildBirth()
Definition: Node.cc:253