CMS 3D CMS Logo

GBRTreeD.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTreeD
3 #define EGAMMAOBJECTS_GBRTreeD
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - CERN //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
26 
27 #include <vector>
28 
29 class GBRTreeD {
30 public:
31  GBRTreeD() {}
32  template <typename InputTreeT>
33  GBRTreeD(const InputTreeT &tree);
34 
35  //double GetResponse(const float* vector) const;
36  double GetResponse(int termidx) const { return fResponses[termidx]; }
37  int TerminalIndex(const float *vector) const;
38 
39  std::vector<double> &Responses() { return fResponses; }
40  const std::vector<double> &Responses() const { return fResponses; }
41 
42  std::vector<unsigned short> &CutIndices() { return fCutIndices; }
43  const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
44 
45  std::vector<float> &CutVals() { return fCutVals; }
46  const std::vector<float> &CutVals() const { return fCutVals; }
47 
48  std::vector<int> &LeftIndices() { return fLeftIndices; }
49  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
50 
51  std::vector<int> &RightIndices() { return fRightIndices; }
52  const std::vector<int> &RightIndices() const { return fRightIndices; }
53 
54 protected:
55  std::vector<unsigned short> fCutIndices;
56  std::vector<float> fCutVals;
57  std::vector<int> fLeftIndices;
58  std::vector<int> fRightIndices;
59  std::vector<double> fResponses;
60 
62 };
63 
64 //_______________________________________________________________________
65 inline int GBRTreeD::TerminalIndex(const float *vector) const {
66  int index = 0;
67 
68  unsigned short cutindex = fCutIndices[0];
69  float cutval = fCutVals[0];
70 
71  while (true) {
72  if (vector[cutindex] > cutval) {
74  } else {
76  }
77 
78  if (index > 0) {
79  cutindex = fCutIndices[index];
80  cutval = fCutVals[index];
81  } else {
82  return (-index);
83  }
84  }
85 }
86 
87 //_______________________________________________________________________
88 template <typename InputTreeT>
89 GBRTreeD::GBRTreeD(const InputTreeT &tree)
90  : fCutIndices(tree.CutIndices()),
91  fCutVals(tree.CutVals()),
92  fLeftIndices(tree.LeftIndices()),
93  fRightIndices(tree.RightIndices()),
94  fResponses(tree.Responses()) {}
95 
96 #endif
const std::vector< int > & LeftIndices() const
Definition: GBRTreeD.h:49
std::vector< unsigned short > & CutIndices()
Definition: GBRTreeD.h:42
std::vector< int > fRightIndices
Definition: GBRTreeD.h:58
std::vector< int > & LeftIndices()
Definition: GBRTreeD.h:48
const std::vector< double > & Responses() const
Definition: GBRTreeD.h:40
std::vector< float > & CutVals()
Definition: GBRTreeD.h:45
std::vector< double > fResponses
Definition: GBRTreeD.h:59
std::vector< float > fCutVals
Definition: GBRTreeD.h:56
std::vector< int > fLeftIndices
Definition: GBRTreeD.h:57
std::vector< double > & Responses()
Definition: GBRTreeD.h:39
GBRTreeD()
Definition: GBRTreeD.h:31
const std::vector< float > & CutVals() const
Definition: GBRTreeD.h:46
int TerminalIndex(const float *vector) const
Definition: GBRTreeD.h:65
std::vector< unsigned short > fCutIndices
Definition: GBRTreeD.h:55
const std::vector< int > & RightIndices() const
Definition: GBRTreeD.h:52
#define COND_SERIALIZABLE
Definition: Serializable.h:39
Definition: tree.py:1
std::vector< int > & RightIndices()
Definition: GBRTreeD.h:51
const std::vector< unsigned short > & CutIndices() const
Definition: GBRTreeD.h:43
double GetResponse(int termidx) const
Definition: GBRTreeD.h:36