00001 00002 #ifndef EGAMMAOBJECTS_GBRTree2D 00003 #define EGAMMAOBJECTS_GBRTree2D 00004 00006 // // 00007 // GBRForest // 00008 // // 00009 // A fast minimal implementation of Gradient-Boosted Regression Trees // 00010 // which has been especially optimized for size on disk and in memory. // 00011 // // 00012 // Designed to be built from TMVA-trained trees, but could also be // 00013 // generalized to otherwise-trained trees, classification, // 00014 // or other boosting methods in the future // 00015 // // 00016 // Josh Bendavid - MIT // 00018 00019 // The decision tree is implemented here as a set of two arrays, one for 00020 // intermediate nodes, containing the variable index and cut value, as well 00021 // as the indices of the 'left' and 'right' daughter nodes. Positive indices 00022 // indicate further intermediate nodes, whereas negative indices indicate 00023 // terminal nodes, which are stored simply as a vector of regression responses 00024 00025 00026 #include <vector> 00027 #include <map> 00028 00029 class GBRTree2D { 00030 00031 public: 00032 00033 GBRTree2D() {} 00034 ~GBRTree2D() {} 00035 00036 void GetResponse(const float* vector, double &x, double &y) const; 00037 int TerminalIndex(const float *vector) const; 00038 00039 std::vector<float> &ResponsesX() { return fResponsesX; } 00040 const std::vector<float> &ResponsesX() const { return fResponsesX; } 00041 00042 std::vector<float> &ResponsesY() { return fResponsesY; } 00043 const std::vector<float> &ResponsesY() const { return fResponsesY; } 00044 00045 std::vector<unsigned short> &CutIndices() { return fCutIndices; } 00046 const std::vector<unsigned short> &CutIndices() const { return fCutIndices; } 00047 00048 std::vector<float> &CutVals() { return fCutVals; } 00049 const std::vector<float> &CutVals() const { return fCutVals; } 00050 00051 std::vector<int> &LeftIndices() { return fLeftIndices; } 00052 const std::vector<int> &LeftIndices() const { return fLeftIndices; } 00053 00054 std::vector<int> &RightIndices() { return fRightIndices; } 00055 const std::vector<int> &RightIndices() const { return fRightIndices; } 00056 00057 00058 00059 protected: 00060 std::vector<unsigned short> fCutIndices; 00061 std::vector<float> fCutVals; 00062 std::vector<int> fLeftIndices; 00063 std::vector<int> fRightIndices; 00064 std::vector<float> fResponsesX; 00065 std::vector<float> fResponsesY; 00066 00067 }; 00068 00069 //_______________________________________________________________________ 00070 inline void GBRTree2D::GetResponse(const float* vector, double &x, double &y) const { 00071 00072 int index = 0; 00073 00074 unsigned short cutindex = fCutIndices[0]; 00075 float cutval = fCutVals[0]; 00076 00077 while (true) { 00078 00079 if (vector[cutindex] > cutval) { 00080 index = fRightIndices[index]; 00081 } 00082 else { 00083 index = fLeftIndices[index]; 00084 } 00085 00086 if (index>0) { 00087 cutindex = fCutIndices[index]; 00088 cutval = fCutVals[index]; 00089 } 00090 else { 00091 x = fResponsesX[-index]; 00092 y = fResponsesY[-index]; 00093 return; 00094 } 00095 00096 } 00097 00098 00099 } 00100 00101 //_______________________________________________________________________ 00102 inline int GBRTree2D::TerminalIndex(const float* vector) const { 00103 00104 int index = 0; 00105 00106 unsigned short cutindex = fCutIndices[0]; 00107 float cutval = fCutVals[0]; 00108 00109 while (true) { 00110 if (vector[cutindex] > cutval) { 00111 index = fRightIndices[index]; 00112 } 00113 else { 00114 index = fLeftIndices[index]; 00115 } 00116 00117 if (index>0) { 00118 cutindex = fCutIndices[index]; 00119 cutval = fCutVals[index]; 00120 } 00121 else { 00122 return (-index); 00123 } 00124 00125 } 00126 00127 00128 } 00129 00130 #endif