Go to the documentation of this file.00001 #include "RecoTauTag/RecoTau/interface/RecoTauMVAHelper.h"
00002 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00003
00004 #include <boost/foreach.hpp>
00005 #include <boost/bind.hpp>
00006
00007 #include "CondFormats/DataRecord/interface/TauTagMVAComputerRcd.h"
00008 #include "DataFormats/TauReco/interface/PFTau.h"
00009 #include "RecoTauTag/RecoTau/interface/RecoTauDiscriminantPlugins.h"
00010
00011 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00012
00013 namespace reco { namespace tau {
00014
00015 RecoTauMVAHelper::RecoTauMVAHelper(const std::string &name,
00016 const std::string eslabel):
00017 name_(name), eslabel_(eslabel) {}
00018
00019 void RecoTauMVAHelper::setEvent(const edm::Event& evt,
00020 const edm::EventSetup &es) {
00021
00022 BOOST_FOREACH(PluginMap::value_type plugin, plugins_) {
00023 plugin.second->setup(evt, es);
00024 }
00025
00026 using PhysicsTools::Calibration::MVAComputerContainer;
00027 edm::ESHandle<MVAComputerContainer> handle;
00028 if (eslabel_.size()) {
00029 es.get<TauTagMVAComputerRcd>().get(eslabel_, handle);
00030 } else {
00031 es.get<TauTagMVAComputerRcd>().get(handle);
00032 }
00033 const MVAComputerContainer *container = handle.product();
00034
00035 bool reload = computer_.update(container, name_.c_str());
00036
00037 if (reload && computer_.get())
00038 loadDiscriminantPlugins(container->find(name_));
00039 }
00040
00041 void RecoTauMVAHelper::loadDiscriminantPlugins(
00042 const PhysicsTools::Calibration::MVAComputer &comp) {
00043 typedef std::vector<PhysicsTools::Calibration::Variable> VarList;
00044
00045 const VarList &vars = comp.inputSet;
00046
00047 BOOST_FOREACH(const VarList::value_type& var, vars) {
00048
00049 if (std::strncmp(var.name.c_str(), "__", 2) != 0) {
00050
00051 PhysicsTools::AtomicId varId(var.name);
00052 if (!plugins_.count(varId)) {
00053 edm::ParameterSet fakePSet;
00054 fakePSet.addParameter("name", "MVA_" + var.name);
00055 plugins_.insert(
00056 varId, RecoTauDiscriminantPluginFactory::get()->create(
00057 reco::tau::discPluginName(var.name), fakePSet));
00058 }
00059 }
00060 }
00061 }
00062
00063 void RecoTauMVAHelper::fillValues(const reco::PFTauRef& tau) const {
00064
00065 for (PluginMap::const_iterator plugin = plugins_.begin();
00066 plugin != plugins_.end(); ++plugin) {
00067 PhysicsTools::AtomicId id = plugin->first;
00068 std::vector<double> pluginOutput = (plugin->second)->operator()(tau);
00069
00070 for(size_t instance = 0; instance < pluginOutput.size(); ++instance) {
00071 if (std::isnan(pluginOutput[instance])) {
00072 edm::LogError("CorruptedMVAInput") << "A nan was detected in"
00073 << " the tau MVA variable " << id << " returning zero instead!";
00074 pluginOutput[instance] = 0.0;
00075 }
00076 }
00077
00078
00079 std::for_each(pluginOutput.begin(), pluginOutput.end(),
00080 boost::bind(&PhysicsTools::Variable::ValueList::add,
00081 boost::ref(values_), id, _1));
00082 }
00083 }
00084
00085
00086 const PhysicsTools::Variable::ValueList&
00087 RecoTauMVAHelper::discriminants(const PFTauRef& tau) const {
00088 values_.clear();
00089 fillValues(tau);
00090 return values_;
00091 }
00092
00093
00094 double RecoTauMVAHelper::operator()(const reco::PFTauRef &tau) const {
00095
00096 values_.clear();
00097
00098 fillValues(tau);
00099
00100 return computer_->eval(values_);
00101 }
00102
00103 void RecoTauMVAHelper::train(const reco::PFTauRef &tau, bool target,
00104 double weight) const {
00105 static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00106 static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00107 if (!computer_)
00108 return;
00109 values_.clear();
00110 values_.add(kTargetId, target);
00111 values_.add(kWeightId, weight);
00112
00113 fillValues(tau);
00114 computer_->eval(values_);
00115 }
00116
00117 }}