CMS 3D CMS Logo

GBRTreeD.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRTreeD
3 #define EGAMMAOBJECTS_GBRTreeD
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - CERN //
18 
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24 
26 
27 #include <vector>
28 #include <map>
29 #include <cstdio>
30 #include <cmath>
31 #include "Rtypes.h"
32 
33 class GBRTreeD {
34 public:
35  GBRTreeD() {}
36  template <typename InputTreeT>
37  GBRTreeD(const InputTreeT &tree);
38  virtual ~GBRTreeD();
39 
40  //double GetResponse(const float* vector) const;
41  double GetResponse(int termidx) const { return fResponses[termidx]; }
42  int TerminalIndex(const float *vector) const;
43 
44  std::vector<double> &Responses() { return fResponses; }
45  const std::vector<double> &Responses() const { return fResponses; }
46 
47  std::vector<unsigned short> &CutIndices() { return fCutIndices; }
48  const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
49 
50  std::vector<float> &CutVals() { return fCutVals; }
51  const std::vector<float> &CutVals() const { return fCutVals; }
52 
53  std::vector<int> &LeftIndices() { return fLeftIndices; }
54  const std::vector<int> &LeftIndices() const { return fLeftIndices; }
55 
56  std::vector<int> &RightIndices() { return fRightIndices; }
57  const std::vector<int> &RightIndices() const { return fRightIndices; }
58 
59 protected:
60  std::vector<unsigned short> fCutIndices;
61  std::vector<float> fCutVals;
62  std::vector<int> fLeftIndices;
63  std::vector<int> fRightIndices;
64  std::vector<double> fResponses;
65 
67 };
68 
69 //_______________________________________________________________________
70 inline int GBRTreeD::TerminalIndex(const float *vector) const {
71  int index = 0;
72 
73  unsigned short cutindex = fCutIndices[0];
74  float cutval = fCutVals[0];
75 
76  while (true) {
77  if (vector[cutindex] > cutval) {
79  } else {
81  }
82 
83  if (index > 0) {
84  cutindex = fCutIndices[index];
85  cutval = fCutVals[index];
86  } else {
87  return (-index);
88  }
89  }
90 }
91 
92 //_______________________________________________________________________
93 template <typename InputTreeT>
94 GBRTreeD::GBRTreeD(const InputTreeT &tree)
95  : fCutIndices(tree.CutIndices()),
96  fCutVals(tree.CutVals()),
97  fLeftIndices(tree.LeftIndices()),
98  fRightIndices(tree.RightIndices()),
99  fResponses(tree.Responses()) {}
100 
101 #endif
GBRTreeD::LeftIndices
std::vector< int > & LeftIndices()
Definition: GBRTreeD.h:53
GBRTreeD::fRightIndices
std::vector< int > fRightIndices
Definition: GBRTreeD.h:63
tree
Definition: tree.py:1
COND_SERIALIZABLE
#define COND_SERIALIZABLE
Definition: Serializable.h:39
GBRTreeD::Responses
std::vector< double > & Responses()
Definition: GBRTreeD.h:44
GBRTreeD::GetResponse
double GetResponse(int termidx) const
Definition: GBRTreeD.h:41
GBRTreeD::LeftIndices
const std::vector< int > & LeftIndices() const
Definition: GBRTreeD.h:54
GBRTreeD::TerminalIndex
int TerminalIndex(const float *vector) const
Definition: GBRTreeD.h:70
GBRTreeD::~GBRTreeD
virtual ~GBRTreeD()
Definition: GBRTreeD.cc:4
GBRTreeD::CutIndices
std::vector< unsigned short > & CutIndices()
Definition: GBRTreeD.h:47
GBRTreeD::CutVals
std::vector< float > & CutVals()
Definition: GBRTreeD.h:50
GBRTreeD::fCutVals
std::vector< float > fCutVals
Definition: GBRTreeD.h:61
Serializable.h
trackerHitRTTI::vector
Definition: trackerHitRTTI.h:21
GBRTreeD::fCutIndices
std::vector< unsigned short > fCutIndices
Definition: GBRTreeD.h:60
GBRTreeD::fResponses
std::vector< double > fResponses
Definition: GBRTreeD.h:64
GBRTreeD::RightIndices
std::vector< int > & RightIndices()
Definition: GBRTreeD.h:56
GBRTreeD::CutVals
const std::vector< float > & CutVals() const
Definition: GBRTreeD.h:51
GBRTreeD::RightIndices
const std::vector< int > & RightIndices() const
Definition: GBRTreeD.h:57
GBRTreeD::CutIndices
const std::vector< unsigned short > & CutIndices() const
Definition: GBRTreeD.h:48
GBRTreeD::Responses
const std::vector< double > & Responses() const
Definition: GBRTreeD.h:45
GBRTreeD
Definition: GBRTreeD.h:33
AlignmentPI::index
index
Definition: AlignmentPayloadInspectorHelper.h:46
GBRTreeD::fLeftIndices
std::vector< int > fLeftIndices
Definition: GBRTreeD.h:62
GBRTreeD::GBRTreeD
GBRTreeD()
Definition: GBRTreeD.h:35