CMS 3D CMS Logo

/afs/cern.ch/work/a/aaltunda/public/www/CMSSW_5_3_14/src/RecoTauTag/RecoTau/plugins/RecoTauMVADiscriminator.cc

Go to the documentation of this file.
00001 /*
00002  * RecoTauMVADiscriminator
00003  *
00004  * Apply an MVA discriminator to a collection of PFTaus.  Output is a
00005  * PFTauDiscriminator.  The module takes the following options:
00006  *  > dbLabel - should match "appendToDataLabel" option of PoolDBSource
00007  *              if it exists.
00008  *  > mvas    - a vector of PSets, each of which contains nCharged, nPiZeros
00009  *              and a string giving the name of the correct MVA in the
00010  *              MVA ComputerContainer provided PoolDBSource.  This maps decay
00011  *              modes to MVA implementations.
00012  *  > defaultMVA - MVA to use if the decay mode does not match one specified in
00013  *              mvas.
00014  *  > remapOutput - TMVA gives its output from (-1, 1).  If this enabled remap
00015  *              it to (0, 1).
00016  *
00017  *  The interface to the MVA framework is handled by the RecoTauMVAHelper class.
00018  *
00019  * Author: Evan K. Friis (UC Davis)
00020  *
00021  */
00022 
00023 #include <boost/foreach.hpp>
00024 #include <boost/ptr_container/ptr_map.hpp>
00025 
00026 #include "RecoTauTag/RecoTau/interface/TauDiscriminationProducerBase.h"
00027 #include "RecoTauTag/RecoTau/interface/RecoTauMVAHelper.h"
00028 #include "RecoTauTag/RecoTau/interface/PFTauDecayModeTools.h"
00029 
00030 #include "DataFormats/TauReco/interface/PFTau.h"
00031 #include "DataFormats/TauReco/interface/PFTauFwd.h"
00032 
00033 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00034 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00035 
00036 
00037 class RecoTauMVADiscriminator : public PFTauDiscriminationProducerBase {
00038   public:
00039     explicit RecoTauMVADiscriminator(const edm::ParameterSet& pset);
00040     ~RecoTauMVADiscriminator() {}
00041 
00042     void beginEvent(const edm::Event&, const edm::EventSetup&);
00043     double discriminate(const reco::PFTauRef&);
00044 
00045   private:
00046     // Map a decay mode to an MVA getter
00047     typedef boost::ptr_map<reco::PFTau::hadronicDecayMode,
00048             reco::tau::RecoTauMVAHelper> MVAMap;
00049 
00050     std::auto_ptr<reco::tau::RecoTauMVAHelper> defaultMVA_;
00051 
00052     MVAMap mvas_;
00053     std::string dbLabel_;
00054     double unsupportedDMValue_;
00055     bool remapOutput_;
00056 };
00057 
00058 RecoTauMVADiscriminator::RecoTauMVADiscriminator(const edm::ParameterSet& pset)
00059   :PFTauDiscriminationProducerBase(pset) {
00060   std::string dbLabel;
00061   if (pset.exists("dbLabel"))
00062     dbLabel = pset.getParameter<std::string>("dbLabel");
00063 
00064   unsupportedDMValue_ = (pset.exists("unsupportedDecayModeValue")) ?
00065       pset.getParameter<double>("unsupportedDecayModeValue")
00066       : prediscriminantFailValue_;
00067 
00068   remapOutput_ = pset.getParameter<bool>("remapOutput");
00069 
00070   edm::ParameterSet discriminantOptions = pset.getParameter<edm::ParameterSet>(
00071       "discriminantOptions");
00072 
00073   typedef std::vector<edm::ParameterSet> VPSet;
00074   const VPSet& mvas = pset.getParameter<VPSet>("mvas");
00075 
00076   for (VPSet::const_iterator mva = mvas.begin(); mva != mvas.end(); ++mva) {
00077     unsigned int nCharged = mva->getParameter<unsigned int>("nCharged");
00078     unsigned int nPiZeros = mva->getParameter<unsigned int>("nPiZeros");
00079     reco::PFTau::hadronicDecayMode decayMode = reco::tau::translateDecayMode(
00080         nCharged, nPiZeros);
00081     // Check to ensure this decay mode is not already added
00082     if (!mvas_.count(decayMode)) {
00083       std::string computerName = mva->getParameter<std::string>("mvaLabel");
00084       // Add it
00085       mvas_.insert(
00086           decayMode, new reco::tau::RecoTauMVAHelper(
00087             computerName, dbLabel, discriminantOptions));
00088     } else {
00089       edm::LogError("DecayModeNotUnique") << "The tau decay mode with "
00090         "nCharged/nPiZero = " << nCharged << "/" << nPiZeros << " dm: "
00091         << decayMode <<
00092         " is associated to multiple MVA implmentations, "
00093         "the second instantiation is being ignored!!!";
00094     }
00095   }
00096 
00097   // Check if we a catch-all MVA is desired.
00098   if (pset.exists("defaultMVA")) {
00099     defaultMVA_.reset(new reco::tau::RecoTauMVAHelper(
00100             pset.getParameter<std::string>("defaultMVA"),
00101             dbLabel, discriminantOptions));
00102   }
00103 
00104 }
00105 
00106 void RecoTauMVADiscriminator::beginEvent(const edm::Event& evt,
00107                                          const edm::EventSetup& es) {
00108   // Pass the event setup so the MVAHelpers can get the MVAs from the DB
00109   BOOST_FOREACH(MVAMap::value_type mva, mvas_) {
00110       mva.second->setEvent(evt, es);
00111   }
00112   if (defaultMVA_.get())
00113     defaultMVA_->setEvent(evt, es);
00114 }
00115 
00116 // Get the MVA output for a given PFTau
00117 double RecoTauMVADiscriminator::discriminate(const reco::PFTauRef& tau) {
00118   // Find the right MVA for this tau's decay mode
00119   MVAMap::iterator mva = mvas_.find(tau->decayMode());
00120   // If this DM has an associated decay mode, get and return the result.
00121   double output = unsupportedDMValue_;
00122   if (mva != mvas_.end() || defaultMVA_.get()) {
00123     if (mva != mvas_.end())
00124       output = mva->second->operator()(tau);
00125     else
00126       output = defaultMVA_->operator()(tau);
00127     // TMVA produces output from -1 to 1
00128     if (remapOutput_) {
00129       output += 1.;
00130       output /= 2.;
00131     }
00132   }
00133   return output;
00134 }
00135 
00136 #include "FWCore/Framework/interface/MakerMacros.h"
00137 DEFINE_FWK_MODULE(RecoTauMVADiscriminator);