Go to the documentation of this file.00001 #include "RecoTauTag/RecoTau/interface/RecoTauMVAHelper.h"
00002 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00003
00004 #include <boost/foreach.hpp>
00005 #include <boost/bind.hpp>
00006 #include <sstream>
00007
00008 #include "CondFormats/DataRecord/interface/TauTagMVAComputerRcd.h"
00009 #include "DataFormats/TauReco/interface/PFTau.h"
00010 #include "RecoTauTag/RecoTau/interface/RecoTauDiscriminantPlugins.h"
00011
00012 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00013 #include "FWCore/Utilities/interface/isFinite.h"
00014
00015 namespace reco { namespace tau {
00016
00017 RecoTauMVAHelper::RecoTauMVAHelper(const std::string& name,
00018 const std::string& eslabel,
00019 const edm::ParameterSet& pluginOptions):
00020 name_(name), eslabel_(eslabel), pluginOptions_(pluginOptions) {}
00021
00022 void RecoTauMVAHelper::setEvent(const edm::Event& evt,
00023 const edm::EventSetup &es) {
00024
00025 using PhysicsTools::Calibration::MVAComputerContainer;
00026 edm::ESHandle<MVAComputerContainer> handle;
00027 if (eslabel_.size()) {
00028 es.get<TauTagMVAComputerRcd>().get(eslabel_, handle);
00029 } else {
00030 es.get<TauTagMVAComputerRcd>().get(handle);
00031 }
00032 const MVAComputerContainer *container = handle.product();
00033
00034 bool reload = computer_.update(container, name_.c_str());
00035
00036 if (reload && computer_.get())
00037 loadDiscriminantPlugins(container->find(name_));
00038
00039 BOOST_FOREACH(PluginMap::value_type plugin, plugins_) {
00040 plugin.second->setup(evt, es);
00041 }
00042 }
00043
00044 void RecoTauMVAHelper::loadDiscriminantPlugins(
00045 const PhysicsTools::Calibration::MVAComputer &comp) {
00046 typedef std::vector<PhysicsTools::Calibration::Variable> VarList;
00047
00048 const VarList &vars = comp.inputSet;
00049
00050 BOOST_FOREACH(const VarList::value_type& var, vars) {
00051
00052 if (std::strncmp(var.name.c_str(), "__", 2) != 0) {
00053
00054 PhysicsTools::AtomicId varId(var.name);
00055 if (!plugins_.count(varId)) {
00056 edm::ParameterSet options;
00057 if (pluginOptions_.exists(var.name)) {
00058 options = pluginOptions_.getParameter<edm::ParameterSet>(var.name);
00059 };
00060
00061 if (!options.exists("name"))
00062 options.addParameter("name", "MVA_" + var.name);
00063
00064
00065
00066
00067
00068
00069 std::string pluginName = reco::tau::discPluginName(var.name);
00070 if (options.exists("plugin")) {
00071 pluginName = options.getParameter<std::string>("plugin");
00072 }
00073 plugins_.insert(
00074 varId, RecoTauDiscriminantPluginFactory::get()->create(
00075 pluginName, options));
00076 }
00077 }
00078 }
00079 }
00080
00081 void RecoTauMVAHelper::fillValues(const reco::PFTauRef& tau) const {
00082
00083 for (PluginMap::const_iterator plugin = plugins_.begin();
00084 plugin != plugins_.end(); ++plugin) {
00085 PhysicsTools::AtomicId id = plugin->first;
00086 std::vector<double> pluginOutput = (plugin->second)->operator()(tau);
00087
00088 for(size_t instance = 0; instance < pluginOutput.size(); ++instance) {
00089 if (edm::isNotFinite(pluginOutput[instance])) {
00090 std::ostringstream error;
00091 error << "A nan was detected in"
00092 << " the tau MVA variable " << id << " returning zero instead!"
00093 << " The PFTau: " << *tau << std::endl;
00094 tau->dump(error);
00095 edm::LogError("CorruptedMVAInput") << error.str();
00096 pluginOutput[instance] = 0.0;
00097 }
00098 }
00099
00100
00101 std::for_each(pluginOutput.begin(), pluginOutput.end(),
00102 boost::bind(&PhysicsTools::Variable::ValueList::add,
00103 boost::ref(values_), id, _1));
00104 }
00105 }
00106
00107
00108 const PhysicsTools::Variable::ValueList&
00109 RecoTauMVAHelper::discriminants(const PFTauRef& tau) const {
00110 values_.clear();
00111 fillValues(tau);
00112 return values_;
00113 }
00114
00115
00116 double RecoTauMVAHelper::operator()(const reco::PFTauRef &tau) const {
00117
00118 values_.clear();
00119
00120 fillValues(tau);
00121
00122 return computer_->eval(values_);
00123 }
00124
00125 void RecoTauMVAHelper::train(const reco::PFTauRef &tau, bool target,
00126 double weight) const {
00127 static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00128 static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00129 if (!computer_)
00130 return;
00131 values_.clear();
00132 values_.add(kTargetId, target);
00133 values_.add(kWeightId, weight);
00134
00135 fillValues(tau);
00136 computer_->eval(values_);
00137 }
00138
00139 }}