CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
PFRecoTauDiscriminationByMVAIsolation2.cc
Go to the documentation of this file.
1 
11 
17 
19 
22 
29 
32 
33 #include <TMath.h>
34 #include <TFile.h>
35 
36 #include <iostream>
37 
38 using namespace reco;
39 
40 namespace {
41  const GBRForest* loadMVAfromFile(const edm::FileInPath& inputFileName,
42  const std::string& mvaName,
43  std::vector<TFile*>& inputFilesToDelete) {
44  if (inputFileName.location() == edm::FileInPath::Unknown)
45  throw cms::Exception("PFRecoTauDiscriminationByIsolationMVA2::loadMVA")
46  << " Failed to find File = " << inputFileName << " !!\n";
47  TFile* inputFile = new TFile(inputFileName.fullPath().data());
48 
49  //const GBRForest* mva = dynamic_cast<GBRForest*>(inputFile->Get(mvaName.data())); // CV: dynamic_cast<GBRForest*> fails for some reason ?!
50  const GBRForest* mva = (GBRForest*)inputFile->Get(mvaName.data());
51  if (!mva)
52  throw cms::Exception("PFRecoTauDiscriminationByIsolationMVA2::loadMVA")
53  << " Failed to load MVA = " << mvaName.data() << " from file = " << inputFileName.fullPath().data()
54  << " !!\n";
55 
56  inputFilesToDelete.push_back(inputFile);
57 
58  return mva;
59  }
60 } // namespace
61 
63 public:
66  moduleLabel_(cfg.getParameter<std::string>("@module_label")),
67  mvaReader_(nullptr),
68  mvaInput_(nullptr) {
69  mvaName_ = cfg.getParameter<std::string>("mvaName");
70  loadMVAfromDB_ = cfg.getParameter<bool>("loadMVAfromDB");
71  if (!loadMVAfromDB_) {
72  inputFileName_ = cfg.getParameter<edm::FileInPath>("inputFileName");
73  } else {
74  mvaToken_ = esConsumes(edm::ESInputTag{"", mvaName_});
75  }
76  std::string mvaOpt_string = cfg.getParameter<std::string>("mvaOpt");
77  if (mvaOpt_string == "oldDMwoLT")
78  mvaOpt_ = kOldDMwoLT;
79  else if (mvaOpt_string == "oldDMwLT")
80  mvaOpt_ = kOldDMwLT;
81  else if (mvaOpt_string == "newDMwoLT")
82  mvaOpt_ = kNewDMwoLT;
83  else if (mvaOpt_string == "newDMwLT")
84  mvaOpt_ = kNewDMwLT;
85  else
86  throw cms::Exception("PFRecoTauDiscriminationByIsolationMVA2")
87  << " Invalid Configuration Parameter 'mvaOpt' = " << mvaOpt_string << " !!\n";
88 
89  if (mvaOpt_ == kOldDMwoLT || mvaOpt_ == kNewDMwoLT)
90  mvaInput_ = new float[6];
91  else if (mvaOpt_ == kOldDMwLT || mvaOpt_ == kNewDMwLT)
92  mvaInput_ = new float[12];
93  else
94  assert(0);
95 
96  tauTransverseImpactParameters_token_ =
97  consumes<PFTauTIPAssociationByRef>(cfg.getParameter<edm::InputTag>("srcTauTransverseImpactParameters"));
98 
99  basicTauDiscriminators_token_ =
100  consumes<reco::TauDiscriminatorContainer>(cfg.getParameter<edm::InputTag>("srcBasicTauDiscriminators"));
101  chargedIsoPtSum_index_ = cfg.getParameter<int>("srcChargedIsoPtSumIndex");
102  neutralIsoPtSum_index_ = cfg.getParameter<int>("srcNeutralIsoPtSumIndex");
103  pucorrPtSum_index_ = cfg.getParameter<int>("srcPUcorrPtSumIndex");
104 
105  verbosity_ = cfg.getParameter<int>("verbosity");
106  }
107 
108  void beginEvent(const edm::Event&, const edm::EventSetup&) override;
109 
110  reco::SingleTauDiscriminatorContainer discriminate(const PFTauRef&) const override;
111 
113  if (!loadMVAfromDB_)
114  delete mvaReader_;
115  delete[] mvaInput_;
116  for (std::vector<TFile*>::iterator it = inputFilesToDelete_.begin(); it != inputFilesToDelete_.end(); ++it) {
117  delete (*it);
118  }
119  }
120 
121  static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
122 
123 private:
125 
132  int mvaOpt_;
133  float* mvaInput_;
134 
139 
145 
147 
148  std::vector<TFile*> inputFilesToDelete_;
149 
151 };
152 
154  if (!mvaReader_) {
155  if (loadMVAfromDB_) {
156  mvaReader_ = &es.getData(mvaToken_);
157  } else {
158  mvaReader_ = loadMVAfromFile(inputFileName_, mvaName_, inputFilesToDelete_);
159  }
160  }
161 
162  evt.getByToken(tauTransverseImpactParameters_token_, tauLifetimeInfos_);
163 
164  evt.getByToken(basicTauDiscriminators_token_, basicTauDiscriminators_);
165 
166  evt.getByToken(Tau_token, taus_);
167 }
168 
171  // CV: define dummy category index in order to use RecoTauDiscriminantCutMultiplexer module to apply WP cuts
172  result.rawValues = {-1., 0.};
173 
174  // CV: computation of MVA value requires presence of leading charged hadron
175  if (tau->leadChargedHadrCand().isNull())
176  return 0.;
177 
178  int tauDecayMode = tau->decayMode();
179 
180  if (((mvaOpt_ == kOldDMwoLT || mvaOpt_ == kOldDMwLT) &&
181  (tauDecayMode == 0 || tauDecayMode == 1 || tauDecayMode == 2 || tauDecayMode == 10)) ||
182  ((mvaOpt_ == kNewDMwoLT || mvaOpt_ == kNewDMwLT) &&
183  (tauDecayMode == 0 || tauDecayMode == 1 || tauDecayMode == 2 || tauDecayMode == 5 || tauDecayMode == 6 ||
184  tauDecayMode == 10))) {
185  double chargedIsoPtSum = (*basicTauDiscriminators_)[tau].rawValues.at(chargedIsoPtSum_index_);
186  double neutralIsoPtSum = (*basicTauDiscriminators_)[tau].rawValues.at(neutralIsoPtSum_index_);
187  double puCorrPtSum = (*basicTauDiscriminators_)[tau].rawValues.at(pucorrPtSum_index_);
188 
189  const reco::PFTauTransverseImpactParameter& tauLifetimeInfo = *(*tauLifetimeInfos_)[tau];
190 
191  double decayDistX = tauLifetimeInfo.flightLength().x();
192  double decayDistY = tauLifetimeInfo.flightLength().y();
193  double decayDistZ = tauLifetimeInfo.flightLength().z();
194  double decayDistMag = TMath::Sqrt(decayDistX * decayDistX + decayDistY * decayDistY + decayDistZ * decayDistZ);
195 
196  if (mvaOpt_ == kOldDMwoLT || mvaOpt_ == kNewDMwoLT) {
197  mvaInput_[0] = TMath::Log(TMath::Max(1., Double_t(tau->pt())));
198  mvaInput_[1] = TMath::Abs(tau->eta());
199  mvaInput_[2] = TMath::Log(TMath::Max(1.e-2, chargedIsoPtSum));
200  mvaInput_[3] = TMath::Log(TMath::Max(1.e-2, neutralIsoPtSum - 0.125 * puCorrPtSum));
201  mvaInput_[4] = TMath::Log(TMath::Max(1.e-2, puCorrPtSum));
202  mvaInput_[5] = tauDecayMode;
203  } else if (mvaOpt_ == kOldDMwLT || mvaOpt_ == kNewDMwLT) {
204  mvaInput_[0] = TMath::Log(TMath::Max(1., Double_t(tau->pt())));
205  mvaInput_[1] = TMath::Abs(tau->eta());
206  mvaInput_[2] = TMath::Log(TMath::Max(1.e-2, chargedIsoPtSum));
207  mvaInput_[3] = TMath::Log(TMath::Max(1.e-2, neutralIsoPtSum - 0.125 * puCorrPtSum));
208  mvaInput_[4] = TMath::Log(TMath::Max(1.e-2, puCorrPtSum));
209  mvaInput_[5] = tauDecayMode;
210  mvaInput_[6] = TMath::Sign(+1., tauLifetimeInfo.dxy());
211  mvaInput_[7] = TMath::Sqrt(TMath::Abs(TMath::Min(1., tauLifetimeInfo.dxy())));
212  mvaInput_[8] = TMath::Min(10., TMath::Abs(tauLifetimeInfo.dxy_Sig()));
213  mvaInput_[9] = (tauLifetimeInfo.hasSecondaryVertex()) ? 1. : 0.;
214  mvaInput_[10] = TMath::Sqrt(decayDistMag);
215  mvaInput_[11] = TMath::Min(10., tauLifetimeInfo.flightLengthSig());
216  }
217 
218  double mvaValue = mvaReader_->GetClassifier(mvaInput_);
219  if (verbosity_) {
220  edm::LogPrint("PFTauDiscByMVAIsol2") << "<PFRecoTauDiscriminationByIsolationMVA2::discriminate>:";
221  edm::LogPrint("PFTauDiscByMVAIsol2") << " tau: Pt = " << tau->pt() << ", eta = " << tau->eta();
222  edm::LogPrint("PFTauDiscByMVAIsol2") << " isolation: charged = " << chargedIsoPtSum
223  << ", neutral = " << neutralIsoPtSum << ", PUcorr = " << puCorrPtSum;
224  edm::LogPrint("PFTauDiscByMVAIsol2") << " decay mode = " << tauDecayMode;
225  edm::LogPrint("PFTauDiscByMVAIsol2") << " impact parameter: distance = " << tauLifetimeInfo.dxy()
226  << ", significance = " << tauLifetimeInfo.dxy_Sig();
227  edm::LogPrint("PFTauDiscByMVAIsol2")
228  << " has decay vertex = " << tauLifetimeInfo.hasSecondaryVertex() << ":"
229  << " distance = " << decayDistMag << ", significance = " << tauLifetimeInfo.flightLengthSig();
230  edm::LogPrint("PFTauDiscByMVAIsol2") << "--> mvaValue = " << mvaValue;
231  }
232  result.rawValues.at(0) = mvaValue;
233  }
234  return result;
235 }
236 
238  // pfRecoTauDiscriminationByIsolationMVA2
240 
241  desc.add<std::string>("mvaName");
242  desc.add<bool>("loadMVAfromDB");
243  desc.addOptional<edm::FileInPath>("inputFileName");
244  desc.add<std::string>("mvaOpt");
245 
246  desc.add<edm::InputTag>("srcTauTransverseImpactParameters");
247  desc.add<edm::InputTag>("srcBasicTauDiscriminators");
248  desc.add<int>("srcChargedIsoPtSumIndex");
249  desc.add<int>("srcNeutralIsoPtSumIndex");
250  desc.add<int>("srcPUcorrPtSumIndex");
251  desc.add<int>("verbosity", 0);
252 
253  fillProducerDescriptions(desc); // inherited from the base
254 
255  descriptions.add("pfRecoTauDiscriminationByIsolationMVA2", desc);
256 }
257 
void beginEvent(const edm::Event &, const edm::EventSetup &) override
ParameterDescriptionBase * addOptional(U const &iLabel, T const &value)
tuple cfg
Definition: looper.py:296
bool getByToken(EDGetToken token, Handle< PROD > &result) const
Definition: Event.h:539
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
edm::ESGetToken< GBRForest, GBRWrapperRcd > mvaToken_
edm::AssociationVector< reco::PFTauRefProd, std::vector< reco::PFTauTransverseImpactParameterRef > > PFTauTIPAssociationByRef
edm::Handle< PFTauTIPAssociationByRef > tauLifetimeInfos_
assert(be >=bs)
edm::EDGetTokenT< PFTauTIPAssociationByRef > tauTransverseImpactParameters_token_
tuple result
Definition: mps_fire.py:311
bool getData(T &iHolder) const
Definition: EventSetup.h:122
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
edm::Handle< reco::TauDiscriminatorContainer > basicTauDiscriminators_
LocationCode location() const
Where was the file found?
Definition: FileInPath.cc:159
ParameterDescriptionBase * add(U const &iLabel, T const &value)
Log< level::Warning, true > LogPrint
bool isNull() const
Checks for null.
Definition: Ref.h:235
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
T getParameter(std::string const &) const
Definition: ParameterSet.h:303
void add(std::string const &label, ParameterSetDescription const &psetDescription)
std::string fullPath() const
Definition: FileInPath.cc:161
edm::EDGetTokenT< reco::TauDiscriminatorContainer > basicTauDiscriminators_token_
moduleLabel_(iConfig.getParameter< string >("@module_label"))
ESGetTokenH3DDVariant esConsumes(std::string const &Reccord, edm::ConsumesCollector &)
Definition: DeDxTools.cc:283
reco::SingleTauDiscriminatorContainer discriminate(const PFTauRef &) const override