CMS 3D CMS Logo

ProcTMVA.cc

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 //
00003 // Package:     MVAComputer
00004 // Class  :     ProcTMVA
00005 // 
00006 
00007 // Implementation:
00008 //     TMVA wrapper, needs n non-optional, non-multiple input variables
00009 //     and outputs one result variable. All TMVA algorithms can be used,
00010 //     calibration data is passed via stream and extracted from a zipped
00011 //     buffer.
00012 //
00013 // Author:      Christophe Saout
00014 // Created:     Sat Apr 24 15:18 CEST 2007
00015 // $Id: ProcTMVA.cc,v 1.7 2008/03/15 22:26:55 saout Exp $
00016 //
00017 
00018 #include <string>
00019 #include <vector>
00020 #include <memory>
00021 
00022 // ROOT version magic to support TMVA interface changes in newer ROOT
00023 #include <RVersion.h>
00024 
00025 #include <TMVA/DataSet.h>
00026 #include <TMVA/Types.h>
00027 #include <TMVA/MethodBase.h>
00028 #include <TMVA/Methods.h>
00029 
00030 #include "PhysicsTools/MVAComputer/interface/memstream.h"
00031 #include "PhysicsTools/MVAComputer/interface/zstream.h"
00032 
00033 #include "PhysicsTools/MVAComputer/interface/VarProcessor.h"
00034 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00035 
00036 using namespace PhysicsTools;
00037 
00038 namespace { // anonymous
00039 
00040 class ProcTMVA : public VarProcessor {
00041     public:
00042         typedef VarProcessor::Registry::Registry<ProcTMVA,
00043                                         Calibration::ProcTMVA> Registry;
00044 
00045         ProcTMVA(const char *name,
00046                  const Calibration::ProcTMVA *calib,
00047                  const MVAComputer *computer);
00048         virtual ~ProcTMVA() {}
00049 
00050         virtual void configure(ConfIterator iter, unsigned int n);
00051         virtual void eval(ValueIterator iter, unsigned int n) const;
00052 
00053     private:
00054         mutable TMVA::DataSet           data;
00055         std::auto_ptr<TMVA::MethodBase> method;
00056         unsigned int                    nVars;
00057 };
00058 
00059 static ProcTMVA::Registry registry("ProcTMVA");
00060 
00061 #define SWITCH_METHOD(name)                                     \
00062         case (TMVA::Types::k##name):                            \
00063                 return new TMVA::Method##name(*data, "");
00064 
00065 static TMVA::MethodBase *methodInst(TMVA::DataSet *data, TMVA::Types::EMVA type)
00066 {
00067         switch(type) {
00068                 SWITCH_METHOD(Cuts)
00069                 SWITCH_METHOD(SeedDistance)
00070                 SWITCH_METHOD(Likelihood)
00071                 SWITCH_METHOD(PDERS)
00072                 SWITCH_METHOD(HMatrix)
00073                 SWITCH_METHOD(Fisher)
00074                 SWITCH_METHOD(CFMlpANN)
00075                 SWITCH_METHOD(TMlpANN)
00076                 SWITCH_METHOD(BDT)
00077                 SWITCH_METHOD(RuleFit)
00078                 SWITCH_METHOD(SVM)
00079                 SWITCH_METHOD(MLP)
00080                 SWITCH_METHOD(BayesClassifier)
00081                 SWITCH_METHOD(FDA)
00082                 SWITCH_METHOD(Committee)
00083             default:
00084                 return 0;
00085         }
00086 }
00087 
00088 #undef SWITCH_METHOD
00089 
00090 ProcTMVA::ProcTMVA(const char *name,
00091                    const Calibration::ProcTMVA *calib,
00092                    const MVAComputer *computer) :
00093         VarProcessor(name, calib, computer),
00094         nVars(calib->variables.size())
00095 {
00096         for(std::vector<std::string>::const_iterator iter =
00097                                                 calib->variables.begin();
00098             iter != calib->variables.end(); iter++)
00099                 data.AddVariable(iter->c_str());
00100 
00101         ext::imemstream is(
00102                 reinterpret_cast<const char*>(&calib->store.front()),
00103                 calib->store.size());
00104         ext::izstream izs(&is);
00105 
00106         TMVA::Types::EMVA methodType =
00107                         TMVA::Types::Instance().GetMethodType(calib->method);
00108 
00109         method = std::auto_ptr<TMVA::MethodBase>(
00110                                         methodInst(&data, methodType));
00111 
00112         method->ReadStateFromStream(izs);
00113 }
00114 
00115 void ProcTMVA::configure(ConfIterator iter, unsigned int n)
00116 {
00117         if (n != nVars)
00118                 return;
00119 
00120         for(unsigned int i = 0; i < n; i++)
00121                 iter++(Variable::FLAG_NONE);
00122 
00123         iter << Variable::FLAG_NONE;
00124 }
00125 
00126 void ProcTMVA::eval(ValueIterator iter, unsigned int n) const
00127 {
00128         for(unsigned int i = 0; i < n; i++)
00129                 data.GetEvent().SetVal(i, *iter++);
00130 
00131         method->GetVarTransform().GetEventRaw().CopyVarValues(data.GetEvent());
00132         method->GetVarTransform().ApplyTransformation(TMVA::Types::kSignal);
00133         iter(method->GetMvaValue());
00134 }
00135 
00136 } // anonymous namespace

Generated on Tue Jun 9 17:41:31 2009 for CMSSW by  doxygen 1.5.4