00001 00002 #ifndef EGAMMAOBJECTS_GBRTree 00003 #define EGAMMAOBJECTS_GBRTree 00004 00006 // // 00007 // GBRForest // 00008 // // 00009 // A fast minimal implementation of Gradient-Boosted Regression Trees // 00010 // which has been especially optimized for size on disk and in memory. // 00011 // // 00012 // Designed to be built from TMVA-trained trees, but could also be // 00013 // generalized to otherwise-trained trees, classification, // 00014 // or other boosting methods in the future // 00015 // // 00016 // Josh Bendavid - MIT // 00018 00019 // The decision tree is implemented here as a set of two arrays, one for 00020 // intermediate nodes, containing the variable index and cut value, as well 00021 // as the indices of the 'left' and 'right' daughter nodes. Positive indices 00022 // indicate further intermediate nodes, whereas negative indices indicate 00023 // terminal nodes, which are stored simply as a vector of regression responses 00024 00025 00026 #include <vector> 00027 #include <map> 00028 00029 namespace TMVA { 00030 class DecisionTree; 00031 class DecisionTreeNode; 00032 } 00033 00034 class GBRTree { 00035 00036 public: 00037 00038 GBRTree(); 00039 explicit GBRTree(const TMVA::DecisionTree *tree); 00040 virtual ~GBRTree(); 00041 00042 double GetResponse(const float* vector) const; 00043 int TerminalIndex(const float *vector) const; 00044 00045 std::vector<float> &Responses() { return fResponses; } 00046 const std::vector<float> &Responses() const { return fResponses; } 00047 00048 std::vector<unsigned char> &CutIndices() { return fCutIndices; } 00049 const std::vector<unsigned char> &CutIndices() const { return fCutIndices; } 00050 00051 std::vector<float> &CutVals() { return fCutVals; } 00052 const std::vector<float> &CutVals() const { return fCutVals; } 00053 00054 std::vector<int> &LeftIndices() { return fLeftIndices; } 00055 const std::vector<int> &LeftIndices() const { return fLeftIndices; } 00056 00057 std::vector<int> &RightIndices() { return fRightIndices; } 00058 const std::vector<int> &RightIndices() const { return fRightIndices; } 00059 00060 00061 00062 protected: 00063 unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node); 00064 unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node); 00065 00066 void AddNode(const TMVA::DecisionTreeNode *node); 00067 00068 std::vector<unsigned char> fCutIndices; 00069 std::vector<float> fCutVals; 00070 std::vector<int> fLeftIndices; 00071 std::vector<int> fRightIndices; 00072 std::vector<float> fResponses; 00073 00074 }; 00075 00076 //_______________________________________________________________________ 00077 inline double GBRTree::GetResponse(const float* vector) const { 00078 00079 int index = 0; 00080 00081 unsigned char cutindex = fCutIndices[0]; 00082 float cutval = fCutVals[0]; 00083 00084 while (true) { 00085 00086 if (vector[cutindex] > cutval) { 00087 index = fRightIndices[index]; 00088 } 00089 else { 00090 index = fLeftIndices[index]; 00091 } 00092 00093 if (index>0) { 00094 cutindex = fCutIndices[index]; 00095 cutval = fCutVals[index]; 00096 } 00097 else { 00098 return fResponses[-index]; 00099 } 00100 00101 } 00102 00103 00104 } 00105 00106 //_______________________________________________________________________ 00107 inline int GBRTree::TerminalIndex(const float* vector) const { 00108 00109 int index = 0; 00110 00111 unsigned char cutindex = fCutIndices[0]; 00112 float cutval = fCutVals[0]; 00113 00114 while (true) { 00115 if (vector[cutindex] > cutval) { 00116 index = fRightIndices[index]; 00117 } 00118 else { 00119 index = fLeftIndices[index]; 00120 } 00121 00122 if (index>0) { 00123 cutindex = fCutIndices[index]; 00124 cutval = fCutVals[index]; 00125 } 00126 else { 00127 return (-index); 00128 } 00129 00130 } 00131 00132 00133 } 00134 00135 #endif