CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_3_9_patch3/src/CondFormats/EgammaObjects/interface/GBRTree2D.h

Go to the documentation of this file.
00001 
00002 #ifndef EGAMMAOBJECTS_GBRTree2D
00003 #define EGAMMAOBJECTS_GBRTree2D
00004 
00006 //                                                                      //
00007 // GBRForest                                                            //
00008 //                                                                      //
00009 // A fast minimal implementation of Gradient-Boosted Regression Trees   //
00010 // which has been especially optimized for size on disk and in memory.  //                                                                  
00011 //                                                                      //
00012 // Designed to be built from TMVA-trained trees, but could also be      //
00013 // generalized to otherwise-trained trees, classification,              //
00014 //  or other boosting methods in the future                             //
00015 //                                                                      //
00016 //  Josh Bendavid - MIT                                                 //
00018 
00019 // The decision tree is implemented here as a set of two arrays, one for
00020 // intermediate nodes, containing the variable index and cut value, as well
00021 // as the indices of the 'left' and 'right' daughter nodes.  Positive indices
00022 // indicate further intermediate nodes, whereas negative indices indicate
00023 // terminal nodes, which are stored simply as a vector of regression responses
00024 
00025 
00026 #include <vector>
00027 #include <map>
00028 
00029   class GBRTree2D {
00030 
00031     public:
00032 
00033        GBRTree2D() {}
00034        ~GBRTree2D() {}
00035        
00036        void GetResponse(const float* vector, double &x, double &y) const;
00037        int TerminalIndex(const float *vector) const;
00038        
00039        std::vector<float> &ResponsesX() { return fResponsesX; }       
00040        const std::vector<float> &ResponsesX() const { return fResponsesX; }
00041 
00042        std::vector<float> &ResponsesY() { return fResponsesY; }       
00043        const std::vector<float> &ResponsesY() const { return fResponsesY; }
00044 
00045        std::vector<unsigned short> &CutIndices() { return fCutIndices; }
00046        const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
00047        
00048        std::vector<float> &CutVals() { return fCutVals; }
00049        const std::vector<float> &CutVals() const { return fCutVals; }
00050        
00051        std::vector<int> &LeftIndices() { return fLeftIndices; }
00052        const std::vector<int> &LeftIndices() const { return fLeftIndices; } 
00053        
00054        std::vector<int> &RightIndices() { return fRightIndices; }
00055        const std::vector<int> &RightIndices() const { return fRightIndices; }
00056        
00057 
00058        
00059     protected:              
00060         std::vector<unsigned short> fCutIndices;
00061         std::vector<float> fCutVals;
00062         std::vector<int> fLeftIndices;
00063         std::vector<int> fRightIndices;
00064         std::vector<float> fResponsesX;  
00065         std::vector<float> fResponsesY;
00066         
00067   };
00068 
00069 //_______________________________________________________________________
00070 inline void GBRTree2D::GetResponse(const float* vector, double &x, double &y) const {
00071   
00072   int index = 0;
00073   
00074   unsigned short cutindex = fCutIndices[0];
00075   float cutval = fCutVals[0];
00076   
00077   while (true) {
00078      
00079     if (vector[cutindex] > cutval) {
00080       index = fRightIndices[index];
00081     }
00082     else {
00083       index = fLeftIndices[index];
00084     }
00085     
00086     if (index>0) {
00087       cutindex = fCutIndices[index];
00088       cutval = fCutVals[index];
00089     }
00090     else {
00091       x = fResponsesX[-index];
00092       y = fResponsesY[-index];
00093       return;
00094     }
00095     
00096   }
00097   
00098 
00099 }
00100 
00101 //_______________________________________________________________________
00102 inline int GBRTree2D::TerminalIndex(const float* vector) const {
00103   
00104   int index = 0;
00105   
00106   unsigned short cutindex = fCutIndices[0];
00107   float cutval = fCutVals[0];
00108   
00109   while (true) {
00110     if (vector[cutindex] > cutval) {
00111       index = fRightIndices[index];
00112     }
00113     else {
00114       index = fLeftIndices[index];
00115     }
00116     
00117     if (index>0) {
00118       cutindex = fCutIndices[index];
00119       cutval = fCutVals[index];
00120     }
00121     else {
00122       return (-index);
00123     }
00124     
00125   }
00126   
00127 
00128 }
00129   
00130 #endif