CMS 3D CMS Logo

GBRTree.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree
3 #define EGAMMAOBJECTS_GBRTree
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
25 
27 
28 #include <vector>
29 
30  class GBRTree {
31 
32  public:
33 
34  GBRTree();
35  explicit GBRTree(int nIntermediate, int nTerminal);
36 
37  double GetResponse(const float* vector) const;
38 
39  std::vector<float> &Responses() { return fResponses; }
40  const std::vector<float> &Responses() const { return fResponses; }
41 
42  std::vector<unsigned char> &CutIndices() { return fCutIndices; }
43  const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
44 
45  std::vector<float> &CutVals() { return fCutVals; }
46  const std::vector<float> &CutVals() const { return fCutVals; }
47 
48  std::vector<int> &LeftIndices() { return fLeftIndices; }
49  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
50 
51  std::vector<int> &RightIndices() { return fRightIndices; }
52  const std::vector<int> &RightIndices() const { return fRightIndices; }
53 
54  protected:
55 
56  std::vector<unsigned char> fCutIndices;
57  std::vector<float> fCutVals;
58  std::vector<int> fLeftIndices;
59  std::vector<int> fRightIndices;
60  std::vector<float> fResponses;
61 
62 
64 };
65 
66 //_______________________________________________________________________
67 inline double GBRTree::GetResponse(const float* vector) const
68 {
69  int index = 0;
70  do {
71  auto r = fRightIndices[index];
72  auto l = fLeftIndices[index];
73  index = vector[fCutIndices[index]] > fCutVals[index] ? r : l;
74  } while (index>0);
75  return fResponses[-index];
76 }
77 
78 #endif
const std::vector< float > & CutVals() const
Definition: GBRTree.h:46
std::vector< float > & Responses()
Definition: GBRTree.h:39
const std::vector< int > & RightIndices() const
Definition: GBRTree.h:52
const std::vector< int > & LeftIndices() const
Definition: GBRTree.h:49
std::vector< float > & CutVals()
Definition: GBRTree.h:45
std::vector< int > fRightIndices
Definition: GBRTree.h:59
std::vector< float > fResponses
Definition: GBRTree.h:60
std::vector< int > & LeftIndices()
Definition: GBRTree.h:48
const std::vector< unsigned char > & CutIndices() const
Definition: GBRTree.h:43
double GetResponse(const float *vector) const
Definition: GBRTree.h:67
std::vector< float > fCutVals
Definition: GBRTree.h:57
std::vector< int > fLeftIndices
Definition: GBRTree.h:58
#define COND_SERIALIZABLE
Definition: Serializable.h:38
const std::vector< float > & Responses() const
Definition: GBRTree.h:40
std::vector< int > & RightIndices()
Definition: GBRTree.h:51
std::vector< unsigned char > fCutIndices
Definition: GBRTree.h:56
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:42