CMS 3D CMS Logo

GBRTree.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree
3 #define EGAMMAOBJECTS_GBRTree
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
26 
27 #include <vector>
28 
29 class GBRTree {
30 public:
31  GBRTree() {}
32  explicit GBRTree(int nIntermediate, int nTerminal);
33 
34  double GetResponse(const float *vector) const;
35 
36  std::vector<float> &Responses() { return fResponses; }
37  const std::vector<float> &Responses() const { return fResponses; }
38 
39  std::vector<unsigned char> &CutIndices() { return fCutIndices; }
40  const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
41 
42  std::vector<float> &CutVals() { return fCutVals; }
43  const std::vector<float> &CutVals() const { return fCutVals; }
44 
45  std::vector<int> &LeftIndices() { return fLeftIndices; }
46  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
47 
48  std::vector<int> &RightIndices() { return fRightIndices; }
49  const std::vector<int> &RightIndices() const { return fRightIndices; }
50 
51 protected:
52  std::vector<unsigned char> fCutIndices;
53  std::vector<float> fCutVals;
54  std::vector<int> fLeftIndices;
55  std::vector<int> fRightIndices;
56  std::vector<float> fResponses;
57 
59 };
60 
61 //_______________________________________________________________________
62 inline double GBRTree::GetResponse(const float *vector) const {
63  int index = 0;
64  do {
65  auto r = fRightIndices[index];
66  auto l = fLeftIndices[index];
67  unsigned int x = vector[fCutIndices[index]] > fCutVals[index] ? ~0 : 0;
68  index = (x & r) | ((~x) & l);
69  } while (index > 0);
70  return fResponses[-index];
71 }
72 
73 #endif
std::vector< float > & Responses()
Definition: GBRTree.h:36
const std::vector< float > & Responses() const
Definition: GBRTree.h:37
GBRTree()
Definition: GBRTree.h:31
double GetResponse(const float *vector) const
Definition: GBRTree.h:62
const std::vector< int > & LeftIndices() const
Definition: GBRTree.h:46
std::vector< float > & CutVals()
Definition: GBRTree.h:42
std::vector< int > fRightIndices
Definition: GBRTree.h:55
std::vector< float > fResponses
Definition: GBRTree.h:56
std::vector< int > & LeftIndices()
Definition: GBRTree.h:45
const std::vector< unsigned char > & CutIndices() const
Definition: GBRTree.h:40
std::vector< float > fCutVals
Definition: GBRTree.h:53
std::vector< int > fLeftIndices
Definition: GBRTree.h:54
#define COND_SERIALIZABLE
Definition: Serializable.h:39
const std::vector< float > & CutVals() const
Definition: GBRTree.h:43
const std::vector< int > & RightIndices() const
Definition: GBRTree.h:49
std::vector< int > & RightIndices()
Definition: GBRTree.h:48
std::vector< unsigned char > fCutIndices
Definition: GBRTree.h:52
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:39