CMS 3D CMS Logo

PATTauDiscriminationByMVAIsolationRun2.cc
Go to the documentation of this file.
1 
2 /*
3  * \class PATTauDiscriminationByMVAIsolationRun2
4  *
5  * MVA based discriminator against jet -> tau fakes
6  *
7  * Adopted from RecoTauTag/RecoTau/plugins/PFRecoTauDiscriminationByMVAIsolationRun2.cc
8  * to enable computation of MVA isolation on MiniAOD
9  *
10  * \author Alexander Nehrkorn, RWTH Aachen
11  */
12 
13 // todo 1: remove leadingTrackChi2 as input variable from:
14 // - here
15 // - TauPFEssential
16 // - PFRecoTauDiscriminationByMVAIsolationRun2
17 // - Training of BDT
18 
20 
26 
28 
31 
37 
40 
41 #include <TFile.h>
42 
43 #include <iostream>
44 
45 using namespace pat;
46 
47 namespace {
48  const GBRForest* loadMVAfromFile(const edm::FileInPath& inputFileName,
49  const std::string& mvaName,
50  std::vector<TFile*>& inputFilesToDelete) {
51  if (inputFileName.location() == edm::FileInPath::Unknown)
52  throw cms::Exception("PATTauDiscriminationByIsolationMVARun2::loadMVA")
53  << " Failed to find File = " << inputFileName << " !!\n";
54  TFile* inputFile = new TFile(inputFileName.fullPath().data());
55 
56  //const GBRForest* mva = dynamic_cast<GBRForest*>(inputFile->Get(mvaName.data())); // CV: dynamic_cast<GBRForest*> fails for some reason ?!
57  const GBRForest* mva = (GBRForest*)inputFile->Get(mvaName.data());
58  if (!mva)
59  throw cms::Exception("PATTauDiscriminationByIsolationMVARun2::loadMVA")
60  << " Failed to load MVA = " << mvaName.data() << " from file = " << inputFileName.fullPath().data()
61  << " !!\n";
62 
63  inputFilesToDelete.push_back(inputFile);
64 
65  return mva;
66  }
67 } // namespace
68 
69 namespace reco {
70  namespace tau {
71 
73  public:
76  moduleLabel_(cfg.getParameter<std::string>("@module_label")),
77  mvaReader_(nullptr),
78  mvaInput_(nullptr) {
79  mvaName_ = cfg.getParameter<std::string>("mvaName");
80  loadMVAfromDB_ = cfg.getParameter<bool>("loadMVAfromDB");
81  if (!loadMVAfromDB_) {
82  inputFileName_ = cfg.getParameter<edm::FileInPath>("inputFileName");
83  } else {
84  mvaToken_ = esConsumes(edm::ESInputTag{"", mvaName_});
85  }
86  std::string mvaOpt_string = cfg.getParameter<std::string>("mvaOpt");
87  if (mvaOpt_string == "oldDMwoLT")
88  mvaOpt_ = kOldDMwoLT;
89  else if (mvaOpt_string == "oldDMwLT")
90  mvaOpt_ = kOldDMwLT;
91  else if (mvaOpt_string == "newDMwoLT")
92  mvaOpt_ = kNewDMwoLT;
93  else if (mvaOpt_string == "newDMwLT")
94  mvaOpt_ = kNewDMwLT;
95  else if (mvaOpt_string == "DBoldDMwLT")
96  mvaOpt_ = kDBoldDMwLT;
97  else if (mvaOpt_string == "DBnewDMwLT")
98  mvaOpt_ = kDBnewDMwLT;
99  else if (mvaOpt_string == "PWoldDMwLT")
100  mvaOpt_ = kPWoldDMwLT;
101  else if (mvaOpt_string == "PWnewDMwLT")
102  mvaOpt_ = kPWnewDMwLT;
103  else if (mvaOpt_string == "DBoldDMwLTwGJ")
104  mvaOpt_ = kDBoldDMwLTwGJ;
105  else if (mvaOpt_string == "DBnewDMwLTwGJ")
106  mvaOpt_ = kDBnewDMwLTwGJ;
107  else if (mvaOpt_string == "DBnewDMwLTwGJPhase2")
108  mvaOpt_ = kDBnewDMwLTwGJPhase2;
109  else
110  throw cms::Exception("PATTauDiscriminationByMVAIsolationRun2")
111  << " Invalid Configuration Parameter 'mvaOpt' = " << mvaOpt_string << " !!\n";
112 
113  if (mvaOpt_ == kOldDMwoLT || mvaOpt_ == kNewDMwoLT)
114  mvaInput_ = new float[6];
115  else if (mvaOpt_ == kOldDMwLT || mvaOpt_ == kNewDMwLT)
116  mvaInput_ = new float[12];
117  else if (mvaOpt_ == kDBoldDMwLT || mvaOpt_ == kDBnewDMwLT || mvaOpt_ == kPWoldDMwLT || mvaOpt_ == kPWnewDMwLT ||
118  mvaOpt_ == kDBoldDMwLTwGJ || mvaOpt_ == kDBnewDMwLTwGJ)
119  mvaInput_ = new float[23];
120  else if (mvaOpt_ == kDBnewDMwLTwGJPhase2)
121  mvaInput_ = new float[30];
122  else
123  assert(0);
124 
125  chargedIsoPtSums_ = cfg.getParameter<std::string>("srcChargedIsoPtSum");
126  neutralIsoPtSums_ = cfg.getParameter<std::string>("srcNeutralIsoPtSum");
127  puCorrPtSums_ = cfg.getParameter<std::string>("srcPUcorrPtSum");
128  photonPtSumOutsideSignalCone_ = cfg.getParameter<std::string>("srcPhotonPtSumOutsideSignalCone");
129  footprintCorrection_ = cfg.getParameter<std::string>("srcFootprintCorrection");
130 
131  verbosity_ = cfg.getParameter<int>("verbosity");
132  }
133 
134  void beginEvent(const edm::Event&, const edm::EventSetup&) override;
135 
136  reco::SingleTauDiscriminatorContainer discriminate(const TauRef&) const override;
137 
139  if (!loadMVAfromDB_)
140  delete mvaReader_;
141  delete[] mvaInput_;
142  for (std::vector<TFile*>::iterator it = inputFilesToDelete_.begin(); it != inputFilesToDelete_.end(); ++it) {
143  delete (*it);
144  }
145  }
146 
147  static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
148 
149  private:
151 
157  int mvaOpt_;
158  float* mvaInput_;
159 
165 
167  std::vector<TFile*> inputFilesToDelete_;
168 
170  };
171 
172  void PATTauDiscriminationByMVAIsolationRun2::beginEvent(const edm::Event& evt, const edm::EventSetup& es) {
173  if (!mvaReader_) {
174  if (loadMVAfromDB_) {
175  mvaReader_ = &es.getData(mvaToken_);
176  } else {
177  mvaReader_ = loadMVAfromFile(inputFileName_, mvaName_, inputFilesToDelete_);
178  }
179  }
180 
181  evt.getByToken(Tau_token, taus_);
182  }
183 
184  reco::SingleTauDiscriminatorContainer PATTauDiscriminationByMVAIsolationRun2::discriminate(const TauRef& tau) const {
185  // CV: define dummy category index in order to use RecoTauDiscriminantCutMultiplexer module to appy WP cuts
187  result.rawValues = {-1.};
188 
189  // CV: computation of MVA value requires presence of leading charged hadron
190  if (tau->leadChargedHadrCand().isNull()) {
191  result.rawValues.at(0) = 0.;
192  return result;
193  }
194 
195  if (reco::tau::fillIsoMVARun2Inputs(mvaInput_,
196  *tau,
197  mvaOpt_,
198  chargedIsoPtSums_,
199  neutralIsoPtSums_,
200  puCorrPtSums_,
201  photonPtSumOutsideSignalCone_,
202  footprintCorrection_)) {
203  double mvaValue = mvaReader_->GetClassifier(mvaInput_);
204  if (verbosity_) {
205  edm::LogPrint("PATTauDiscByMVAIsolRun2") << "<PATTauDiscriminationByMVAIsolationRun2::discriminate>:";
206  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " tau: Pt = " << tau->pt() << ", eta = " << tau->eta();
207  edm::LogPrint("PATTauDiscByMVAIsolRun2")
208  << " isolation: charged = " << tau->tauID(chargedIsoPtSums_)
209  << ", neutral = " << tau->tauID(neutralIsoPtSums_) << ", PUcorr = " << tau->tauID(puCorrPtSums_);
210  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " decay mode = " << tau->decayMode();
211  edm::LogPrint("PATTauDiscByMVAIsolRun2")
212  << " impact parameter: distance = " << tau->dxy() << ", significance = " << tau->dxy_Sig();
213  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " has decay vertex = " << tau->hasSecondaryVertex() << ":"
214  << ", significance = " << tau->flightLengthSig();
215  edm::LogPrint("PATTauDiscByMVAIsolRun2") << "--> mvaValue = " << mvaValue;
216  }
217  result.rawValues.at(0) = mvaValue;
218  }
219  return result;
220  }
221 
223  // patTauDiscriminationByMVAIsolationRun2
225 
226  desc.add<std::string>("mvaName");
227  desc.add<bool>("loadMVAfromDB");
228  desc.addOptional<edm::FileInPath>("inputFileName");
229  desc.add<std::string>("mvaOpt");
230 
231  desc.add<std::string>("srcChargedIsoPtSum");
232  desc.add<std::string>("srcNeutralIsoPtSum");
233  desc.add<std::string>("srcPUcorrPtSum");
234  desc.add<std::string>("srcPhotonPtSumOutsideSignalCone");
235  desc.add<std::string>("srcFootprintCorrection");
236  desc.add<int>("verbosity", 0);
237 
238  fillProducerDescriptions(desc); // inherited from the base
239 
240  descriptions.add("patTauDiscriminationByMVAIsolationRun2", desc);
241  }
242 
244 
245  } // namespace tau
246 } // namespace reco
ESGetTokenH3DDVariant esConsumes(std::string const &Record, edm::ConsumesCollector &)
Definition: DeDxTools.cc:283
T const & getData(const ESGetToken< T, R > &iToken) const noexcept(false)
Definition: EventSetup.h:119
bool fillIsoMVARun2Inputs(float *mvaInput, const pat::Tau &tau, int mvaOpt, const std::string &nameCharged, const std::string &nameNeutral, const std::string &namePu, const std::string &nameOutside, const std::string &nameFootprint)
bool getByToken(EDGetToken token, Handle< PROD > &result) const
Definition: Event.h:536
assert(be >=bs)
Definition: HeavyIon.h:7
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
Log< level::Warning, true > LogPrint
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
void add(std::string const &label, ParameterSetDescription const &psetDescription)
fixed size matrix