CMS 3D CMS Logo

GBRTree.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree
3 #define EGAMMAOBJECTS_GBRTree
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
25 
27 
28 #include <vector>
29 #include <map>
30 
31  namespace TMVA {
32  class DecisionTree;
33  class DecisionTreeNode;
34  }
35 
36  class GBRTree {
37 
38  public:
39 
40  GBRTree();
41  explicit GBRTree(const TMVA::DecisionTree *tree, double scale, bool useyesnoleaf, bool adjustboundary);
42  explicit GBRTree(int nIntermediate, int nTerminal);
43  virtual ~GBRTree();
44 
45  double GetResponse(const float* vector) const;
46  int TerminalIndex(const float *vector) const;
47 
48  std::vector<float> &Responses() { return fResponses; }
49  const std::vector<float> &Responses() const { return fResponses; }
50 
51  std::vector<unsigned char> &CutIndices() { return fCutIndices; }
52  const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
53 
54  std::vector<float> &CutVals() { return fCutVals; }
55  const std::vector<float> &CutVals() const { return fCutVals; }
56 
57  std::vector<int> &LeftIndices() { return fLeftIndices; }
58  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
59 
60  std::vector<int> &RightIndices() { return fRightIndices; }
61  const std::vector<int> &RightIndices() const { return fRightIndices; }
62 
63 
64 
65  protected:
66  unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
67  unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node);
68 
69  void AddNode(const TMVA::DecisionTreeNode *node, double scale, bool isregression, bool useyesnoleaf, bool adjustboundary);
70 
71  std::vector<unsigned char> fCutIndices;
72  std::vector<float> fCutVals;
73  std::vector<int> fLeftIndices;
74  std::vector<int> fRightIndices;
75  std::vector<float> fResponses;
76 
77 
79 };
80 
81 //_______________________________________________________________________
82 inline double GBRTree::GetResponse(const float* vector) const {
83  return fResponses[TerminalIndex(vector)];
84 }
85 
86 //_______________________________________________________________________
87 inline int GBRTree::TerminalIndex(const float* vector) const {
88  int index = 0;
89  do {
90  auto r = fRightIndices[index];
91  auto l = fLeftIndices[index];
92  index = vector[fCutIndices[index]] > fCutVals[index] ? r : l;
93  } while (index>0);
94  return -index;
95 }
96 
97 #endif
const std::vector< float > & CutVals() const
Definition: GBRTree.h:55
std::vector< float > & Responses()
Definition: GBRTree.h:48
const std::vector< int > & RightIndices() const
Definition: GBRTree.h:61
const std::vector< int > & LeftIndices() const
Definition: GBRTree.h:58
std::vector< float > & CutVals()
Definition: GBRTree.h:54
std::vector< int > fRightIndices
Definition: GBRTree.h:74
std::vector< float > fResponses
Definition: GBRTree.h:75
std::vector< int > & LeftIndices()
Definition: GBRTree.h:57
const std::vector< unsigned char > & CutIndices() const
Definition: GBRTree.h:52
double GetResponse(const float *vector) const
Definition: GBRTree.h:82
std::vector< float > fCutVals
Definition: GBRTree.h:72
std::vector< int > fLeftIndices
Definition: GBRTree.h:73
#define COND_SERIALIZABLE
Definition: Serializable.h:38
const std::vector< float > & Responses() const
Definition: GBRTree.h:49
Definition: GBRForest.h:26
int TerminalIndex(const float *vector) const
Definition: GBRTree.h:87
Definition: tree.py:1
std::vector< int > & RightIndices()
Definition: GBRTree.h:60
std::vector< unsigned char > fCutIndices
Definition: GBRTree.h:71
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:51