CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
Node.cc
Go to the documentation of this file.
1 // Node.cxx //
3 // =====================================================================//
4 // This is the object implementation of a node, which is the //
5 // fundamental unit of a decision tree. //
6 // References include //
7 // *Elements of Statistical Learning by Hastie, //
8 // Tibshirani, and Friedman. //
9 // *Greedy Function Approximation: A Gradient Boosting Machine. //
10 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
11 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
12 // //
14 
16 // _______________________Includes_______________________________________//
18 
20 #include "TRandom3.h"
21 #include "TStopwatch.h"
22 #include <iostream>
23 #include <fstream>
24 
26 // _______________________Constructor(s)________________________________//
28 
30 {
31  name = "";
32  leftDaughter = 0;
33  rightDaughter = 0;
34  parent = 0;
35  splitValue = -99;
36  splitVariable = -1;
37  avgError = -1;
38  totalError = -1;
39  errorReduction = -1;
40 }
41 
43 {
44  name = cName;
45  leftDaughter = 0;
46  rightDaughter = 0;
47  parent = 0;
48  splitValue = -99;
49  splitVariable = -1;
50  avgError = -1;
51  totalError = -1;
52  errorReduction = -1;
53 }
54 
56 // _______________________Destructor____________________________________//
58 
60 {
61 // Recursively delete all nodes in the tree.
62  delete leftDaughter;
63  delete rightDaughter;
64 }
65 
67 // ______________________Get/Set________________________________________//
69 
71 {
72  name = sName;
73 }
74 
76 {
77  return name;
78 }
79 
80 // ----------------------------------------------------------------------
81 
82 void Node::setErrorReduction(Double_t sErrorReduction)
83 {
84  errorReduction = sErrorReduction;
85 }
86 
88 {
89  return errorReduction;
90 }
91 
92 // ----------------------------------------------------------------------
93 
94 void Node::setLeftDaughter(Node *sLeftDaughter)
95 {
96  leftDaughter = sLeftDaughter;
97 }
98 
100 {
101  return leftDaughter;
102 }
103 
104 void Node::setRightDaughter(Node *sRightDaughter)
105 {
106  rightDaughter = sRightDaughter;
107 }
108 
110 {
111  return rightDaughter;
112 }
113 
114 // ----------------------------------------------------------------------
115 
116 void Node::setParent(Node *sParent)
117 {
118  parent = sParent;
119 }
120 
122 {
123  return parent;
124 }
125 
126 // ----------------------------------------------------------------------
127 
128 void Node::setSplitValue(Double_t sSplitValue)
129 {
130  splitValue = sSplitValue;
131 }
132 
134 {
135  return splitValue;
136 }
137 
138 void Node::setSplitVariable(Int_t sSplitVar)
139 {
140  splitVariable = sSplitVar;
141 }
142 
144 {
145  return splitVariable;
146 }
147 
148 // ----------------------------------------------------------------------
149 
150 void Node::setFitValue(Double_t sFitValue)
151 {
152  fitValue = sFitValue;
153 }
154 
156 {
157  return fitValue;
158 }
159 
160 // ----------------------------------------------------------------------
161 
162 void Node::setTotalError(Double_t sTotalError)
163 {
164  totalError = sTotalError;
165 }
166 
168 {
169  return totalError;
170 }
171 
172 void Node::setAvgError(Double_t sAvgError)
173 {
174  avgError = sAvgError;
175 }
176 
178 {
179  return avgError;
180 }
181 
182 // ----------------------------------------------------------------------
183 
184 void Node::setNumEvents(Int_t sNumEvents)
185 {
186  numEvents = sNumEvents;
187 }
188 
190 {
191  return numEvents;
192 }
193 
194 // ----------------------------------------------------------------------
195 
196 std::vector< std::vector<Event*> >& Node::getEvents()
197 {
198  return events;
199 }
200 
201 void Node::setEvents(std::vector< std::vector<Event*> >& sEvents)
202 {
203  events = sEvents;
204  numEvents = events[0].size();
205 }
206 
208 // ______________________Performace_Functions___________________________//
210 
212 {
213 // Determines the split variable and split point which would most reduce the error for the given node (region).
214 // In the process we calculate the fitValue and Error. The general aglorithm is based upon Luis Torgo's thesis.
215 // Check out the reference for a more in depth outline. This part is chapter 3.
216 
217  // Intialize some variables.
218  Double_t bestSplitValue = 0;
219  Int_t bestSplitVariable = -1;
220  Double_t bestErrorReduction = -1;
221 
222  Double_t SUM = 0;
223  Double_t SSUM = 0;
224  numEvents = events[0].size();
225 
226  Double_t candidateErrorReduction = -1;
227 
228  // Calculate the sum of the target variables and the sum of
229  // the target variables squared. We use these later.
230  for(unsigned int i=0; i<events[0].size(); i++)
231  {
232  Double_t target = events[0][i]->data[0];
233  SUM += target;
234  SSUM += target*target;
235  }
236 
237  unsigned int numVars = events.size();
238 
239  // Calculate the best split point for each variable
240  for(unsigned int variableToCheck = 1; variableToCheck < numVars; variableToCheck++)
241  {
242 
243  // The sum of the target variables in the left, right nodes
244  Double_t SUMleft = 0;
245  Double_t SUMright = SUM;
246 
247  // The number of events in the left, right nodes
248  Int_t nleft = 1;
249  Int_t nright = events[variableToCheck].size()-1;
250 
251  Int_t candidateSplitVariable = variableToCheck;
252 
253  std::vector<Event*>& v = events[variableToCheck];
254 
255  // Find the best split point for this variable
256  for(unsigned int i=1; i<v.size(); i++)
257  {
258  // As the candidate split point interates, the number of events in the
259  // left/right node increases/decreases and SUMleft/right increases/decreases.
260 
261  SUMleft = SUMleft + v[i-1]->data[0];
262  SUMright = SUMright - v[i-1]->data[0];
263 
264  // No need to check the split point if x on both sides is equal
265  if(v[i-1]->data[candidateSplitVariable] < v[i]->data[candidateSplitVariable])
266  {
267  // Finding the maximum error reduction for Least Squares boils down to maximizing
268  // the following statement.
269  candidateErrorReduction = SUMleft*SUMleft/nleft + SUMright*SUMright/nright - SUM*SUM/numEvents;
270 // std::cout << "candidateErrorReduction= " << candidateErrorReduction << std::endl << std::endl;
271 
272  // if the new candidate is better than the current best, then we have a new overall best.
273  if(candidateErrorReduction > bestErrorReduction)
274  {
275  bestErrorReduction = candidateErrorReduction;
276  bestSplitValue = (v[i-1]->data[candidateSplitVariable] + v[i]->data[candidateSplitVariable])/2;
277  bestSplitVariable = candidateSplitVariable;
278  }
279  }
280 
281  nright = nright-1;
282  nleft = nleft+1;
283  }
284  }
285 
286  // Store the information gained from our computations.
287 
288  // The fit value is the average for least squares.
289  fitValue = SUM/numEvents;
290 // std::cout << "fitValue= " << fitValue << std::endl;
291 
292  // n*[ <y^2>-k^2 ]
293  totalError = SSUM - SUM*SUM/numEvents;
294 // std::cout << "totalError= " << totalError << std::endl;
295 
296  // [ <y^2>-k^2 ]
298 // std::cout << "avgError= " << avgError << std::endl;
299 
300 
301  errorReduction = bestErrorReduction;
302 // std::cout << "errorReduction= " << errorReduction << std::endl;
303 
304  splitVariable = bestSplitVariable;
305 // std::cout << "splitVariable= " << splitVariable << std::endl;
306 
307  splitValue = bestSplitValue;
308 // std::cout << "splitValue= " << splitValue << std::endl;
309 
310  //if(bestSplitVariable == -1) std::cout << "splitVar = -1. numEvents = " << numEvents << ". errRed = " << errorReduction << std::endl;
311 }
312 
313 // ----------------------------------------------------------------------
314 
316 {
317  std::cout << std::endl << "Listing Events... " << std::endl;
318 
319  for(unsigned int i=0; i < events.size(); i++)
320  {
321  std::cout << std::endl << "Variable " << i << " vector contents: " << std::endl;
322  for(unsigned int j=0; j < events[i].size(); j++)
323  {
324  events[i][j]->outputEvent();
325  }
326  std::cout << std::endl;
327  }
328 }
329 
330 // ----------------------------------------------------------------------
331 
333 {
334  // Create Daughter Nodes
335  Node* left = new Node(name + " left");
336  Node* right = new Node(name + " right");
337 
338  // Link the Nodes Appropriately
339  leftDaughter = left;
340  rightDaughter = right;
341  left->setParent(this);
342  right->setParent(this);
343 }
344 
345 // ----------------------------------------------------------------------
346 
348 {
349 // Keeping sorted copies of the event vectors allows us to save on
350 // computation time. That way we don't have to resort the events
351 // each time we calculate the splitpoint for a node. We sort them once.
352 // Every time we split a node, we simply filter them down correctly
353 // preserving the order. This way we have O(n) efficiency instead
354 // of O(nlogn) efficiency.
355 
356 // Anyways, this function takes events from the parent node
357 // and filters an event into the left or right daughter
358 // node depending on whether it is < or > the split point
359 // for the given split variable.
360 
361  Int_t sv = splitVariable;
362  Double_t sp = splitValue;
363 
364  Node* left = leftDaughter;
365  Node* right = rightDaughter;
366 
367  std::vector< std::vector<Event*> > l(events.size());
368  std::vector< std::vector<Event*> > r(events.size());
369 
370  for(unsigned int i=0; i<events.size(); i++)
371  {
372  for(unsigned int j=0; j<events[i].size(); j++)
373  {
374  Event* e = events[i][j];
375  if(e->data[sv] < sp) l[i].push_back(e);
376  if(e->data[sv] > sp) r[i].push_back(e);
377  }
378  }
379 
380  events = std::vector< std::vector<Event*> >();
381 
382  left->getEvents().swap(l);
383  right->getEvents().swap(r);
384 
385  // Set the number of events in the node.
386  left->setNumEvents(left->getEvents()[0].size());
387  right->setNumEvents(right->getEvents()[0].size());
388 }
389 
390 // ----------------------------------------------------------------------
391 
393 {
394 // Anyways, this function takes an event from the parent node
395 // and filters an event into the left or right daughter
396 // node depending on whether it is < or > the split point
397 // for the given split variable.
398 
399  Int_t sv = splitVariable;
400  Double_t sp = splitValue;
401 
402  Node* left = leftDaughter;
403  Node* right = rightDaughter;
404  Node* nextNode = 0;
405 
406  if(left ==0 || right ==0) return 0;
407 
408  if(e->data[sv] < sp) nextNode = left;
409  if(e->data[sv] > sp) nextNode = right;
410 
411  return nextNode;
412 }
Node * getParent()
Definition: Node.cc:121
Double_t getFitValue()
Definition: Node.cc:155
int i
Definition: DBlmapReader.cc:9
void setLeftDaughter(Node *sLeftDaughter)
Definition: Node.cc:94
std::string name
Definition: Node.h:60
Definition: Node.h:10
Node()
Definition: Node.cc:29
Int_t getNumEvents()
Definition: Node.cc:189
Node * getLeftDaughter()
Definition: Node.cc:99
void setAvgError(Double_t sAvgError)
Definition: Node.cc:172
Int_t numEvents
Definition: Node.h:74
void setSplitValue(Double_t sSplitValue)
Definition: Node.cc:128
void setNumEvents(Int_t sNumEvents)
Definition: Node.cc:184
Double_t getTotalError()
Definition: Node.cc:167
Double_t totalError
Definition: Node.h:70
Node * filterEventToDaughter(Event *e)
Definition: Node.cc:392
Definition: Event.h:16
Double_t getSplitValue()
Definition: Node.cc:133
Node * leftDaughter
Definition: Node.h:62
void setRightDaughter(Node *sLeftDaughter)
Definition: Node.cc:104
void theMiracleOfChildBirth()
Definition: Node.cc:332
std::string getName()
Definition: Node.cc:75
Double_t fitValue
Definition: Node.h:73
Node * getRightDaughter()
Definition: Node.cc:109
void setName(std::string sName)
Definition: Node.cc:70
Node * rightDaughter
Definition: Node.h:63
void listEvents()
Definition: Node.cc:315
Int_t getSplitVariable()
Definition: Node.cc:143
void filterEventsToDaughters()
Definition: Node.cc:347
int j
Definition: DBlmapReader.cc:9
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:196
Int_t splitVariable
Definition: Node.h:67
Double_t errorReduction
Definition: Node.h:69
std::vector< std::vector< Event * > > events
Definition: Node.h:76
Double_t getAvgError()
Definition: Node.cc:177
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:201
Double_t splitValue
Definition: Node.h:66
Double_t getErrorReduction()
Definition: Node.cc:87
void setParent(Node *sParent)
Definition: Node.cc:116
Double_t avgError
Definition: Node.h:71
#define SUM(A, B)
~Node()
Definition: Node.cc:59
void setFitValue(Double_t sFitValue)
Definition: Node.cc:150
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:82
void setTotalError(Double_t sTotalError)
Definition: Node.cc:162
void setSplitVariable(Int_t sSplitVar)
Definition: Node.cc:138
tuple cout
Definition: gather_cfg.py:145
Definition: sp.h:21
void calcOptimumSplit()
Definition: Node.cc:211
Node * parent
Definition: Node.h:64
std::vector< Double_t > data
Definition: Event.h:30
void setErrorReduction(Double_t sErrorReduction)
Definition: Node.cc:82