CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_4_1_8_patch12/src/PhysicsTools/MVAComputer/interface/MVAModuleHelper.h

Go to the documentation of this file.
00001 #ifndef PhysicsTools_MVAComputer_MVAModuleHelper_h
00002 #define PhysicsTools_MVAComputer_MVAModuleHelper_h
00003 // -*- C++ -*-
00004 //
00005 // Package:     MVAComputer
00006 // Class  :     MVAModuleHelper
00007 //
00008 
00009 //
00010 // Author:      Christophe Saout <christophe.saout@cern.ch>
00011 // Created:     Sat Apr 24 15:18 CEST 2007
00012 // $Id: MVAModuleHelper.h,v 1.2 2010/10/20 20:39:10 wmtan Exp $
00013 //
00014 
00015 #include <functional>
00016 #include <algorithm>
00017 #include <numeric>
00018 #include <cstring>
00019 #include <string>
00020 #include <vector>
00021 #include <cmath>
00022 
00023 #include <boost/bind.hpp>
00024 
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026 
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00029 #include "PhysicsTools/MVAComputer/interface/MVAComputerCache.h"
00030 
00031 namespace PhysicsTools {
00032 
00041 template<typename Object>
00042 struct MVAModuleHelperDefaultFiller {
00043         MVAModuleHelperDefaultFiller(const PhysicsTools::AtomicId &name) {}
00044 
00045         double operator()(const Object &object,
00046                           const PhysicsTools::AtomicId &name)
00047         { return object.compute(name); }
00048 };
00049 
00060 template<class Record, typename Object,
00061          class Filler = MVAModuleHelperDefaultFiller<Object> >
00062 class MVAModuleHelper {
00063     public:
00064         MVAModuleHelper(const std::string &label) : label(label) {}
00065         MVAModuleHelper(const MVAModuleHelper &orig) : label(orig.label) {}
00066         ~MVAModuleHelper() {}
00067 
00068         void setEventSetup(const edm::EventSetup &setup);
00069         void setEventSetup(const edm::EventSetup &setup, const char *esLabel);
00070 
00071         double operator()(const Object &object) const;
00072 
00073         void train(const Object &object, bool target, double weight = 1.0) const;
00074 
00075     private:
00076         void init(const PhysicsTools::Calibration::MVAComputerContainer *container);
00077 
00078         const std::string               label;
00079         PhysicsTools::MVAComputerCache  cache;
00080 
00081         class Value {
00082             public:
00083                 Value(const std::string &name) :
00084                         name(name), filler(name) {}
00085                 Value(const std::string &name, double value) :
00086                         name(name), filler(name), value(value) {}
00087 
00088                 inline bool update(const Object &object) const
00089                 {
00090                         value = filler(object, name);
00091                         return !std::isfinite(value);
00092                 }
00093 
00094                 PhysicsTools::AtomicId getName() const { return name; }
00095                 double getValue() const { return value; }
00096 
00097             private:
00098                 PhysicsTools::AtomicId          name;
00099                 Filler                          filler;
00100 
00101                 mutable double                  value;
00102         };
00103 
00104         std::vector<Value>                      values;
00105 };
00106 
00107 template<class Record, typename Object, class Filler>
00108 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00109                                                 const edm::EventSetup &setup)
00110 {
00111         edm::ESHandle<MVAComputerContainer> handle;
00112         setup.get<Record>().get(handle);
00113         const MVAComputerContainer *container = handle.product();
00114         if (cache.update(container, label.c_str()) && cache)
00115                 init(container);
00116 }
00117 
00118 template<class Record, typename Object, class Filler>
00119 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00120                         const edm::EventSetup &setup, const char *esLabel)
00121 {
00122         edm::ESHandle<MVAComputerContainer> handle;
00123         setup.get<Record>().get(esLabel, handle);
00124         const MVAComputerContainer *container = handle.product();
00125         if (cache.update(container, label.c_str()) && cache)
00126                 init(container);
00127 }
00128 
00129 template<class Record, typename Object, class Filler>
00130 void MVAModuleHelper<Record, Object, Filler>::init(
00131         const PhysicsTools::Calibration::MVAComputerContainer *container)
00132 {
00133         const std::vector<PhysicsTools::Calibration::Variable> &vars =
00134                                         container->find(label).inputSet;
00135         values.clear();
00136         for(std::vector<PhysicsTools::Calibration::Variable>::const_iterator
00137                         iter = vars.begin(); iter != vars.end(); ++iter)
00138                 if (std::strncmp(iter->name.c_str(), "__", 2) != 0)
00139                         values.push_back(Value(iter->name));
00140 }
00141 
00142 template<class Record, typename Object, class Filler>
00143 double MVAModuleHelper<Record, Object, Filler>::operator()(
00144                                                 const Object &object) const
00145 {
00146         std::for_each(values.begin(), values.end(),
00147                       boost::bind(&Value::update, _1, object));
00148         return cache->eval(values);
00149 }
00150 
00151 template<class Record, typename Object, class Filler>
00152 void MVAModuleHelper<Record, Object, Filler>::train(
00153                 const Object &object, bool target, double weight) const
00154 {
00155         static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00156         static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00157 
00158         if (!cache)
00159                 return;
00160 
00161         using boost::bind;
00162         if (std::accumulate(values.begin(), values.end(), 0,
00163                             bind(std::plus<int>(), _1,
00164                                  bind(&Value::update, _2, object))))
00165                 return;
00166 
00167         PhysicsTools::Variable::ValueList list;
00168         list.add(kTargetId, target);
00169         list.add(kWeightId, weight);
00170         for(typename std::vector<Value>::const_iterator iter = values.begin();
00171             iter != values.end(); ++iter)
00172                 list.add(iter->getName(), iter->getValue());
00173 
00174         cache->eval(list);
00175 }
00176 
00177 } // namespace PhysicsTools
00178 
00179 #endif // PhysicsTools_MVAComputer_MVAModuleHelper_h