CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
GBRTree.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree
3 #define EGAMMAOBJECTS_GBRTree
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
25 
27 
28 #include <vector>
29 #include <map>
30 
31  namespace TMVA {
32  class DecisionTree;
33  class DecisionTreeNode;
34  }
35 
36  class GBRTree {
37 
38  public:
39 
40  GBRTree();
41  explicit GBRTree(const TMVA::DecisionTree *tree, double scale, bool useyesnoleaf, bool adjustboundary);
42  virtual ~GBRTree();
43 
44  double GetResponse(const float* vector) const;
45  int TerminalIndex(const float *vector) const;
46 
47  std::vector<float> &Responses() { return fResponses; }
48  const std::vector<float> &Responses() const { return fResponses; }
49 
50  std::vector<unsigned char> &CutIndices() { return fCutIndices; }
51  const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
52 
53  std::vector<float> &CutVals() { return fCutVals; }
54  const std::vector<float> &CutVals() const { return fCutVals; }
55 
56  std::vector<int> &LeftIndices() { return fLeftIndices; }
57  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
58 
59  std::vector<int> &RightIndices() { return fRightIndices; }
60  const std::vector<int> &RightIndices() const { return fRightIndices; }
61 
62 
63 
64  protected:
65  unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
66  unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node);
67 
68  void AddNode(const TMVA::DecisionTreeNode *node, double scale, bool isregression, bool useyesnoleaf, bool adjustboundary);
69 
70  std::vector<unsigned char> fCutIndices;
71  std::vector<float> fCutVals;
72  std::vector<int> fLeftIndices;
73  std::vector<int> fRightIndices;
74  std::vector<float> fResponses;
75 
76 
78 };
79 
80 //_______________________________________________________________________
81 inline double GBRTree::GetResponse(const float* vector) const {
82  return fResponses[TerminalIndex(vector)];
83 }
84 
85 //_______________________________________________________________________
86 inline int GBRTree::TerminalIndex(const float* vector) const {
87  int index = 0;
88  do {
89  auto r = fRightIndices[index];
90  auto l = fLeftIndices[index];
91  index = vector[fCutIndices[index]] > fCutVals[index] ? r : l;
92  } while (index>0);
93  return -index;
94 }
95 
96 #endif
const std::vector< float > & CutVals() const
Definition: GBRTree.h:54
std::vector< float > & Responses()
Definition: GBRTree.h:47
const std::vector< int > & RightIndices() const
Definition: GBRTree.h:60
const std::vector< int > & LeftIndices() const
Definition: GBRTree.h:57
unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node)
std::vector< float > & CutVals()
Definition: GBRTree.h:53
std::vector< int > fRightIndices
Definition: GBRTree.h:73
std::vector< float > fResponses
Definition: GBRTree.h:74
std::vector< int > & LeftIndices()
Definition: GBRTree.h:56
const std::vector< unsigned char > & CutIndices() const
Definition: GBRTree.h:51
unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node)
double GetResponse(const float *vector) const
Definition: GBRTree.h:81
std::vector< float > fCutVals
Definition: GBRTree.h:71
std::vector< int > fLeftIndices
Definition: GBRTree.h:72
#define COND_SERIALIZABLE
Definition: Serializable.h:38
const std::vector< float > & Responses() const
Definition: GBRTree.h:48
virtual ~GBRTree()
int TerminalIndex(const float *vector) const
Definition: GBRTree.h:86
std::vector< int > & RightIndices()
Definition: GBRTree.h:59
std::vector< unsigned char > fCutIndices
Definition: GBRTree.h:70
void AddNode(const TMVA::DecisionTreeNode *node, double scale, bool isregression, bool useyesnoleaf, bool adjustboundary)
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:50