Go to the documentation of this file.00001 #include "RecoTauTag/RecoTau/interface/RecoTauMVAHelper.h"
00002 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00003
00004 #include <boost/foreach.hpp>
00005 #include <boost/bind.hpp>
00006 #include <sstream>
00007
00008 #include "CondFormats/DataRecord/interface/TauTagMVAComputerRcd.h"
00009 #include "DataFormats/TauReco/interface/PFTau.h"
00010 #include "RecoTauTag/RecoTau/interface/RecoTauDiscriminantPlugins.h"
00011
00012 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00013
00014 namespace reco { namespace tau {
00015
00016 RecoTauMVAHelper::RecoTauMVAHelper(const std::string& name,
00017 const std::string& eslabel,
00018 const edm::ParameterSet& pluginOptions):
00019 name_(name), eslabel_(eslabel), pluginOptions_(pluginOptions) {}
00020
00021 void RecoTauMVAHelper::setEvent(const edm::Event& evt,
00022 const edm::EventSetup &es) {
00023
00024 using PhysicsTools::Calibration::MVAComputerContainer;
00025 edm::ESHandle<MVAComputerContainer> handle;
00026 if (eslabel_.size()) {
00027 es.get<TauTagMVAComputerRcd>().get(eslabel_, handle);
00028 } else {
00029 es.get<TauTagMVAComputerRcd>().get(handle);
00030 }
00031 const MVAComputerContainer *container = handle.product();
00032
00033 bool reload = computer_.update(container, name_.c_str());
00034
00035 if (reload && computer_.get())
00036 loadDiscriminantPlugins(container->find(name_));
00037
00038 BOOST_FOREACH(PluginMap::value_type plugin, plugins_) {
00039 plugin.second->setup(evt, es);
00040 }
00041 }
00042
00043 void RecoTauMVAHelper::loadDiscriminantPlugins(
00044 const PhysicsTools::Calibration::MVAComputer &comp) {
00045 typedef std::vector<PhysicsTools::Calibration::Variable> VarList;
00046
00047 const VarList &vars = comp.inputSet;
00048
00049 BOOST_FOREACH(const VarList::value_type& var, vars) {
00050
00051 if (std::strncmp(var.name.c_str(), "__", 2) != 0) {
00052
00053 PhysicsTools::AtomicId varId(var.name);
00054 if (!plugins_.count(varId)) {
00055 edm::ParameterSet options;
00056 if (pluginOptions_.exists(var.name)) {
00057 options = pluginOptions_.getParameter<edm::ParameterSet>(var.name);
00058 };
00059
00060 if (!options.exists("name"))
00061 options.addParameter("name", "MVA_" + var.name);
00062
00063
00064
00065
00066
00067
00068 std::string pluginName = reco::tau::discPluginName(var.name);
00069 if (options.exists("plugin")) {
00070 pluginName = options.getParameter<std::string>("plugin");
00071 }
00072 plugins_.insert(
00073 varId, RecoTauDiscriminantPluginFactory::get()->create(
00074 pluginName, options));
00075 }
00076 }
00077 }
00078 }
00079
00080 void RecoTauMVAHelper::fillValues(const reco::PFTauRef& tau) const {
00081
00082 for (PluginMap::const_iterator plugin = plugins_.begin();
00083 plugin != plugins_.end(); ++plugin) {
00084 PhysicsTools::AtomicId id = plugin->first;
00085 std::vector<double> pluginOutput = (plugin->second)->operator()(tau);
00086
00087 for(size_t instance = 0; instance < pluginOutput.size(); ++instance) {
00088 if (std::isnan(pluginOutput[instance])) {
00089 std::ostringstream error;
00090 error << "A nan was detected in"
00091 << " the tau MVA variable " << id << " returning zero instead!"
00092 << " The PFTau: " << *tau << std::endl;
00093 tau->dump(error);
00094 edm::LogError("CorruptedMVAInput") << error.str();
00095 pluginOutput[instance] = 0.0;
00096 }
00097 }
00098
00099
00100 std::for_each(pluginOutput.begin(), pluginOutput.end(),
00101 boost::bind(&PhysicsTools::Variable::ValueList::add,
00102 boost::ref(values_), id, _1));
00103 }
00104 }
00105
00106
00107 const PhysicsTools::Variable::ValueList&
00108 RecoTauMVAHelper::discriminants(const PFTauRef& tau) const {
00109 values_.clear();
00110 fillValues(tau);
00111 return values_;
00112 }
00113
00114
00115 double RecoTauMVAHelper::operator()(const reco::PFTauRef &tau) const {
00116
00117 values_.clear();
00118
00119 fillValues(tau);
00120
00121 return computer_->eval(values_);
00122 }
00123
00124 void RecoTauMVAHelper::train(const reco::PFTauRef &tau, bool target,
00125 double weight) const {
00126 static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00127 static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00128 if (!computer_)
00129 return;
00130 values_.clear();
00131 values_.add(kTargetId, target);
00132 values_.add(kWeightId, weight);
00133
00134 fillValues(tau);
00135 computer_->eval(values_);
00136 }
00137
00138 }}