CMS 3D CMS Logo

GBRTree2D.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree2D
3 #define EGAMMAOBJECTS_GBRTree2D
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
26 
27 #include <vector>
28 #include <map>
29 
30 class GBRTree2D {
31 public:
32  GBRTree2D() {}
34 
35  void GetResponse(const float *vector, double &x, double &y) const;
36  int TerminalIndex(const float *vector) const;
37 
38  std::vector<float> &ResponsesX() { return fResponsesX; }
39  const std::vector<float> &ResponsesX() const { return fResponsesX; }
40 
41  std::vector<float> &ResponsesY() { return fResponsesY; }
42  const std::vector<float> &ResponsesY() const { return fResponsesY; }
43 
44  std::vector<unsigned short> &CutIndices() { return fCutIndices; }
45  const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
46 
47  std::vector<float> &CutVals() { return fCutVals; }
48  const std::vector<float> &CutVals() const { return fCutVals; }
49 
50  std::vector<int> &LeftIndices() { return fLeftIndices; }
51  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
52 
53  std::vector<int> &RightIndices() { return fRightIndices; }
54  const std::vector<int> &RightIndices() const { return fRightIndices; }
55 
56 protected:
57  std::vector<unsigned short> fCutIndices;
58  std::vector<float> fCutVals;
59  std::vector<int> fLeftIndices;
60  std::vector<int> fRightIndices;
61  std::vector<float> fResponsesX;
62  std::vector<float> fResponsesY;
63 
65 };
66 
67 //_______________________________________________________________________
68 inline void GBRTree2D::GetResponse(const float *vector, double &x, double &y) const {
69  int index = 0;
70 
71  unsigned short cutindex = fCutIndices[0];
72  float cutval = fCutVals[0];
73 
74  while (true) {
75  if (vector[cutindex] > cutval) {
77  } else {
79  }
80 
81  if (index > 0) {
82  cutindex = fCutIndices[index];
83  cutval = fCutVals[index];
84  } else {
85  x = fResponsesX[-index];
86  y = fResponsesY[-index];
87  return;
88  }
89  }
90 }
91 
92 //_______________________________________________________________________
93 inline int GBRTree2D::TerminalIndex(const float *vector) const {
94  int index = 0;
95 
96  unsigned short cutindex = fCutIndices[0];
97  float cutval = fCutVals[0];
98 
99  while (true) {
100  if (vector[cutindex] > cutval) {
102  } else {
104  }
105 
106  if (index > 0) {
107  cutindex = fCutIndices[index];
108  cutval = fCutVals[index];
109  } else {
110  return (-index);
111  }
112  }
113 }
114 
115 #endif
DDAxes::y
GBRTree2D::GBRTree2D
GBRTree2D()
Definition: GBRTree2D.h:32
GBRTree2D::CutIndices
std::vector< unsigned short > & CutIndices()
Definition: GBRTree2D.h:44
COND_SERIALIZABLE
#define COND_SERIALIZABLE
Definition: Serializable.h:39
GBRTree2D::ResponsesY
const std::vector< float > & ResponsesY() const
Definition: GBRTree2D.h:42
GBRTree2D::fCutVals
std::vector< float > fCutVals
Definition: GBRTree2D.h:58
DDAxes::x
GBRTree2D::fResponsesY
std::vector< float > fResponsesY
Definition: GBRTree2D.h:62
GBRTree2D::TerminalIndex
int TerminalIndex(const float *vector) const
Definition: GBRTree2D.h:93
GBRTree2D::fCutIndices
std::vector< unsigned short > fCutIndices
Definition: GBRTree2D.h:57
GBRTree2D::CutVals
const std::vector< float > & CutVals() const
Definition: GBRTree2D.h:48
GBRTree2D::RightIndices
std::vector< int > & RightIndices()
Definition: GBRTree2D.h:53
GBRTree2D::fLeftIndices
std::vector< int > fLeftIndices
Definition: GBRTree2D.h:59
GBRTree2D::RightIndices
const std::vector< int > & RightIndices() const
Definition: GBRTree2D.h:54
GBRTree2D::fRightIndices
std::vector< int > fRightIndices
Definition: GBRTree2D.h:60
GBRTree2D::LeftIndices
std::vector< int > & LeftIndices()
Definition: GBRTree2D.h:50
GBRTree2D::ResponsesX
const std::vector< float > & ResponsesX() const
Definition: GBRTree2D.h:39
Serializable.h
trackerHitRTTI::vector
Definition: trackerHitRTTI.h:21
GBRTree2D::CutIndices
const std::vector< unsigned short > & CutIndices() const
Definition: GBRTree2D.h:45
GBRTree2D::~GBRTree2D
~GBRTree2D()
Definition: GBRTree2D.h:33
GBRTree2D::LeftIndices
const std::vector< int > & LeftIndices() const
Definition: GBRTree2D.h:51
GBRTree2D::fResponsesX
std::vector< float > fResponsesX
Definition: GBRTree2D.h:61
GBRTree2D::GetResponse
void GetResponse(const float *vector, double &x, double &y) const
Definition: GBRTree2D.h:68
GBRTree2D::CutVals
std::vector< float > & CutVals()
Definition: GBRTree2D.h:47
genVertex_cff.x
x
Definition: genVertex_cff.py:12
detailsBasic3DVector::y
float float y
Definition: extBasic3DVector.h:14
AlignmentPI::index
index
Definition: AlignmentPayloadInspectorHelper.h:46
GBRTree2D
Definition: GBRTree2D.h:30
GBRTree2D::ResponsesY
std::vector< float > & ResponsesY()
Definition: GBRTree2D.h:41
GBRTree2D::ResponsesX
std::vector< float > & ResponsesX()
Definition: GBRTree2D.h:38