CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_6_1_2_SLHC4_patch1/src/CondFormats/EgammaObjects/interface/GBRTree.h

Go to the documentation of this file.
00001 
00002 #ifndef EGAMMAOBJECTS_GBRTree
00003 #define EGAMMAOBJECTS_GBRTree
00004 
00006 //                                                                      //
00007 // GBRForest                                                            //
00008 //                                                                      //
00009 // A fast minimal implementation of Gradient-Boosted Regression Trees   //
00010 // which has been especially optimized for size on disk and in memory.  //                                                                  
00011 //                                                                      //
00012 // Designed to be built from TMVA-trained trees, but could also be      //
00013 // generalized to otherwise-trained trees, classification,              //
00014 //  or other boosting methods in the future                             //
00015 //                                                                      //
00016 //  Josh Bendavid - MIT                                                 //
00018 
00019 // The decision tree is implemented here as a set of two arrays, one for
00020 // intermediate nodes, containing the variable index and cut value, as well
00021 // as the indices of the 'left' and 'right' daughter nodes.  Positive indices
00022 // indicate further intermediate nodes, whereas negative indices indicate
00023 // terminal nodes, which are stored simply as a vector of regression responses
00024 
00025 
00026 #include <vector>
00027 #include <map>
00028 
00029   namespace TMVA {
00030     class DecisionTree;
00031     class DecisionTreeNode;
00032   }
00033 
00034   class GBRTree {
00035 
00036     public:
00037 
00038        GBRTree();
00039        explicit GBRTree(const TMVA::DecisionTree *tree);
00040        virtual ~GBRTree();
00041        
00042        double GetResponse(const float* vector) const;
00043        int TerminalIndex(const float *vector) const;
00044        
00045        std::vector<float> &Responses() { return fResponses; }       
00046        const std::vector<float> &Responses() const { return fResponses; }
00047        
00048        std::vector<unsigned char> &CutIndices() { return fCutIndices; }
00049        const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
00050        
00051        std::vector<float> &CutVals() { return fCutVals; }
00052        const std::vector<float> &CutVals() const { return fCutVals; }
00053        
00054        std::vector<int> &LeftIndices() { return fLeftIndices; }
00055        const std::vector<int> &LeftIndices() const { return fLeftIndices; } 
00056        
00057        std::vector<int> &RightIndices() { return fRightIndices; }
00058        const std::vector<int> &RightIndices() const { return fRightIndices; }
00059        
00060 
00061        
00062     protected:      
00063         unsigned int CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
00064         unsigned int CountTerminalNodes(const TMVA::DecisionTreeNode *node);
00065       
00066         void AddNode(const TMVA::DecisionTreeNode *node);
00067         
00068         std::vector<unsigned char> fCutIndices;
00069         std::vector<float> fCutVals;
00070         std::vector<int> fLeftIndices;
00071         std::vector<int> fRightIndices;
00072         std::vector<float> fResponses;  
00073         
00074   };
00075 
00076 //_______________________________________________________________________
00077 inline double GBRTree::GetResponse(const float* vector) const {
00078   
00079   int index = 0;
00080   
00081   unsigned char cutindex = fCutIndices[0];
00082   float cutval = fCutVals[0];
00083   
00084   while (true) {
00085      
00086     if (vector[cutindex] > cutval) {
00087       index = fRightIndices[index];
00088     }
00089     else {
00090       index = fLeftIndices[index];
00091     }
00092     
00093     if (index>0) {
00094       cutindex = fCutIndices[index];
00095       cutval = fCutVals[index];
00096     }
00097     else {
00098       return fResponses[-index];
00099     }
00100     
00101   }
00102   
00103 
00104 }
00105 
00106 //_______________________________________________________________________
00107 inline int GBRTree::TerminalIndex(const float* vector) const {
00108   
00109   int index = 0;
00110   
00111   unsigned char cutindex = fCutIndices[0];
00112   float cutval = fCutVals[0];
00113   
00114   while (true) {
00115     if (vector[cutindex] > cutval) {
00116       index = fRightIndices[index];
00117     }
00118     else {
00119       index = fLeftIndices[index];
00120     }
00121     
00122     if (index>0) {
00123       cutindex = fCutIndices[index];
00124       cutval = fCutVals[index];
00125     }
00126     else {
00127       return (-index);
00128     }
00129     
00130   }
00131   
00132 
00133 }
00134   
00135 #endif