00001 #ifndef PhysicsTools_MVAComputer_MVAModuleHelper_h
00002 #define PhysicsTools_MVAComputer_MVAModuleHelper_h
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include <functional>
00016 #include <algorithm>
00017 #include <numeric>
00018 #include <cstring>
00019 #include <string>
00020 #include <vector>
00021 #include <cmath>
00022
00023 #include <boost/bind.hpp>
00024
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00029 #include "PhysicsTools/MVAComputer/interface/MVAComputerCache.h"
00030
00031 namespace PhysicsTools {
00032
00041 template<typename Object>
00042 struct MVAModuleHelperDefaultFiller {
00043 MVAModuleHelperDefaultFiller(const PhysicsTools::AtomicId &name) {}
00044
00045 double operator()(const Object &object,
00046 const PhysicsTools::AtomicId &name)
00047 { return object.compute(name); }
00048 };
00049
00060 template<class Record, typename Object,
00061 class Filler = MVAModuleHelperDefaultFiller<Object> >
00062 class MVAModuleHelper {
00063 public:
00064 MVAModuleHelper(const std::string &label) : label(label) {}
00065 MVAModuleHelper(const MVAModuleHelper &orig) : label(orig.label) {}
00066 ~MVAModuleHelper() {}
00067
00068 void setEventSetup(const edm::EventSetup &setup);
00069 void setEventSetup(const edm::EventSetup &setup, const char *esLabel);
00070
00071 double operator()(const Object &object) const;
00072
00073 void train(const Object &object, bool target, double weight = 1.0) const;
00074
00075 private:
00076 void init(const PhysicsTools::Calibration::MVAComputerContainer *container);
00077
00078 const std::string label;
00079 PhysicsTools::MVAComputerCache cache;
00080
00081 class Value {
00082 public:
00083 Value(const std::string &name) :
00084 name(name), filler(name) {}
00085 Value(const std::string &name, double value) :
00086 name(name), filler(name), value(value) {}
00087
00088 inline bool update(const Object &object) const
00089 {
00090 value = filler(object, name);
00091 return !std::isfinite(value);
00092 }
00093
00094 PhysicsTools::AtomicId getName() const { return name; }
00095 double getValue() const { return value; }
00096
00097 private:
00098 PhysicsTools::AtomicId name;
00099 Filler filler;
00100
00101 mutable double value;
00102 };
00103
00104 std::vector<Value> values;
00105 };
00106
00107 template<class Record, typename Object, class Filler>
00108 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00109 const edm::EventSetup &setup)
00110 {
00111 edm::ESHandle<MVAComputerContainer> handle;
00112 setup.get<Record>().get(handle);
00113 const MVAComputerContainer *container = handle.product();
00114 if (cache.update(container, label.c_str()) && cache)
00115 init(container);
00116 }
00117
00118 template<class Record, typename Object, class Filler>
00119 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00120 const edm::EventSetup &setup, const char *esLabel)
00121 {
00122 edm::ESHandle<MVAComputerContainer> handle;
00123 setup.get<Record>().get(esLabel, handle);
00124 const MVAComputerContainer *container = handle.product();
00125 if (cache.update(container, label.c_str()) && cache)
00126 init(container);
00127 }
00128
00129 template<class Record, typename Object, class Filler>
00130 void MVAModuleHelper<Record, Object, Filler>::init(
00131 const PhysicsTools::Calibration::MVAComputerContainer *container)
00132 {
00133 const std::vector<PhysicsTools::Calibration::Variable> &vars =
00134 container->find(label).inputSet;
00135 values.clear();
00136 for(std::vector<PhysicsTools::Calibration::Variable>::const_iterator
00137 iter = vars.begin(); iter != vars.end(); ++iter)
00138 if (std::strncmp(iter->name.c_str(), "__", 2) != 0)
00139 values.push_back(Value(iter->name));
00140 }
00141
00142 template<class Record, typename Object, class Filler>
00143 double MVAModuleHelper<Record, Object, Filler>::operator()(
00144 const Object &object) const
00145 {
00146 std::for_each(values.begin(), values.end(),
00147 boost::bind(&Value::update, _1, object));
00148 return cache->eval(values);
00149 }
00150
00151 template<class Record, typename Object, class Filler>
00152 void MVAModuleHelper<Record, Object, Filler>::train(
00153 const Object &object, bool target, double weight) const
00154 {
00155 static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00156 static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00157
00158 if (!cache)
00159 return;
00160
00161 using boost::bind;
00162 if (std::accumulate(values.begin(), values.end(), 0,
00163 bind(std::plus<int>(), _1,
00164 bind(&Value::update, _2, object))))
00165 return;
00166
00167 PhysicsTools::Variable::ValueList list;
00168 list.add(kTargetId, target);
00169 list.add(kWeightId, weight);
00170 for(typename std::vector<Value>::const_iterator iter = values.begin();
00171 iter != values.end(); ++iter)
00172 list.add(iter->getName(), iter->getValue());
00173
00174 cache->eval(list);
00175 }
00176
00177 }
00178
00179 #endif // PhysicsTools_MVAComputer_MVAModuleHelper_h