CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_3_4/src/PhysicsTools/MVAComputer/interface/MVAModuleHelper.h

Go to the documentation of this file.
00001 #ifndef PhysicsTools_MVAComputer_MVAModuleHelper_h
00002 #define PhysicsTools_MVAComputer_MVAModuleHelper_h
00003 // -*- C++ -*-
00004 //
00005 // Package:     MVAComputer
00006 // Class  :     MVAModuleHelper
00007 //
00008 
00009 //
00010 // Author:      Christophe Saout <christophe.saout@cern.ch>
00011 // Created:     Sat Apr 24 15:18 CEST 2007
00012 // $Id: MVAModuleHelper.h,v 1.3 2011/04/20 07:07:37 kukartse Exp $
00013 //
00014 
00015 #include <functional>
00016 #include <algorithm>
00017 #include <numeric>
00018 #include <cstring>
00019 #include <string>
00020 #include <vector>
00021 #include <cmath>
00022 
00023 #include <boost/bind.hpp>
00024 
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026 
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00029 #include "PhysicsTools/MVAComputer/interface/MVAComputerCache.h"
00030 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00031 
00032 namespace PhysicsTools {
00033 
00042 template<typename Object>
00043 struct MVAModuleHelperDefaultFiller {
00044         MVAModuleHelperDefaultFiller(const PhysicsTools::AtomicId &name) {}
00045 
00046         double operator()(const Object &object,
00047                           const PhysicsTools::AtomicId &name)
00048         { return object.compute(name); }
00049 };
00050 
00061 template<class Record, typename Object,
00062          class Filler = MVAModuleHelperDefaultFiller<Object> >
00063 class MVAModuleHelper {
00064     public:
00065         MVAModuleHelper(const std::string &label) : label(label) {}
00066         MVAModuleHelper(const MVAModuleHelper &orig) : label(orig.label) {}
00067         ~MVAModuleHelper() {}
00068 
00069         void setEventSetup(const edm::EventSetup &setup);
00070         void setEventSetup(const edm::EventSetup &setup, const char *esLabel);
00071 
00072         double operator()(const Object &object) const;
00073 
00074         void train(const Object &object, bool target, double weight = 1.0) const;
00075 
00076     private:
00077         void init(const PhysicsTools::Calibration::MVAComputerContainer *container);
00078 
00079         const std::string               label;
00080         PhysicsTools::MVAComputerCache  cache;
00081 
00082         class Value {
00083             public:
00084                 Value(const std::string &name) :
00085                         name(name), filler(name) {}
00086                 Value(const std::string &name, double value) :
00087                         name(name), filler(name), value(value) {}
00088 
00089                 inline bool update(const Object &object) const
00090                 {
00091                         value = filler(object, name);
00092                         return !std::isfinite(value);
00093                 }
00094 
00095                 PhysicsTools::AtomicId getName() const { return name; }
00096                 double getValue() const { return value; }
00097 
00098             private:
00099                 PhysicsTools::AtomicId          name;
00100                 Filler                          filler;
00101 
00102                 mutable double                  value;
00103         };
00104 
00105         std::vector<Value>                      values;
00106 };
00107 
00108 template<class Record, typename Object, class Filler>
00109 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00110                                                 const edm::EventSetup &setup)
00111 {
00112         edm::ESHandle<PhysicsTools::Calibration::MVAComputerContainer> handle;
00113         setup.get<Record>().get(handle);
00114         const PhysicsTools::Calibration::MVAComputerContainer *container = handle.product();
00115         if (cache.update(container, label.c_str()) && cache)
00116                 init(container);
00117 }
00118 
00119 template<class Record, typename Object, class Filler>
00120 void MVAModuleHelper<Record, Object, Filler>::setEventSetup(
00121                         const edm::EventSetup &setup, const char *esLabel)
00122 {
00123         edm::ESHandle<PhysicsTools::Calibration::MVAComputerContainer> handle;
00124         setup.get<Record>().get(esLabel, handle);
00125         const PhysicsTools::Calibration::MVAComputerContainer *container = handle.product();
00126         if (cache.update(container, label.c_str()) && cache)
00127                 init(container);
00128 }
00129 
00130 template<class Record, typename Object, class Filler>
00131 void MVAModuleHelper<Record, Object, Filler>::init(
00132         const PhysicsTools::Calibration::MVAComputerContainer *container)
00133 {
00134         const std::vector<PhysicsTools::Calibration::Variable> &vars =
00135                                         container->find(label).inputSet;
00136         values.clear();
00137         for(std::vector<PhysicsTools::Calibration::Variable>::const_iterator
00138                         iter = vars.begin(); iter != vars.end(); ++iter)
00139                 if (std::strncmp(iter->name.c_str(), "__", 2) != 0)
00140                         values.push_back(Value(iter->name));
00141 }
00142 
00143 template<class Record, typename Object, class Filler>
00144 double MVAModuleHelper<Record, Object, Filler>::operator()(
00145                                                 const Object &object) const
00146 {
00147         std::for_each(values.begin(), values.end(),
00148                       boost::bind(&Value::update, _1, object));
00149         return cache->eval(values);
00150 }
00151 
00152 template<class Record, typename Object, class Filler>
00153 void MVAModuleHelper<Record, Object, Filler>::train(
00154                 const Object &object, bool target, double weight) const
00155 {
00156         static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00157         static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00158 
00159         if (!cache)
00160                 return;
00161 
00162         using boost::bind;
00163         if (std::accumulate(values.begin(), values.end(), 0,
00164                             bind(std::plus<int>(), _1,
00165                                  bind(&Value::update, _2, object))))
00166                 return;
00167 
00168         PhysicsTools::Variable::ValueList list;
00169         list.add(kTargetId, target);
00170         list.add(kWeightId, weight);
00171         for(typename std::vector<Value>::const_iterator iter = values.begin();
00172             iter != values.end(); ++iter)
00173                 list.add(iter->getName(), iter->getValue());
00174 
00175         cache->eval(list);
00176 }
00177 
00178 } // namespace PhysicsTools
00179 
00180 #endif // PhysicsTools_MVAComputer_MVAModuleHelper_h