CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_6_1_2_SLHC4_patch1/src/RecoTauTag/RecoTau/src/RecoTauMVAHelper.cc

Go to the documentation of this file.
00001 #include "RecoTauTag/RecoTau/interface/RecoTauMVAHelper.h"
00002 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00003 
00004 #include <boost/foreach.hpp>
00005 #include <boost/bind.hpp>
00006 #include <sstream>
00007 
00008 #include "CondFormats/DataRecord/interface/TauTagMVAComputerRcd.h"
00009 #include "DataFormats/TauReco/interface/PFTau.h"
00010 #include "RecoTauTag/RecoTau/interface/RecoTauDiscriminantPlugins.h"
00011 
00012 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00013 #include "FWCore/Utilities/interface/isFinite.h"
00014 
00015 namespace reco { namespace tau {
00016 
00017 RecoTauMVAHelper::RecoTauMVAHelper(const std::string& name,
00018                                    const std::string& eslabel,
00019                                    const edm::ParameterSet& pluginOptions):
00020     name_(name), eslabel_(eslabel), pluginOptions_(pluginOptions) {}
00021 
00022 void RecoTauMVAHelper::setEvent(const edm::Event& evt,
00023                                 const edm::EventSetup &es) {
00024   // Update our MVA from the DB
00025   using PhysicsTools::Calibration::MVAComputerContainer;
00026   edm::ESHandle<MVAComputerContainer> handle;
00027   if (eslabel_.size()) {
00028     es.get<TauTagMVAComputerRcd>().get(eslabel_, handle);
00029   } else {
00030     es.get<TauTagMVAComputerRcd>().get(handle);
00031   }
00032   const MVAComputerContainer *container = handle.product();
00033   // Load our MVA
00034   bool reload = computer_.update(container, name_.c_str());
00035   // If the MVA changed, update our list of discriminant plugins
00036   if (reload && computer_.get())
00037     loadDiscriminantPlugins(container->find(name_));
00038   // Update the event info for all of our discriminators
00039   BOOST_FOREACH(PluginMap::value_type plugin, plugins_) {
00040     plugin.second->setup(evt, es);
00041   }
00042 }
00043 
00044 void RecoTauMVAHelper::loadDiscriminantPlugins(
00045     const PhysicsTools::Calibration::MVAComputer &comp) {
00046   typedef std::vector<PhysicsTools::Calibration::Variable> VarList;
00047   // List of input variables for this MVA.
00048   const VarList &vars = comp.inputSet;
00049   // Load the plugin for each of the Var if needed
00050   BOOST_FOREACH(const VarList::value_type& var, vars) {
00051     // Check to make sure it isn't a magic variable
00052     if (std::strncmp(var.name.c_str(), "__", 2) != 0) {
00053       // If we haven't added yet, build it.
00054       PhysicsTools::AtomicId varId(var.name);
00055       if (!plugins_.count(varId)) {
00056         edm::ParameterSet options;
00057         if (pluginOptions_.exists(var.name)) {
00058           options = pluginOptions_.getParameter<edm::ParameterSet>(var.name);
00059         };
00060         // Make sure it has a name (required by base class)
00061         if (!options.exists("name"))
00062           options.addParameter("name", "MVA_" + var.name);
00063         // Check if we want to specify the plugin name manually.  This is
00064         // required for things like the discriminant from discriminators, which
00065         // take an InputTag.  If we want to have more than one, we have to be
00066         // able take the MVA name (like FlightPathSig) and map it to
00067         // RecoTauDiscriminantFromDiscriminator[disc input tag = flight path sig]
00068         // Otherwise we just keep our regular plugin mapping.
00069         std::string pluginName = reco::tau::discPluginName(var.name);
00070         if (options.exists("plugin")) {
00071           pluginName = options.getParameter<std::string>("plugin");
00072         }
00073         plugins_.insert(
00074             varId, RecoTauDiscriminantPluginFactory::get()->create(
00075                 pluginName, options));
00076       }
00077     }
00078   }
00079 }
00080 
00081 void RecoTauMVAHelper::fillValues(const reco::PFTauRef& tau) const {
00082   // Loop over the relevant discriminators and the output
00083   for (PluginMap::const_iterator plugin = plugins_.begin();
00084        plugin != plugins_.end(); ++plugin) {
00085     PhysicsTools::AtomicId id = plugin->first;
00086     std::vector<double> pluginOutput = (plugin->second)->operator()(tau);
00087     // Check for nans
00088     for(size_t instance = 0; instance < pluginOutput.size(); ++instance) {
00089       if (edm::isNotFinite(pluginOutput[instance])) {
00090         std::ostringstream error;
00091         error << "A nan was detected in"
00092             << " the tau MVA variable " << id << " returning zero instead!"
00093             << " The PFTau: " << *tau << std::endl;
00094         tau->dump(error);
00095         edm::LogError("CorruptedMVAInput") << error.str();
00096         pluginOutput[instance] = 0.0;
00097       }
00098     }
00099     //std::cout << "id: " << id << " first: " << pluginOutput[0] << std::endl;
00100     // Build values and copy into values vector
00101     std::for_each(pluginOutput.begin(), pluginOutput.end(),
00102                   boost::bind(&PhysicsTools::Variable::ValueList::add,
00103                               boost::ref(values_), id, _1));
00104   }
00105 }
00106 
00107 // Get values
00108 const PhysicsTools::Variable::ValueList&
00109 RecoTauMVAHelper::discriminants(const PFTauRef& tau) const {
00110   values_.clear();
00111   fillValues(tau);
00112   return values_;
00113 }
00114 
00115 // Apply the MVA to a given tau
00116 double RecoTauMVAHelper::operator()(const reco::PFTauRef &tau) const {
00117   // Clear output
00118   values_.clear();
00119   // Build the values
00120   fillValues(tau);
00121   // Call the MVA
00122   return computer_->eval(values_);
00123 }
00124 
00125 void RecoTauMVAHelper::train(const reco::PFTauRef &tau, bool target,
00126                              double weight) const {
00127   static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00128   static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00129   if (!computer_)
00130     return;
00131   values_.clear();
00132   values_.add(kTargetId, target);
00133   values_.add(kWeightId, weight);
00134   // Build the discriminant values
00135   fillValues(tau);
00136   computer_->eval(values_);
00137 }
00138 
00139 }}  // end namespace reco::tau