CMS 3D CMS Logo

MVAModuleHelper.h

Go to the documentation of this file.
00001 #ifndef PhysicsTools_MVAComputer_MVAModuleHelper_h
00002 #define PhysicsTools_MVAComputer_MVAModuleHelper_h
00003 // -*- C++ -*-
00004 //
00005 // Package:     MVAComputer
00006 // Class  :     MVAModuleHelper
00007 //
00008 
00009 //
00010 // Author:      Christophe Saout <christophe.saout@cern.ch>
00011 // Created:     Sat Apr 24 15:18 CEST 2007
00012 // $Id: MVAModuleHelper.h,v 1.1 2008/12/14 15:05:22 saout Exp $
00013 //
00014 
00015 #include <functional>
00016 #include <algorithm>
00017 #include <numeric>
00018 #include <cstring>
00019 #include <string>
00020 #include <vector>
00021 #include <cmath>
00022 
00023 #include <boost/bind.hpp>
00024 
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026 
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00029 #include "PhysicsTools/MVAComputer/interface/MVAComputerCache.h"
00030 
00031 namespace PhysicsTools {
00032 
00041 template<typename Object>
00042 struct MVAModuleHelperDefaultFiller {
00043         MVAModuleHelperDefaultFiller(const PhysicsTools::AtomicId &name) {}
00044 
00045         double operator()(const Object &object,
00046                           const PhysicsTools::AtomicId &name)
00047         { return object.compute(name); }
00048 };
00049 
00060 template<class Record, typename Object,
00061          class Filler = MVAModuleHelperDefaultFiller<Object> >
00062 class MVAModuleHelper {
00063     public:
00064         MVAModuleHelper(const std::string &label) : label(label) {}
00065         MVAModuleHelper(const MVAModuleHelper &orig) : label(orig.label) {}
00066         ~MVAModuleHelper() {}
00067 
00068         void setEventSetup(const edm::EventSetup &setup);
00069         void setEventSetup(const edm::EventSetup &setup, const char *esLabel);
00070 
00071         double operator()(const Object &object) const;
00072 
00073         void train(const Object &object, bool target, double weight = 1.0) const;
00074 
00075     private:
00076         void init(const PhysicsTools::Calibration::MVAComputerContainer *container);
00077 
00078         const std::string               label;
00079         PhysicsTools::MVAComputerCache  cache;
00080 
00081         class Value {
00082             public:
00083                 Value(const std::string &name) :
00084                         name(name), filler(name) {}
00085                 Value(const std::string &name, double value) :
00086                         name(name), filler(name), value(value) {}
00087 
00088                 inline bool update(const Object &object) const
00089                 {
00090                         value = filler(object, name);
00091                         return !std::isfinite(value);
00092                 }
00093 
00094                 PhysicsTools::AtomicId getName() const { return name; }
00095                 double getValue() const { return value; }
00096 
00097             private:
00098                 PhysicsTools::AtomicId          name;
00099                 Filler                          filler;
00100 
00101                 mutable double                  value;
00102         };
00103 
00104         std::vector<Value>                      values;
00105 };
00106 
00107 template<class Record, typename Object, class Filler>
00108 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00109                                                 const edm::EventSetup &setup)
00110 {
00111         using namespace PhysicsTools::Calibration;
00112         edm::ESHandle<MVAComputerContainer> handle;
00113         setup.get<Record>().get(handle);
00114         const MVAComputerContainer *container = handle.product();
00115         if (cache.update(container, label.c_str()) && cache)
00116                 init(container);
00117 }
00118 
00119 template<class Record, typename Object, class Filler>
00120 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00121                         const edm::EventSetup &setup, const char *esLabel)
00122 {
00123         using namespace PhysicsTools::Calibration;
00124         edm::ESHandle<MVAComputerContainer> handle;
00125         setup.get<Record>().get(esLabel, handle);
00126         const MVAComputerContainer *container = handle.product();
00127         if (cache.update(container, label.c_str()) && cache)
00128                 init(container);
00129 }
00130 
00131 template<class Record, typename Object, class Filler>
00132 void MVAModuleHelper<Record, Object, Filler>::init(
00133         const PhysicsTools::Calibration::MVAComputerContainer *container)
00134 {
00135         const std::vector<PhysicsTools::Calibration::Variable> &vars =
00136                                         container->find(label).inputSet;
00137         values.clear();
00138         for(std::vector<PhysicsTools::Calibration::Variable>::const_iterator
00139                         iter = vars.begin(); iter != vars.end(); ++iter)
00140                 if (std::strncmp(iter->name.c_str(), "__", 2) != 0)
00141                         values.push_back(Value(iter->name));
00142 }
00143 
00144 template<class Record, typename Object, class Filler>
00145 double MVAModuleHelper<Record, Object, Filler>::operator()(
00146                                                 const Object &object) const
00147 {
00148         std::for_each(values.begin(), values.end(),
00149                       boost::bind(&Value::update, _1, object));
00150         return cache->eval(values);
00151 }
00152 
00153 template<class Record, typename Object, class Filler>
00154 void MVAModuleHelper<Record, Object, Filler>::train(
00155                 const Object &object, bool target, double weight) const
00156 {
00157         static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00158         static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00159 
00160         if (!cache)
00161                 return;
00162 
00163         using boost::bind;
00164         if (std::accumulate(values.begin(), values.end(), 0,
00165                             bind(std::plus<int>(), _1,
00166                                  bind(&Value::update, _2, object))))
00167                 return;
00168 
00169         PhysicsTools::Variable::ValueList list;
00170         list.add(kTargetId, target);
00171         list.add(kWeightId, weight);
00172         for(typename std::vector<Value>::const_iterator iter = values.begin();
00173             iter != values.end(); ++iter)
00174                 list.add(iter->getName(), iter->getValue());
00175 
00176         cache->eval(list);
00177 }
00178 
00179 } // namespace PhysicsTools
00180 
00181 #endif // PhysicsTools_MVAComputer_MVAModuleHelper_h

Generated on Tue Jun 9 17:41:17 2009 for CMSSW by  doxygen 1.5.4