00001 #ifndef PhysicsTools_MVAComputer_MVAModuleHelper_h
00002 #define PhysicsTools_MVAComputer_MVAModuleHelper_h
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include <functional>
00016 #include <algorithm>
00017 #include <numeric>
00018 #include <cstring>
00019 #include <string>
00020 #include <vector>
00021 #include <cmath>
00022
00023 #include <boost/bind.hpp>
00024
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00029 #include "PhysicsTools/MVAComputer/interface/MVAComputerCache.h"
00030 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00031
00032 namespace PhysicsTools {
00033
00042 template<typename Object>
00043 struct MVAModuleHelperDefaultFiller {
00044 MVAModuleHelperDefaultFiller(const PhysicsTools::AtomicId &name) {}
00045
00046 double operator()(const Object &object,
00047 const PhysicsTools::AtomicId &name)
00048 { return object.compute(name); }
00049 };
00050
00061 template<class Record, typename Object,
00062 class Filler = MVAModuleHelperDefaultFiller<Object> >
00063 class MVAModuleHelper {
00064 public:
00065 MVAModuleHelper(const std::string &label) : label(label) {}
00066 MVAModuleHelper(const MVAModuleHelper &orig) : label(orig.label) {}
00067 ~MVAModuleHelper() {}
00068
00069 void setEventSetup(const edm::EventSetup &setup);
00070 void setEventSetup(const edm::EventSetup &setup, const char *esLabel);
00071
00072 double operator()(const Object &object) const;
00073
00074 void train(const Object &object, bool target, double weight = 1.0) const;
00075
00076 private:
00077 void init(const PhysicsTools::Calibration::MVAComputerContainer *container);
00078
00079 const std::string label;
00080 PhysicsTools::MVAComputerCache cache;
00081
00082 class Value {
00083 public:
00084 Value(const std::string &name) :
00085 name(name), filler(name) {}
00086 Value(const std::string &name, double value) :
00087 name(name), filler(name), value(value) {}
00088
00089 inline bool update(const Object &object) const
00090 {
00091 value = filler(object, name);
00092 return !std::isfinite(value);
00093 }
00094
00095 PhysicsTools::AtomicId getName() const { return name; }
00096 double getValue() const { return value; }
00097
00098 private:
00099 PhysicsTools::AtomicId name;
00100 Filler filler;
00101
00102 mutable double value;
00103 };
00104
00105 std::vector<Value> values;
00106 };
00107
00108 template<class Record, typename Object, class Filler>
00109 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00110 const edm::EventSetup &setup)
00111 {
00112 edm::ESHandle<PhysicsTools::Calibration::MVAComputerContainer> handle;
00113 setup.get<Record>().get(handle);
00114 const PhysicsTools::Calibration::MVAComputerContainer *container = handle.product();
00115 if (cache.update(container, label.c_str()) && cache)
00116 init(container);
00117 }
00118
00119 template<class Record, typename Object, class Filler>
00120 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00121 const edm::EventSetup &setup, const char *esLabel)
00122 {
00123 edm::ESHandle<PhysicsTools::Calibration::MVAComputerContainer> handle;
00124 setup.get<Record>().get(esLabel, handle);
00125 const PhysicsTools::Calibration::MVAComputerContainer *container = handle.product();
00126 if (cache.update(container, label.c_str()) && cache)
00127 init(container);
00128 }
00129
00130 template<class Record, typename Object, class Filler>
00131 void MVAModuleHelper<Record, Object, Filler>::init(
00132 const PhysicsTools::Calibration::MVAComputerContainer *container)
00133 {
00134 const std::vector<PhysicsTools::Calibration::Variable> &vars =
00135 container->find(label).inputSet;
00136 values.clear();
00137 for(std::vector<PhysicsTools::Calibration::Variable>::const_iterator
00138 iter = vars.begin(); iter != vars.end(); ++iter)
00139 if (std::strncmp(iter->name.c_str(), "__", 2) != 0)
00140 values.push_back(Value(iter->name));
00141 }
00142
00143 template<class Record, typename Object, class Filler>
00144 double MVAModuleHelper<Record, Object, Filler>::operator()(
00145 const Object &object) const
00146 {
00147 std::for_each(values.begin(), values.end(),
00148 boost::bind(&Value::update, _1, object));
00149 return cache->eval(values);
00150 }
00151
00152 template<class Record, typename Object, class Filler>
00153 void MVAModuleHelper<Record, Object, Filler>::train(
00154 const Object &object, bool target, double weight) const
00155 {
00156 static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00157 static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00158
00159 if (!cache)
00160 return;
00161
00162 using boost::bind;
00163 if (std::accumulate(values.begin(), values.end(), 0,
00164 bind(std::plus<int>(), _1,
00165 bind(&Value::update, _2, object))))
00166 return;
00167
00168 PhysicsTools::Variable::ValueList list;
00169 list.add(kTargetId, target);
00170 list.add(kWeightId, weight);
00171 for(typename std::vector<Value>::const_iterator iter = values.begin();
00172 iter != values.end(); ++iter)
00173 list.add(iter->getName(), iter->getValue());
00174
00175 cache->eval(list);
00176 }
00177
00178 }
00179
00180 #endif // PhysicsTools_MVAComputer_MVAModuleHelper_h