CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
GBRTree.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree
3 #define EGAMMAOBJECTS_GBRTree
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
25 
27 
28 #include <vector>
29 #include <map>
30 
31  namespace TMVA {
32  class DecisionTree;
33  class DecisionTreeNode;
34  }
35 
36  class GBRTree {
37 
38  public:
39 
40  GBRTree();
41  explicit GBRTree(const TMVA::DecisionTree *tree);
42  virtual ~GBRTree();
43 
44  double GetResponse(const float* vector) const;
45  int TerminalIndex(const float *vector) const;
46 
47  std::vector<float> &Responses() { return fResponses; }
48  const std::vector<float> &Responses() const { return fResponses; }
49 
50  std::vector<unsigned char> &CutIndices() { return fCutIndices; }
51  const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
52 
53  std::vector<float> &CutVals() { return fCutVals; }
54  const std::vector<float> &CutVals() const { return fCutVals; }
55 
56  std::vector<int> &LeftIndices() { return fLeftIndices; }
57  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
58 
59  std::vector<int> &RightIndices() { return fRightIndices; }
60  const std::vector<int> &RightIndices() const { return fRightIndices; }
61 
62 
63 
64  protected:
65  unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
66  unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node);
67 
68  void AddNode(const TMVA::DecisionTreeNode *node);
69 
70  std::vector<unsigned char> fCutIndices;
71  std::vector<float> fCutVals;
72  std::vector<int> fLeftIndices;
73  std::vector<int> fRightIndices;
74  std::vector<float> fResponses;
75 
76 
78 };
79 
80 //_______________________________________________________________________
81 inline double GBRTree::GetResponse(const float* vector) const {
82 
83  int index = 0;
84 
85  unsigned char cutindex = fCutIndices[0];
86  float cutval = fCutVals[0];
87 
88  while (true) {
89 
90  if (vector[cutindex] > cutval) {
91  index = fRightIndices[index];
92  }
93  else {
94  index = fLeftIndices[index];
95  }
96 
97  if (index>0) {
98  cutindex = fCutIndices[index];
99  cutval = fCutVals[index];
100  }
101  else {
102  return fResponses[-index];
103  }
104 
105  }
106 
107 
108 }
109 
110 //_______________________________________________________________________
111 inline int GBRTree::TerminalIndex(const float* vector) const {
112 
113  int index = 0;
114 
115  unsigned char cutindex = fCutIndices[0];
116  float cutval = fCutVals[0];
117 
118  while (true) {
119  if (vector[cutindex] > cutval) {
120  index = fRightIndices[index];
121  }
122  else {
123  index = fLeftIndices[index];
124  }
125 
126  if (index>0) {
127  cutindex = fCutIndices[index];
128  cutval = fCutVals[index];
129  }
130  else {
131  return (-index);
132  }
133 
134  }
135 
136 
137 }
138 
139 #endif
const std::vector< float > & CutVals() const
Definition: GBRTree.h:54
std::vector< float > & Responses()
Definition: GBRTree.h:47
const std::vector< int > & RightIndices() const
Definition: GBRTree.h:60
const std::vector< int > & LeftIndices() const
Definition: GBRTree.h:57
unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node)
tuple node
Definition: Node.py:50
std::vector< float > & CutVals()
Definition: GBRTree.h:53
std::vector< int > fRightIndices
Definition: GBRTree.h:73
std::vector< float > fResponses
Definition: GBRTree.h:74
std::vector< int > & LeftIndices()
Definition: GBRTree.h:56
const std::vector< unsigned char > & CutIndices() const
Definition: GBRTree.h:51
unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node)
double GetResponse(const float *vector) const
Definition: GBRTree.h:81
std::vector< float > fCutVals
Definition: GBRTree.h:71
std::vector< int > fLeftIndices
Definition: GBRTree.h:72
void AddNode(const TMVA::DecisionTreeNode *node)
#define COND_SERIALIZABLE
Definition: Serializable.h:37
const std::vector< float > & Responses() const
Definition: GBRTree.h:48
virtual ~GBRTree()
int TerminalIndex(const float *vector) const
Definition: GBRTree.h:111
std::vector< int > & RightIndices()
Definition: GBRTree.h:59
std::vector< unsigned char > fCutIndices
Definition: GBRTree.h:70
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:50