CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_4_4_5_patch3/src/RecoTauTag/RecoTau/src/RecoTauMVAHelper.cc

Go to the documentation of this file.
00001 #include "RecoTauTag/RecoTau/interface/RecoTauMVAHelper.h"
00002 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00003 
00004 #include <boost/foreach.hpp>
00005 #include <boost/bind.hpp>
00006 #include <sstream>
00007 
00008 #include "CondFormats/DataRecord/interface/TauTagMVAComputerRcd.h"
00009 #include "DataFormats/TauReco/interface/PFTau.h"
00010 #include "RecoTauTag/RecoTau/interface/RecoTauDiscriminantPlugins.h"
00011 
00012 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00013 
00014 namespace reco { namespace tau {
00015 
00016 RecoTauMVAHelper::RecoTauMVAHelper(const std::string& name,
00017                                    const std::string& eslabel,
00018                                    const edm::ParameterSet& pluginOptions):
00019     name_(name), eslabel_(eslabel), pluginOptions_(pluginOptions) {}
00020 
00021 void RecoTauMVAHelper::setEvent(const edm::Event& evt,
00022                                 const edm::EventSetup &es) {
00023   // Update our MVA from the DB
00024   using PhysicsTools::Calibration::MVAComputerContainer;
00025   edm::ESHandle<MVAComputerContainer> handle;
00026   if (eslabel_.size()) {
00027     es.get<TauTagMVAComputerRcd>().get(eslabel_, handle);
00028   } else {
00029     es.get<TauTagMVAComputerRcd>().get(handle);
00030   }
00031   const MVAComputerContainer *container = handle.product();
00032   // Load our MVA
00033   bool reload = computer_.update(container, name_.c_str());
00034   // If the MVA changed, update our list of discriminant plugins
00035   if (reload && computer_.get())
00036     loadDiscriminantPlugins(container->find(name_));
00037   // Update the event info for all of our discriminators
00038   BOOST_FOREACH(PluginMap::value_type plugin, plugins_) {
00039     plugin.second->setup(evt, es);
00040   }
00041 }
00042 
00043 void RecoTauMVAHelper::loadDiscriminantPlugins(
00044     const PhysicsTools::Calibration::MVAComputer &comp) {
00045   typedef std::vector<PhysicsTools::Calibration::Variable> VarList;
00046   // List of input variables for this MVA.
00047   const VarList &vars = comp.inputSet;
00048   // Load the plugin for each of the Var if needed
00049   BOOST_FOREACH(const VarList::value_type& var, vars) {
00050     // Check to make sure it isn't a magic variable
00051     if (std::strncmp(var.name.c_str(), "__", 2) != 0) {
00052       // If we haven't added yet, build it.
00053       PhysicsTools::AtomicId varId(var.name);
00054       if (!plugins_.count(varId)) {
00055         edm::ParameterSet options;
00056         if (pluginOptions_.exists(var.name)) {
00057           options = pluginOptions_.getParameter<edm::ParameterSet>(var.name);
00058         };
00059         // Make sure it has a name (required by base class)
00060         if (!options.exists("name"))
00061           options.addParameter("name", "MVA_" + var.name);
00062         // Check if we want to specify the plugin name manually.  This is
00063         // required for things like the discriminant from discriminators, which
00064         // take an InputTag.  If we want to have more than one, we have to be
00065         // able take the MVA name (like FlightPathSig) and map it to
00066         // RecoTauDiscriminantFromDiscriminator[disc input tag = flight path sig]
00067         // Otherwise we just keep our regular plugin mapping.
00068         std::string pluginName = reco::tau::discPluginName(var.name);
00069         if (options.exists("plugin")) {
00070           pluginName = options.getParameter<std::string>("plugin");
00071         }
00072         plugins_.insert(
00073             varId, RecoTauDiscriminantPluginFactory::get()->create(
00074                 pluginName, options));
00075       }
00076     }
00077   }
00078 }
00079 
00080 void RecoTauMVAHelper::fillValues(const reco::PFTauRef& tau) const {
00081   // Loop over the relevant discriminators and the output
00082   for (PluginMap::const_iterator plugin = plugins_.begin();
00083        plugin != plugins_.end(); ++plugin) {
00084     PhysicsTools::AtomicId id = plugin->first;
00085     std::vector<double> pluginOutput = (plugin->second)->operator()(tau);
00086     // Check for nans
00087     for(size_t instance = 0; instance < pluginOutput.size(); ++instance) {
00088       if (std::isnan(pluginOutput[instance])) {
00089         std::ostringstream error;
00090         error << "A nan was detected in"
00091             << " the tau MVA variable " << id << " returning zero instead!"
00092             << " The PFTau: " << *tau << std::endl;
00093         tau->dump(error);
00094         edm::LogError("CorruptedMVAInput") << error.str();
00095         pluginOutput[instance] = 0.0;
00096       }
00097     }
00098     //std::cout << "id: " << id << " first: " << pluginOutput[0] << std::endl;
00099     // Build values and copy into values vector
00100     std::for_each(pluginOutput.begin(), pluginOutput.end(),
00101                   boost::bind(&PhysicsTools::Variable::ValueList::add,
00102                               boost::ref(values_), id, _1));
00103   }
00104 }
00105 
00106 // Get values
00107 const PhysicsTools::Variable::ValueList&
00108 RecoTauMVAHelper::discriminants(const PFTauRef& tau) const {
00109   values_.clear();
00110   fillValues(tau);
00111   return values_;
00112 }
00113 
00114 // Apply the MVA to a given tau
00115 double RecoTauMVAHelper::operator()(const reco::PFTauRef &tau) const {
00116   // Clear output
00117   values_.clear();
00118   // Build the values
00119   fillValues(tau);
00120   // Call the MVA
00121   return computer_->eval(values_);
00122 }
00123 
00124 void RecoTauMVAHelper::train(const reco::PFTauRef &tau, bool target,
00125                              double weight) const {
00126   static const PhysicsTools::AtomicId kTargetId("__TARGET__");
00127   static const PhysicsTools::AtomicId kWeightId("__WEIGHT__");
00128   if (!computer_)
00129     return;
00130   values_.clear();
00131   values_.add(kTargetId, target);
00132   values_.add(kWeightId, weight);
00133   // Build the discriminant values
00134   fillValues(tau);
00135   computer_->eval(values_);
00136 }
00137 
00138 }}  // end namespace reco::tau