00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include <string>
00019 #include <vector>
00020 #include <memory>
00021
00022
00023 #include <RVersion.h>
00024
00025 #include <TMVA/DataSet.h>
00026 #include <TMVA/Types.h>
00027 #include <TMVA/MethodBase.h>
00028 #include <TMVA/Methods.h>
00029
00030 #include "PhysicsTools/MVAComputer/interface/memstream.h"
00031 #include "PhysicsTools/MVAComputer/interface/zstream.h"
00032
00033 #include "PhysicsTools/MVAComputer/interface/VarProcessor.h"
00034 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00035
00036 using namespace PhysicsTools;
00037
00038 namespace {
00039
00040 class ProcTMVA : public VarProcessor {
00041 public:
00042 typedef VarProcessor::Registry::Registry<ProcTMVA,
00043 Calibration::ProcTMVA> Registry;
00044
00045 ProcTMVA(const char *name,
00046 const Calibration::ProcTMVA *calib,
00047 const MVAComputer *computer);
00048 virtual ~ProcTMVA() {}
00049
00050 virtual void configure(ConfIterator iter, unsigned int n);
00051 virtual void eval(ValueIterator iter, unsigned int n) const;
00052
00053 private:
00054 mutable TMVA::DataSet data;
00055 std::auto_ptr<TMVA::MethodBase> method;
00056 unsigned int nVars;
00057 };
00058
00059 static ProcTMVA::Registry registry("ProcTMVA");
00060
00061 #define SWITCH_METHOD(name) \
00062 case (TMVA::Types::k##name): \
00063 return new TMVA::Method##name(*data, "");
00064
00065 static TMVA::MethodBase *methodInst(TMVA::DataSet *data, TMVA::Types::EMVA type)
00066 {
00067 switch(type) {
00068 SWITCH_METHOD(Cuts)
00069 SWITCH_METHOD(SeedDistance)
00070 SWITCH_METHOD(Likelihood)
00071 SWITCH_METHOD(PDERS)
00072 SWITCH_METHOD(HMatrix)
00073 SWITCH_METHOD(Fisher)
00074 SWITCH_METHOD(CFMlpANN)
00075 SWITCH_METHOD(TMlpANN)
00076 SWITCH_METHOD(BDT)
00077 SWITCH_METHOD(RuleFit)
00078 SWITCH_METHOD(SVM)
00079 SWITCH_METHOD(MLP)
00080 SWITCH_METHOD(BayesClassifier)
00081 SWITCH_METHOD(FDA)
00082 SWITCH_METHOD(Committee)
00083 default:
00084 return 0;
00085 }
00086 }
00087
00088 #undef SWITCH_METHOD
00089
00090 ProcTMVA::ProcTMVA(const char *name,
00091 const Calibration::ProcTMVA *calib,
00092 const MVAComputer *computer) :
00093 VarProcessor(name, calib, computer),
00094 nVars(calib->variables.size())
00095 {
00096 for(std::vector<std::string>::const_iterator iter =
00097 calib->variables.begin();
00098 iter != calib->variables.end(); iter++)
00099 data.AddVariable(iter->c_str());
00100
00101 ext::imemstream is(
00102 reinterpret_cast<const char*>(&calib->store.front()),
00103 calib->store.size());
00104 ext::izstream izs(&is);
00105
00106 TMVA::Types::EMVA methodType =
00107 TMVA::Types::Instance().GetMethodType(calib->method);
00108
00109 method = std::auto_ptr<TMVA::MethodBase>(
00110 methodInst(&data, methodType));
00111
00112 method->ReadStateFromStream(izs);
00113 }
00114
00115 void ProcTMVA::configure(ConfIterator iter, unsigned int n)
00116 {
00117 if (n != nVars)
00118 return;
00119
00120 for(unsigned int i = 0; i < n; i++)
00121 iter++(Variable::FLAG_NONE);
00122
00123 iter << Variable::FLAG_NONE;
00124 }
00125
00126 void ProcTMVA::eval(ValueIterator iter, unsigned int n) const
00127 {
00128 for(unsigned int i = 0; i < n; i++)
00129 data.GetEvent().SetVal(i, *iter++);
00130
00131 method->GetVarTransform().GetEventRaw().CopyVarValues(data.GetEvent());
00132 method->GetVarTransform().ApplyTransformation(TMVA::Types::kSignal);
00133 iter(method->GetMvaValue());
00134 }
00135
00136 }