CMS 3D CMS Logo

GBRTree2D.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree2D
3 #define EGAMMAOBJECTS_GBRTree2D
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
25 
27 
28 #include <vector>
29 #include <map>
30 
31  class GBRTree2D {
32 
33  public:
34 
35  GBRTree2D() {}
37 
38  void GetResponse(const float* vector, double &x, double &y) const;
39  int TerminalIndex(const float *vector) const;
40 
41  std::vector<float> &ResponsesX() { return fResponsesX; }
42  const std::vector<float> &ResponsesX() const { return fResponsesX; }
43 
44  std::vector<float> &ResponsesY() { return fResponsesY; }
45  const std::vector<float> &ResponsesY() const { return fResponsesY; }
46 
47  std::vector<unsigned short> &CutIndices() { return fCutIndices; }
48  const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
49 
50  std::vector<float> &CutVals() { return fCutVals; }
51  const std::vector<float> &CutVals() const { return fCutVals; }
52 
53  std::vector<int> &LeftIndices() { return fLeftIndices; }
54  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
55 
56  std::vector<int> &RightIndices() { return fRightIndices; }
57  const std::vector<int> &RightIndices() const { return fRightIndices; }
58 
59 
60 
61  protected:
62  std::vector<unsigned short> fCutIndices;
63  std::vector<float> fCutVals;
64  std::vector<int> fLeftIndices;
65  std::vector<int> fRightIndices;
66  std::vector<float> fResponsesX;
67  std::vector<float> fResponsesY;
68 
69 
71 };
72 
73 //_______________________________________________________________________
74 inline void GBRTree2D::GetResponse(const float* vector, double &x, double &y) const {
75 
76  int index = 0;
77 
78  unsigned short cutindex = fCutIndices[0];
79  float cutval = fCutVals[0];
80 
81  while (true) {
82 
83  if (vector[cutindex] > cutval) {
84  index = fRightIndices[index];
85  }
86  else {
87  index = fLeftIndices[index];
88  }
89 
90  if (index>0) {
91  cutindex = fCutIndices[index];
92  cutval = fCutVals[index];
93  }
94  else {
95  x = fResponsesX[-index];
96  y = fResponsesY[-index];
97  return;
98  }
99 
100  }
101 
102 
103 }
104 
105 //_______________________________________________________________________
106 inline int GBRTree2D::TerminalIndex(const float* vector) const {
107 
108  int index = 0;
109 
110  unsigned short cutindex = fCutIndices[0];
111  float cutval = fCutVals[0];
112 
113  while (true) {
114  if (vector[cutindex] > cutval) {
115  index = fRightIndices[index];
116  }
117  else {
118  index = fLeftIndices[index];
119  }
120 
121  if (index>0) {
122  cutindex = fCutIndices[index];
123  cutval = fCutVals[index];
124  }
125  else {
126  return (-index);
127  }
128 
129  }
130 
131 
132 }
133 
134 #endif
std::vector< unsigned short > & CutIndices()
Definition: GBRTree2D.h:47
const std::vector< int > & RightIndices() const
Definition: GBRTree2D.h:57
std::vector< int > & LeftIndices()
Definition: GBRTree2D.h:53
const std::vector< int > & LeftIndices() const
Definition: GBRTree2D.h:54
int TerminalIndex(const float *vector) const
Definition: GBRTree2D.h:106
std::vector< float > fResponsesY
Definition: GBRTree2D.h:67
std::vector< float > & ResponsesY()
Definition: GBRTree2D.h:44
std::vector< int > & RightIndices()
Definition: GBRTree2D.h:56
std::vector< float > & CutVals()
Definition: GBRTree2D.h:50
const std::vector< float > & CutVals() const
Definition: GBRTree2D.h:51
void GetResponse(const float *vector, double &x, double &y) const
Definition: GBRTree2D.h:74
~GBRTree2D()
Definition: GBRTree2D.h:36
std::vector< int > fLeftIndices
Definition: GBRTree2D.h:64
std::vector< int > fRightIndices
Definition: GBRTree2D.h:65
GBRTree2D()
Definition: GBRTree2D.h:35
const std::vector< float > & ResponsesY() const
Definition: GBRTree2D.h:45
std::vector< float > fResponsesX
Definition: GBRTree2D.h:66
const std::vector< float > & ResponsesX() const
Definition: GBRTree2D.h:42
#define COND_SERIALIZABLE
Definition: Serializable.h:38
std::vector< float > & ResponsesX()
Definition: GBRTree2D.h:41
std::vector< float > fCutVals
Definition: GBRTree2D.h:63
std::vector< unsigned short > fCutIndices
Definition: GBRTree2D.h:62
const std::vector< unsigned short > & CutIndices() const
Definition: GBRTree2D.h:48