CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
Node.cc
Go to the documentation of this file.
1 // Node.cxx //
3 // =====================================================================//
4 // This is the object implementation of a node, which is the //
5 // fundamental unit of a decision tree. //
6 // References include //
7 // *Elements of Statistical Learning by Hastie, //
8 // Tibshirani, and Friedman. //
9 // *Greedy Function Approximation: A Gradient Boosting Machine. //
10 // Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001. //
11 // *Inductive Learning of Tree-based Regression Models. Luis Torgo. //
12 // //
14 
16 // _______________________Includes_______________________________________//
18 
20 #include "TRandom3.h"
21 #include "TStopwatch.h"
22 #include <iostream>
23 #include <fstream>
24 
26 // _______________________Constructor(s)________________________________//
28 
30 {
31  name = "";
32  leftDaughter = 0;
33  rightDaughter = 0;
34  parent = 0;
35  splitValue = -99;
36  splitVariable = -1;
37  avgError = -1;
38  totalError = -1;
39  errorReduction = -1;
40 }
41 
43 {
44  name = cName;
45  leftDaughter = 0;
46  rightDaughter = 0;
47  parent = 0;
48  splitValue = -99;
49  splitVariable = -1;
50  avgError = -1;
51  totalError = -1;
52  errorReduction = -1;
53 }
54 
56 // _______________________Destructor____________________________________//
58 
60 {
61 // Recursively delete all nodes in the tree.
62  delete leftDaughter;
63  delete rightDaughter;
64 }
65 
67 // ______________________Get/Set________________________________________//
69 
71 {
72  name = sName;
73 }
74 
76 {
77  return name;
78 }
79 
80 // ----------------------------------------------------------------------
81 
82 void Node::setErrorReduction(Double_t sErrorReduction)
83 {
84  errorReduction = sErrorReduction;
85 }
86 
88 {
89  return errorReduction;
90 }
91 
92 // ----------------------------------------------------------------------
93 
94 void Node::setLeftDaughter(Node *sLeftDaughter)
95 {
96  leftDaughter = sLeftDaughter;
97 }
98 
100 {
101  return leftDaughter;
102 }
103 
104 void Node::setRightDaughter(Node *sRightDaughter)
105 {
106  rightDaughter = sRightDaughter;
107 }
108 
110 {
111  return rightDaughter;
112 }
113 
114 // ----------------------------------------------------------------------
115 
116 void Node::setParent(Node *sParent)
117 {
118  parent = sParent;
119 }
120 
122 {
123  return parent;
124 }
125 
126 // ----------------------------------------------------------------------
127 
128 void Node::setSplitValue(Double_t sSplitValue)
129 {
130  splitValue = sSplitValue;
131 }
132 
134 {
135  return splitValue;
136 }
137 
138 void Node::setSplitVariable(Int_t sSplitVar)
139 {
140  splitVariable = sSplitVar;
141 }
142 
144 {
145  return splitVariable;
146 }
147 
148 // ----------------------------------------------------------------------
149 
150 void Node::setFitValue(Double_t sFitValue)
151 {
152  fitValue = sFitValue;
153 }
154 
156 {
157  return fitValue;
158 }
159 
160 // ----------------------------------------------------------------------
161 
162 void Node::setTotalError(Double_t sTotalError)
163 {
164  totalError = sTotalError;
165 }
166 
168 {
169  return totalError;
170 }
171 
172 void Node::setAvgError(Double_t sAvgError)
173 {
174  avgError = sAvgError;
175 }
176 
178 {
179  return avgError;
180 }
181 
182 // ----------------------------------------------------------------------
183 
184 void Node::setNumEvents(Int_t sNumEvents)
185 {
186  numEvents = sNumEvents;
187 }
188 
190 {
191  return numEvents;
192 }
193 
194 // ----------------------------------------------------------------------
195 
196 std::vector< std::vector<Event*> >& Node::getEvents()
197 {
198  return events;
199 }
200 
201 void Node::setEvents(std::vector< std::vector<Event*> >& sEvents)
202 {
203  events = sEvents;
204 }
205 
207 // ______________________Performace_Functions___________________________//
209 
211 {
212 // Determines the split variable and split point which would most reduce the error for the given node (region).
213 // In the process we calculate the fitValue and Error. The general aglorithm is based upon Luis Torgo's thesis.
214 // Check out the reference for a more in depth outline. This part is chapter 3.
215 
216  // Intialize some variables.
217  Double_t bestSplitValue = 0;
218  Int_t bestSplitVariable = -1;
219  Double_t bestErrorReduction = 0;
220 
221  Double_t SUM = 0;
222  Double_t SSUM = 0;
223  numEvents = events[0].size();
224 
225  Double_t candidateErrorReduction = 0;
226 
227  // Calculate the sum of the target variables and the sum of
228  // the target variables squared. We use these later.
229  for(unsigned int i=0; i<events[0].size(); i++)
230  {
231  Double_t target = events[0][i]->data[0];
232  SUM += target;
233  SSUM += target*target;
234  }
235 
236  unsigned int numVars = events.size();
237 
238  // Calculate the best split point for each variable
239  for(unsigned int variableToCheck = 1; variableToCheck < numVars; variableToCheck++)
240  {
241  // The sum of the target variables in the left, right nodes
242  Double_t SUMleft = 0;
243  Double_t SUMright = SUM;
244 
245  // The number of events in the left, right nodes
246  Int_t nleft = 1;
247  Int_t nright = events[variableToCheck].size()-1;
248 
249  Int_t candidateSplitVariable = variableToCheck;
250 
251  std::vector<Event*>& v = events[variableToCheck];
252 
253  // Find the best split point for this variable
254  for(unsigned int i=1; i<v.size(); i++)
255  {
256  // As the candidate split point interates, the number of events in the
257  // left/right node increases/decreases and SUMleft/right increases/decreases.
258 
259  SUMleft = SUMleft + v[i-1]->data[0];
260  SUMright = SUMright - v[i-1]->data[0];
261 
262  // No need to check the split point if x on both sides is equal
263  if(v[i-1]->data[candidateSplitVariable] < v[i]->data[candidateSplitVariable])
264  {
265  // Finding the maximum error reduction for Least Squares boils down to maximizing
266  // the following statement.
267  candidateErrorReduction = SUMleft*SUMleft/nleft + SUMright*SUMright/nright - SUM*SUM/numEvents;
268 // std::cout << "candidateErrorReduction= " << candidateErrorReduction << std::endl << std::endl;
269 
270  // if the new candidate is better than the current best, then we have a new overall best.
271  if(candidateErrorReduction > bestErrorReduction)
272  {
273  bestErrorReduction = candidateErrorReduction;
274  bestSplitValue = (v[i-1]->data[candidateSplitVariable] + v[i]->data[candidateSplitVariable])/2;
275  bestSplitVariable = candidateSplitVariable;
276  }
277  }
278 
279  nright = nright-1;
280  nleft = nleft+1;
281  }
282  }
283 
284  // Store the information gained from our computations.
285 
286  // The fit value is the average for least squares.
287  fitValue = SUM/numEvents;
288 // std::cout << "fitValue= " << fitValue << std::endl;
289 
290  // n*[ <y^2>-k^2 ]
291  totalError = SSUM - SUM*SUM/numEvents;
292 // std::cout << "totalError= " << totalError << std::endl;
293 
294  // [ <y^2>-k^2 ]
296 // std::cout << "avgError= " << avgError << std::endl;
297 
298 
299  errorReduction = bestErrorReduction;
300 // std::cout << "errorReduction= " << errorReduction << std::endl;
301 
302  splitVariable = bestSplitVariable;
303 // std::cout << "splitVariable= " << splitVariable << std::endl;
304 
305  splitValue = bestSplitValue;
306 // std::cout << "splitValue= " << splitValue << std::endl;
307 
308 }
309 
310 // ----------------------------------------------------------------------
311 
313 {
314  std::cout << std::endl << "Listing Events... " << std::endl;
315 
316  for(unsigned int i=0; i < events.size(); i++)
317  {
318  std::cout << std::endl << "Variable " << i << " vector contents: " << std::endl;
319  for(unsigned int j=0; j < events[i].size(); j++)
320  {
321  events[i][j]->outputEvent();
322  }
323  std::cout << std::endl;
324  }
325 }
326 
327 // ----------------------------------------------------------------------
328 
330 {
331  // Create Daughter Nodes
332  Node* left = new Node(name + " left");
333  Node* right = new Node(name + " right");
334 
335  // Link the Nodes Appropriately
336  leftDaughter = left;
337  rightDaughter = right;
338  left->setParent(this);
339  right->setParent(this);
340 }
341 
342 // ----------------------------------------------------------------------
343 
345 {
346 // Keeping sorted copies of the event vectors allows us to save on
347 // computation time. That way we don't have to resort the events
348 // each time we calculate the splitpoint for a node. We sort them once.
349 // Every time we split a node, we simply filter them down correctly
350 // preserving the order. This way we have O(n) efficiency instead
351 // of O(nlogn) efficiency.
352 
353 // Anyways, this function takes events from the parent node
354 // and filters an event into the left or right daughter
355 // node depending on whether it is < or > the split point
356 // for the given split variable.
357 
358  Int_t sv = splitVariable;
359  Double_t sp = splitValue;
360 
361  Node* left = leftDaughter;
362  Node* right = rightDaughter;
363 
364  std::vector< std::vector<Event*> > l(events.size());
365  std::vector< std::vector<Event*> > r(events.size());
366 
367  for(unsigned int i=0; i<events.size(); i++)
368  {
369  for(unsigned int j=0; j<events[i].size(); j++)
370  {
371  Event* e = events[i][j];
372  if(e->data[sv] < sp) l[i].push_back(e);
373  if(e->data[sv] > sp) r[i].push_back(e);
374  }
375  }
376 
377  events = std::vector< std::vector<Event*> >();
378 
379  left->getEvents().swap(l);
380  right->getEvents().swap(r);
381 }
Node * getParent()
Definition: Node.cc:121
Double_t getFitValue()
Definition: Node.cc:155
int i
Definition: DBlmapReader.cc:9
void setLeftDaughter(Node *sLeftDaughter)
Definition: Node.cc:94
std::string name
Definition: Node.h:59
Definition: Node.h:10
Node()
Definition: Node.cc:29
Int_t getNumEvents()
Definition: Node.cc:189
Node * getLeftDaughter()
Definition: Node.cc:99
void setAvgError(Double_t sAvgError)
Definition: Node.cc:172
Int_t numEvents
Definition: Node.h:73
void setSplitValue(Double_t sSplitValue)
Definition: Node.cc:128
void setNumEvents(Int_t sNumEvents)
Definition: Node.cc:184
Double_t getTotalError()
Definition: Node.cc:167
Double_t totalError
Definition: Node.h:69
Definition: Event.h:16
Double_t getSplitValue()
Definition: Node.cc:133
Node * leftDaughter
Definition: Node.h:61
void setRightDaughter(Node *sLeftDaughter)
Definition: Node.cc:104
void theMiracleOfChildBirth()
Definition: Node.cc:329
std::string getName()
Definition: Node.cc:75
Double_t fitValue
Definition: Node.h:72
Node * getRightDaughter()
Definition: Node.cc:109
void setName(std::string sName)
Definition: Node.cc:70
Node * rightDaughter
Definition: Node.h:62
void listEvents()
Definition: Node.cc:312
Int_t getSplitVariable()
Definition: Node.cc:143
void filterEventsToDaughters()
Definition: Node.cc:344
int j
Definition: DBlmapReader.cc:9
std::vector< std::vector< Event * > > & getEvents()
Definition: Node.cc:196
Int_t splitVariable
Definition: Node.h:66
Double_t errorReduction
Definition: Node.h:68
std::vector< std::vector< Event * > > events
Definition: Node.h:75
Double_t getAvgError()
Definition: Node.cc:177
void setEvents(std::vector< std::vector< Event * > > &sEvents)
Definition: Node.cc:201
Double_t splitValue
Definition: Node.h:65
Double_t getErrorReduction()
Definition: Node.cc:87
void setParent(Node *sParent)
Definition: Node.cc:116
Double_t avgError
Definition: Node.h:70
#define SUM(A, B)
~Node()
Definition: Node.cc:59
void setFitValue(Double_t sFitValue)
Definition: Node.cc:150
void setTotalError(Double_t sTotalError)
Definition: Node.cc:162
void setSplitVariable(Int_t sSplitVar)
Definition: Node.cc:138
tuple cout
Definition: gather_cfg.py:145
Definition: sp.h:21
void calcOptimumSplit()
Definition: Node.cc:210
Node * parent
Definition: Node.h:63
std::vector< Double_t > data
Definition: Event.h:30
void setErrorReduction(Double_t sErrorReduction)
Definition: Node.cc:82