CMS 3D CMS Logo

GBRTree.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTree
3 #define EGAMMAOBJECTS_GBRTree
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
26 
27 #include <vector>
28 
29 class GBRTree {
30 public:
31  GBRTree();
32  explicit GBRTree(int nIntermediate, int nTerminal);
33 
34  double GetResponse(const float *vector) const;
35 
36  std::vector<float> &Responses() { return fResponses; }
37  const std::vector<float> &Responses() const { return fResponses; }
38 
39  std::vector<unsigned char> &CutIndices() { return fCutIndices; }
40  const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
41 
42  std::vector<float> &CutVals() { return fCutVals; }
43  const std::vector<float> &CutVals() const { return fCutVals; }
44 
45  std::vector<int> &LeftIndices() { return fLeftIndices; }
46  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
47 
48  std::vector<int> &RightIndices() { return fRightIndices; }
49  const std::vector<int> &RightIndices() const { return fRightIndices; }
50 
51 protected:
52  std::vector<unsigned char> fCutIndices;
53  std::vector<float> fCutVals;
54  std::vector<int> fLeftIndices;
55  std::vector<int> fRightIndices;
56  std::vector<float> fResponses;
57 
59 };
60 
61 //_______________________________________________________________________
62 inline double GBRTree::GetResponse(const float *vector) const {
63  int index = 0;
64  do {
65  auto r = fRightIndices[index];
66  auto l = fLeftIndices[index];
67  index = vector[fCutIndices[index]] > fCutVals[index] ? r : l;
68  } while (index > 0);
69  return fResponses[-index];
70 }
71 
72 #endif
GBRTree::LeftIndices
std::vector< int > & LeftIndices()
Definition: GBRTree.h:45
GBRTree::fCutIndices
std::vector< unsigned char > fCutIndices
Definition: GBRTree.h:52
GBRTree::GBRTree
GBRTree()
COND_SERIALIZABLE
#define COND_SERIALIZABLE
Definition: Serializable.h:39
GBRTree::fLeftIndices
std::vector< int > fLeftIndices
Definition: GBRTree.h:54
GBRTree::fRightIndices
std::vector< int > fRightIndices
Definition: GBRTree.h:55
GBRTree::RightIndices
const std::vector< int > & RightIndices() const
Definition: GBRTree.h:49
GBRTree::Responses
const std::vector< float > & Responses() const
Definition: GBRTree.h:37
GBRTree::CutIndices
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:39
Serializable.h
GBRTree::fCutVals
std::vector< float > fCutVals
Definition: GBRTree.h:53
GBRTree::fResponses
std::vector< float > fResponses
Definition: GBRTree.h:56
GBRTree::CutVals
std::vector< float > & CutVals()
Definition: GBRTree.h:42
GBRTree::CutVals
const std::vector< float > & CutVals() const
Definition: GBRTree.h:43
cmsLHEtoEOSManager.l
l
Definition: cmsLHEtoEOSManager.py:193
alignCSCRings.r
r
Definition: alignCSCRings.py:93
GBRTree
Definition: GBRTree.h:29
GBRTree::CutIndices
const std::vector< unsigned char > & CutIndices() const
Definition: GBRTree.h:40
GBRTree::LeftIndices
const std::vector< int > & LeftIndices() const
Definition: GBRTree.h:46
GBRTree::RightIndices
std::vector< int > & RightIndices()
Definition: GBRTree.h:48
AlignmentPI::index
index
Definition: AlignmentPayloadInspectorHelper.h:46
GBRTree::GetResponse
double GetResponse(const float *vector) const
Definition: GBRTree.h:62
GBRTree::Responses
std::vector< float > & Responses()
Definition: GBRTree.h:36