CMS 3D CMS Logo

/afs/cern.ch/work/a/aaltunda/public/www/CMSSW_6_2_5/src/CondFormats/EgammaObjects/interface/GBRForest.h

Go to the documentation of this file.
00001 
00002 #ifndef EGAMMAOBJECTS_GBRForest
00003 #define EGAMMAOBJECTS_GBRForest
00004 
00006 //                                                                      //
00007 // GBRForest                                                            //
00008 //                                                                      //
00009 // A fast minimal implementation of Gradient-Boosted Regression Trees   //
00010 // which has been especially optimized for size on disk and in memory.  //                                                                  
00011 //                                                                      //
00012 // Designed to be built from TMVA-trained trees, but could also be      //
00013 // generalized to otherwise-trained trees, classification,              //
00014 //  or other boosting methods in the future                             //
00015 //                                                                      //
00016 //  Josh Bendavid - MIT                                                 //
00018 
00019 #include <vector>
00020 #include "GBRTree.h"
00021 #include <math.h>
00022 #include <stdio.h>
00023 
00024   namespace TMVA {
00025     class MethodBDT;
00026   }
00027 
00028   class GBRForest {
00029 
00030     public:
00031 
00032        GBRForest();
00033        explicit GBRForest(const TMVA::MethodBDT *bdt);
00034        virtual ~GBRForest();
00035        
00036        double GetResponse(const float* vector) const;
00037        double GetClassifier(const float* vector) const;
00038        
00039        void SetInitialResponse(double response) { fInitialResponse = response; }
00040        
00041        std::vector<GBRTree> &Trees() { return fTrees; }
00042        const std::vector<GBRTree> &Trees() const { return fTrees; }
00043        
00044     protected:
00045       double               fInitialResponse;
00046       std::vector<GBRTree> fTrees;  
00047       
00048   };
00049 
00050 //_______________________________________________________________________
00051 inline double GBRForest::GetResponse(const float* vector) const {
00052   double response = fInitialResponse;
00053   for (std::vector<GBRTree>::const_iterator it=fTrees.begin(); it!=fTrees.end(); ++it) {
00054     response += it->GetResponse(vector);
00055   }
00056   return response;
00057 }
00058 
00059 //_______________________________________________________________________
00060 inline double GBRForest::GetClassifier(const float* vector) const {
00061   double response = GetResponse(vector);
00062   return 2.0/(1.0+exp(-2.0*response))-1; //MVA output between -1 and 1
00063 }
00064 
00065 #endif