CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
GBRTree2D.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree2D
3 #define EGAMMAOBJECTS_GBRTree2D
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
25 
26 #include <vector>
27 #include <map>
28 
29  class GBRTree2D {
30 
31  public:
32 
33  GBRTree2D() {}
35 
36  void GetResponse(const float* vector, double &x, double &y) const;
37  int TerminalIndex(const float *vector) const;
38 
39  std::vector<float> &ResponsesX() { return fResponsesX; }
40  const std::vector<float> &ResponsesX() const { return fResponsesX; }
41 
42  std::vector<float> &ResponsesY() { return fResponsesY; }
43  const std::vector<float> &ResponsesY() const { return fResponsesY; }
44 
45  std::vector<unsigned short> &CutIndices() { return fCutIndices; }
46  const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
47 
48  std::vector<float> &CutVals() { return fCutVals; }
49  const std::vector<float> &CutVals() const { return fCutVals; }
50 
51  std::vector<int> &LeftIndices() { return fLeftIndices; }
52  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
53 
54  std::vector<int> &RightIndices() { return fRightIndices; }
55  const std::vector<int> &RightIndices() const { return fRightIndices; }
56 
57 
58 
59  protected:
60  std::vector<unsigned short> fCutIndices;
61  std::vector<float> fCutVals;
62  std::vector<int> fLeftIndices;
63  std::vector<int> fRightIndices;
64  std::vector<float> fResponsesX;
65  std::vector<float> fResponsesY;
66 
67  };
68 
69 //_______________________________________________________________________
70 inline void GBRTree2D::GetResponse(const float* vector, double &x, double &y) const {
71 
72  int index = 0;
73 
74  unsigned short cutindex = fCutIndices[0];
75  float cutval = fCutVals[0];
76 
77  while (true) {
78 
79  if (vector[cutindex] > cutval) {
80  index = fRightIndices[index];
81  }
82  else {
83  index = fLeftIndices[index];
84  }
85 
86  if (index>0) {
87  cutindex = fCutIndices[index];
88  cutval = fCutVals[index];
89  }
90  else {
91  x = fResponsesX[-index];
92  y = fResponsesY[-index];
93  return;
94  }
95 
96  }
97 
98 
99 }
100 
101 //_______________________________________________________________________
102 inline int GBRTree2D::TerminalIndex(const float* vector) const {
103 
104  int index = 0;
105 
106  unsigned short cutindex = fCutIndices[0];
107  float cutval = fCutVals[0];
108 
109  while (true) {
110  if (vector[cutindex] > cutval) {
111  index = fRightIndices[index];
112  }
113  else {
114  index = fLeftIndices[index];
115  }
116 
117  if (index>0) {
118  cutindex = fCutIndices[index];
119  cutval = fCutVals[index];
120  }
121  else {
122  return (-index);
123  }
124 
125  }
126 
127 
128 }
129 
130 #endif
std::vector< unsigned short > & CutIndices()
Definition: GBRTree2D.h:45
const std::vector< int > & RightIndices() const
Definition: GBRTree2D.h:55
std::vector< int > & LeftIndices()
Definition: GBRTree2D.h:51
const std::vector< int > & LeftIndices() const
Definition: GBRTree2D.h:52
int TerminalIndex(const float *vector) const
Definition: GBRTree2D.h:102
std::vector< float > fResponsesY
Definition: GBRTree2D.h:65
std::vector< float > & ResponsesY()
Definition: GBRTree2D.h:42
std::vector< int > & RightIndices()
Definition: GBRTree2D.h:54
std::vector< float > & CutVals()
Definition: GBRTree2D.h:48
const std::vector< float > & CutVals() const
Definition: GBRTree2D.h:49
void GetResponse(const float *vector, double &x, double &y) const
Definition: GBRTree2D.h:70
~GBRTree2D()
Definition: GBRTree2D.h:34
std::vector< int > fLeftIndices
Definition: GBRTree2D.h:62
std::vector< int > fRightIndices
Definition: GBRTree2D.h:63
GBRTree2D()
Definition: GBRTree2D.h:33
const std::vector< float > & ResponsesY() const
Definition: GBRTree2D.h:43
std::vector< float > fResponsesX
Definition: GBRTree2D.h:64
const std::vector< float > & ResponsesX() const
Definition: GBRTree2D.h:40
std::vector< float > & ResponsesX()
Definition: GBRTree2D.h:39
x
Definition: VDTMath.h:216
std::vector< float > fCutVals
Definition: GBRTree2D.h:61
std::vector< unsigned short > fCutIndices
Definition: GBRTree2D.h:60
const std::vector< unsigned short > & CutIndices() const
Definition: GBRTree2D.h:46