CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
GBRTree.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree
3 #define EGAMMAOBJECTS_GBRTree
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
25 
26 #include <vector>
27 #include <map>
28 
29  namespace TMVA {
30  class DecisionTree;
31  class DecisionTreeNode;
32  }
33 
34  class GBRTree {
35 
36  public:
37 
38  GBRTree();
39  explicit GBRTree(const TMVA::DecisionTree *tree);
40  virtual ~GBRTree();
41 
42  double GetResponse(const float* vector) const;
43  int TerminalIndex(const float *vector) const;
44 
45  std::vector<float> &Responses() { return fResponses; }
46  const std::vector<float> &Responses() const { return fResponses; }
47 
48  std::vector<unsigned char> &CutIndices() { return fCutIndices; }
49  const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
50 
51  std::vector<float> &CutVals() { return fCutVals; }
52  const std::vector<float> &CutVals() const { return fCutVals; }
53 
54  std::vector<int> &LeftIndices() { return fLeftIndices; }
55  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
56 
57  std::vector<int> &RightIndices() { return fRightIndices; }
58  const std::vector<int> &RightIndices() const { return fRightIndices; }
59 
60 
61 
62  protected:
63  unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
64  unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node);
65 
66  void AddNode(const TMVA::DecisionTreeNode *node);
67 
68  std::vector<unsigned char> fCutIndices;
69  std::vector<float> fCutVals;
70  std::vector<int> fLeftIndices;
71  std::vector<int> fRightIndices;
72  std::vector<float> fResponses;
73 
74  };
75 
76 //_______________________________________________________________________
77 inline double GBRTree::GetResponse(const float* vector) const {
78 
79  int index = 0;
80 
81  unsigned char cutindex = fCutIndices[0];
82  float cutval = fCutVals[0];
83 
84  while (true) {
85 
86  if (vector[cutindex] > cutval) {
87  index = fRightIndices[index];
88  }
89  else {
90  index = fLeftIndices[index];
91  }
92 
93  if (index>0) {
94  cutindex = fCutIndices[index];
95  cutval = fCutVals[index];
96  }
97  else {
98  return fResponses[-index];
99  }
100 
101  }
102 
103 
104 }
105 
106 //_______________________________________________________________________
107 inline int GBRTree::TerminalIndex(const float* vector) const {
108 
109  int index = 0;
110 
111  unsigned char cutindex = fCutIndices[0];
112  float cutval = fCutVals[0];
113 
114  while (true) {
115  if (vector[cutindex] > cutval) {
116  index = fRightIndices[index];
117  }
118  else {
119  index = fLeftIndices[index];
120  }
121 
122  if (index>0) {
123  cutindex = fCutIndices[index];
124  cutval = fCutVals[index];
125  }
126  else {
127  return (-index);
128  }
129 
130  }
131 
132 
133 }
134 
135 #endif
const std::vector< float > & CutVals() const
Definition: GBRTree.h:52
std::vector< float > & Responses()
Definition: GBRTree.h:45
const std::vector< int > & RightIndices() const
Definition: GBRTree.h:58
const std::vector< int > & LeftIndices() const
Definition: GBRTree.h:55
unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node)
tuple node
Definition: Node.py:50
std::vector< float > & CutVals()
Definition: GBRTree.h:51
std::vector< int > fRightIndices
Definition: GBRTree.h:71
std::vector< float > fResponses
Definition: GBRTree.h:72
std::vector< int > & LeftIndices()
Definition: GBRTree.h:54
const std::vector< unsigned char > & CutIndices() const
Definition: GBRTree.h:49
unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node)
double GetResponse(const float *vector) const
Definition: GBRTree.h:77
std::vector< float > fCutVals
Definition: GBRTree.h:69
std::vector< int > fLeftIndices
Definition: GBRTree.h:70
void AddNode(const TMVA::DecisionTreeNode *node)
const std::vector< float > & Responses() const
Definition: GBRTree.h:46
virtual ~GBRTree()
int TerminalIndex(const float *vector) const
Definition: GBRTree.h:107
std::vector< int > & RightIndices()
Definition: GBRTree.h:57
std::vector< unsigned char > fCutIndices
Definition: GBRTree.h:68
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:48