00001 #ifndef PhysicsTools_MVAComputer_MVAModuleHelper_h
00002 #define PhysicsTools_MVAComputer_MVAModuleHelper_h
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include <functional>
00016 #include <algorithm>
00017 #include <numeric>
00018 #include <cstring>
00019 #include <string>
00020 #include <vector>
00021 #include <cmath>
00022
00023 #include <boost/bind.hpp>
00024
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00029 #include "PhysicsTools/MVAComputer/interface/MVAComputerCache.h"
00030
00031 namespace PhysicsTools {
00032
00041 template<typename Object>
00042 struct MVAModuleHelperDefaultFiller {
00043 MVAModuleHelperDefaultFiller(const PhysicsTools::AtomicId &name) {}
00044
00045 double operator()(const Object &object,
00046 const PhysicsTools::AtomicId &name)
00047 { return object.compute(name); }
00048 };
00049
00060 template<class Record, typename Object,
00061 class Filler = MVAModuleHelperDefaultFiller<Object> >
00062 class MVAModuleHelper {
00063 public:
00064 MVAModuleHelper(const std::string &label) : label(label) {}
00065 MVAModuleHelper(const MVAModuleHelper &orig) : label(orig.label) {}
00066 ~MVAModuleHelper() {}
00067
00068 void setEventSetup(const edm::EventSetup &setup);
00069 void setEventSetup(const edm::EventSetup &setup, const char *esLabel);
00070
00071 double operator()(const Object &object) const;
00072
00073 void train(const Object &object, bool target, double weight = 1.0) const;
00074
00075 private:
00076 void init(const PhysicsTools::Calibration::MVAComputerContainer *container);
00077
00078 const std::string label;
00079 PhysicsTools::MVAComputerCache cache;
00080
00081 class Value {
00082 public:
00083 Value(const std::string &name) :
00084 name(name), filler(name) {}
00085 Value(const std::string &name, double value) :
00086 name(name), filler(name), value(value) {}
00087
00088 inline bool update(const Object &object) const
00089 {
00090 value = filler(object, name);
00091 return !std::isfinite(value);
00092 }
00093
00094 PhysicsTools::AtomicId getName() const { return name; }
00095 double getValue() const { return value; }
00096
00097 private:
00098 PhysicsTools::AtomicId name;
00099 Filler filler;
00100
00101 mutable double value;
00102 };
00103
00104 std::vector<Value> values;
00105 };
00106
00107 template<class Record, typename Object, class Filler>
00108 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00109 const edm::EventSetup &setup)
00110 {
00111 using namespace PhysicsTools::Calibration;
00112 edm::ESHandle<MVAComputerContainer> handle;
00113 setup.get<Record>().get(handle);
00114 const MVAComputerContainer *container = handle.product();
00115 if (cache.update(container, label.c_str()) && cache)
00116 init(container);
00117 }
00118
00119 template<class Record, typename Object, class Filler>
00120 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00121 const edm::EventSetup &setup, const char *esLabel)
00122 {
00123 using namespace PhysicsTools::Calibration;
00124 edm::ESHandle<MVAComputerContainer> handle;
00125 setup.get<Record>().get(esLabel, handle);
00126 const MVAComputerContainer *container = handle.product();
00127 if (cache.update(container, label.c_str()) && cache)
00128 init(container);
00129 }
00130
00131 template<class Record, typename Object, class Filler>
00132 void MVAModuleHelper<Record, Object, Filler>::init(
00133 const PhysicsTools::Calibration::MVAComputerContainer *container)
00134 {
00135 const std::vector<PhysicsTools::Calibration::Variable> &vars =
00136 container->find(label).inputSet;
00137 values.clear();
00138 for(std::vector<PhysicsTools::Calibration::Variable>::const_iterator
00139 iter = vars.begin(); iter != vars.end(); ++iter)
00140 if (std::strncmp(iter->name.c_str(), "__", 2) != 0)
00141 values.push_back(Value(iter->name));
00142 }
00143
00144 template<class Record, typename Object, class Filler>
00145 double MVAModuleHelper<Record, Object, Filler>::operator()(
00146 const Object &object) const
00147 {
00148 std::for_each(values.begin(), values.end(),
00149 boost::bind(&Value::update, _1, object));
00150 return cache->eval(values);
00151 }
00152
00153 template<class Record, typename Object, class Filler>
00154 void MVAModuleHelper<Record, Object, Filler>::train(
00155 const Object &object, bool target, double weight) const
00156 {
00157 static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00158 static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00159
00160 if (!cache)
00161 return;
00162
00163 using boost::bind;
00164 if (std::accumulate(values.begin(), values.end(), 0,
00165 bind(std::plus<int>(), _1,
00166 bind(&Value::update, _2, object))))
00167 return;
00168
00169 PhysicsTools::Variable::ValueList list;
00170 list.add(kTargetId, target);
00171 list.add(kWeightId, weight);
00172 for(typename std::vector<Value>::const_iterator iter = values.begin();
00173 iter != values.end(); ++iter)
00174 list.add(iter->getName(), iter->getValue());
00175
00176 cache->eval(list);
00177 }
00178
00179 }
00180
00181 #endif // PhysicsTools_MVAComputer_MVAModuleHelper_h