CMS 3D CMS Logo

GBRForest.h
Go to the documentation of this file.
1 
2 #ifndef EGAMMAOBJECTS_GBRForest
3 #define EGAMMAOBJECTS_GBRForest
4 
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to be built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
18 
20 
21 #include <vector>
22 #include "GBRTree.h"
23 #include <math.h>
24 #include <stdio.h>
25 
26  namespace TMVA {
27  class MethodBDT;
28  }
29 
30  class GBRForest {
31 
32  public:
33 
34  GBRForest();
35  explicit GBRForest(const TMVA::MethodBDT *bdt);
36  virtual ~GBRForest();
37 
38  double GetResponse(const float* vector) const;
39  double GetGradBoostClassifier(const float* vector) const;
40  double GetAdaBoostClassifier(const float* vector) const { return GetResponse(vector); }
41 
42  //for backwards-compatibility
43  double GetClassifier(const float* vector) const { return GetGradBoostClassifier(vector); }
44 
45  void SetInitialResponse(double response) { fInitialResponse = response; }
46 
47  std::vector<GBRTree> &Trees() { return fTrees; }
48  const std::vector<GBRTree> &Trees() const { return fTrees; }
49 
50  protected:
52  std::vector<GBRTree> fTrees;
53 
54 
56 };
57 
58 //_______________________________________________________________________
59 inline double GBRForest::GetResponse(const float* vector) const {
60  double response = fInitialResponse;
61  for (std::vector<GBRTree>::const_iterator it=fTrees.begin(); it!=fTrees.end(); ++it) {
62  response += it->GetResponse(vector);
63  }
64  return response;
65 }
66 
67 //_______________________________________________________________________
68 inline double GBRForest::GetGradBoostClassifier(const float* vector) const {
69  double response = GetResponse(vector);
70  return 2.0/(1.0+exp(-2.0*response))-1; //MVA output between -1 and 1
71 }
72 
73 #endif
double fInitialResponse
Definition: GBRForest.h:51
double GetResponse(const float *vector) const
Definition: GBRForest.h:59
std::vector< GBRTree > & Trees()
Definition: GBRForest.h:47
void SetInitialResponse(double response)
Definition: GBRForest.h:45
double GetGradBoostClassifier(const float *vector) const
Definition: GBRForest.h:68
std::vector< GBRTree > fTrees
Definition: GBRForest.h:52
const std::vector< GBRTree > & Trees() const
Definition: GBRForest.h:48
#define COND_SERIALIZABLE
Definition: Serializable.h:38
Definition: GBRForest.h:26
double GetClassifier(const float *vector) const
Definition: GBRForest.h:43
double GetAdaBoostClassifier(const float *vector) const
Definition: GBRForest.h:40