CMS 3D CMS Logo

PATTauDiscriminationByMVAIsolationRun2.cc
Go to the documentation of this file.
1 
2 /*
3  * \class PATTauDiscriminationByMVAIsolationRun2
4  *
5  * MVA based discriminator against jet -> tau fakes
6  *
7  * Adopted from RecoTauTag/RecoTau/plugins/PFRecoTauDiscriminationByMVAIsolationRun2.cc
8  * to enable computation of MVA isolation on MiniAOD
9  *
10  * \author Alexander Nehrkorn, RWTH Aachen
11  */
12 
13 // todo 1: remove leadingTrackChi2 as input variable from:
14 // - here
15 // - TauPFEssential
16 // - PFRecoTauDiscriminationByMVAIsolationRun2
17 // - Training of BDT
18 
20 
26 
28 
34 
38 
39 #include <TFile.h>
40 
41 #include <iostream>
42 
43 using namespace pat;
44 
45 namespace
46 {
47  const GBRForest* loadMVAfromFile(const edm::FileInPath& inputFileName, const std::string& mvaName, std::vector<TFile*>& inputFilesToDelete)
48  {
49  if ( inputFileName.location() == edm::FileInPath::Unknown ) throw cms::Exception("PATTauDiscriminationByIsolationMVARun2::loadMVA")
50  << " Failed to find File = " << inputFileName << " !!\n";
51  TFile* inputFile = new TFile(inputFileName.fullPath().data());
52 
53  //const GBRForest* mva = dynamic_cast<GBRForest*>(inputFile->Get(mvaName.data())); // CV: dynamic_cast<GBRForest*> fails for some reason ?!
54  const GBRForest* mva = (GBRForest*)inputFile->Get(mvaName.data());
55  if ( !mva )
56  throw cms::Exception("PATTauDiscriminationByIsolationMVARun2::loadMVA")
57  << " Failed to load MVA = " << mvaName.data() << " from file = " << inputFileName.fullPath().data() << " !!\n";
58 
59  inputFilesToDelete.push_back(inputFile);
60 
61  return mva;
62  }
63 
64  const GBRForest* loadMVAfromDB(const edm::EventSetup& es, const std::string& mvaName)
65  {
67  es.get<GBRWrapperRcd>().get(mvaName, mva);
68  return mva.product();
69  }
70 }
71 
72 namespace reco { namespace tau {
73 
75 {
76  public:
79  moduleLabel_(cfg.getParameter<std::string>("@module_label")),
80  mvaReader_(nullptr),
81  mvaInput_(nullptr),
82  category_output_()
83  {
84  mvaName_ = cfg.getParameter<std::string>("mvaName");
85  loadMVAfromDB_ = cfg.exists("loadMVAfromDB") ? cfg.getParameter<bool>("loadMVAfromDB") : false;
86  if ( !loadMVAfromDB_ ) {
87  if(cfg.exists("inputFileName")){
88  inputFileName_ = cfg.getParameter<edm::FileInPath>("inputFileName");
89  }else throw cms::Exception("MVA input not defined") << "Requested to load tau MVA input from ROOT file but no file provided in cfg file";
90  }
91  std::string mvaOpt_string = cfg.getParameter<std::string>("mvaOpt");
92  if ( mvaOpt_string == "oldDMwoLT" ) mvaOpt_ = kOldDMwoLT;
93  else if ( mvaOpt_string == "oldDMwLT" ) mvaOpt_ = kOldDMwLT;
94  else if ( mvaOpt_string == "newDMwoLT" ) mvaOpt_ = kNewDMwoLT;
95  else if ( mvaOpt_string == "newDMwLT" ) mvaOpt_ = kNewDMwLT;
96  else if ( mvaOpt_string == "DBoldDMwLT" ) mvaOpt_ = kDBoldDMwLT;
97  else if ( mvaOpt_string == "DBnewDMwLT" ) mvaOpt_ = kDBnewDMwLT;
98  else if ( mvaOpt_string == "PWoldDMwLT" ) mvaOpt_ = kPWoldDMwLT;
99  else if ( mvaOpt_string == "PWnewDMwLT" ) mvaOpt_ = kPWnewDMwLT;
100  else if ( mvaOpt_string == "DBoldDMwLTwGJ" ) mvaOpt_ = kDBoldDMwLTwGJ;
101  else if ( mvaOpt_string == "DBnewDMwLTwGJ" ) mvaOpt_ = kDBnewDMwLTwGJ;
102  else throw cms::Exception("PATTauDiscriminationByMVAIsolationRun2")
103  << " Invalid Configuration Parameter 'mvaOpt' = " << mvaOpt_string << " !!\n";
104 
105  if ( mvaOpt_ == kOldDMwoLT || mvaOpt_ == kNewDMwoLT ) mvaInput_ = new float[6];
106  else if ( mvaOpt_ == kOldDMwLT || mvaOpt_ == kNewDMwLT ) mvaInput_ = new float[12];
107  else if ( mvaOpt_ == kDBoldDMwLT || mvaOpt_ == kDBnewDMwLT ||
108  mvaOpt_ == kPWoldDMwLT || mvaOpt_ == kPWnewDMwLT ||
109  mvaOpt_ == kDBoldDMwLTwGJ || mvaOpt_ == kDBnewDMwLTwGJ) mvaInput_ = new float[23];
110  else assert(0);
111 
112  chargedIsoPtSums_ = cfg.getParameter<std::string>("srcChargedIsoPtSum");
113  neutralIsoPtSums_ = cfg.getParameter<std::string>("srcNeutralIsoPtSum");
114  puCorrPtSums_ = cfg.getParameter<std::string>("srcPUcorrPtSum");
115  photonPtSumOutsideSignalCone_ = cfg.getParameter<std::string>("srcPhotonPtSumOutsideSignalCone");
116  footprintCorrection_ = cfg.getParameter<std::string>("srcFootprintCorrection");
117 
118  verbosity_ = ( cfg.exists("verbosity") ) ?
119  cfg.getParameter<int>("verbosity") : 0;
120 
121  produces<pat::PATTauDiscriminator>("category");
122  }
123 
124  void beginEvent(const edm::Event&, const edm::EventSetup&) override;
125 
126  double discriminate(const TauRef&) const override;
127 
128  void endEvent(edm::Event&) override;
129 
131  {
132  if(!loadMVAfromDB_) delete mvaReader_;
133  delete[] mvaInput_;
134  for ( std::vector<TFile*>::iterator it = inputFilesToDelete_.begin();
135  it != inputFilesToDelete_.end(); ++it ) {
136  delete (*it);
137  }
138  }
139 
140  private:
141 
143 
148  int mvaOpt_;
149  float* mvaInput_;
150 
156 
158  std::unique_ptr<pat::PATTauDiscriminator> category_output_;
159  std::vector<TFile*> inputFilesToDelete_;
160 
162 };
163 
164 void PATTauDiscriminationByMVAIsolationRun2::beginEvent(const edm::Event& evt, const edm::EventSetup& es)
165 {
166  if( !mvaReader_ ) {
167  if ( loadMVAfromDB_ ) {
168  mvaReader_ = loadMVAfromDB(es, mvaName_);
169  } else {
170  mvaReader_ = loadMVAfromFile(inputFileName_, mvaName_, inputFilesToDelete_);
171  }
172  }
173 
174  evt.getByToken(Tau_token, taus_);
175  category_output_.reset(new pat::PATTauDiscriminator(TauRefProd(taus_)));
176 }
177 
178 double PATTauDiscriminationByMVAIsolationRun2::discriminate(const TauRef& tau) const
179 {
180  // CV: define dummy category index in order to use RecoTauDiscriminantCutMultiplexer module to appy WP cuts
181  double category = 0.;
182  category_output_->setValue(tauIndex_, category);
183 
184  // CV: computation of MVA value requires presence of leading charged hadron
185  if ( tau->leadChargedHadrCand().isNull() ) return 0.;
186 
187  if (reco::tau::fillIsoMVARun2Inputs(mvaInput_, *tau, mvaOpt_, chargedIsoPtSums_, neutralIsoPtSums_, puCorrPtSums_, photonPtSumOutsideSignalCone_, footprintCorrection_)) {
188  double mvaValue = mvaReader_->GetClassifier(mvaInput_);
189  if ( verbosity_ ) {
190  edm::LogPrint("PATTauDiscByMVAIsolRun2") << "<PATTauDiscriminationByMVAIsolationRun2::discriminate>:";
191  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " tau: Pt = " << tau->pt() << ", eta = " << tau->eta();
192  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " isolation: charged = " << tau->tauID(chargedIsoPtSums_) << ", neutral = " << tau->tauID(neutralIsoPtSums_) << ", PUcorr = " << tau->tauID(puCorrPtSums_);
193  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " decay mode = " << tau->decayMode();
194  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " impact parameter: distance = " << tau->dxy() << ", significance = " << tau->dxy_Sig();
195  edm::LogPrint("PATTauDiscByMVAIsolRun2") << " has decay vertex = " << tau->hasSecondaryVertex() << ":"
196  << ", significance = " << tau->flightLengthSig();
197  edm::LogPrint("PATTauDiscByMVAIsolRun2") << "--> mvaValue = " << mvaValue;
198  }
199  return mvaValue;
200  }
201  return -1.;
202 }
203 
204 void PATTauDiscriminationByMVAIsolationRun2::endEvent(edm::Event& evt)
205 {
206  // add all category indices to event
207  evt.put(std::move(category_output_), "category");
208 }
209 
211 
212 }} //namespace
T getParameter(std::string const &) const
OrphanHandle< PROD > put(std::unique_ptr< PROD > product)
Put a new product.
Definition: Event.h:137
edm::RefProd< TauCollection > TauRefProd
Definition: Tau.h:39
bool getByToken(EDGetToken token, Handle< PROD > &result) const
Definition: Event.h:579
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:17
bool exists(std::string const &parameterName) const
checks if a parameter exists
bool fillIsoMVARun2Inputs(float *mvaInput, const pat::Tau &tau, int mvaOpt, const std::string &nameCharged, const std::string &nameNeutral, const std::string &namePu, const std::string &nameOutside, const std::string &nameFootprint)
#define nullptr
Definition: HeavyIon.h:7
LocationCode location() const
Where was the file found?
Definition: FileInPath.cc:191
bool isNull() const
Checks for null.
Definition: Ref.h:250
fixed size matrix
T get() const
Definition: EventSetup.h:62
std::string fullPath() const
Definition: FileInPath.cc:197
T const * product() const
Definition: ESHandle.h:86
def move(src, dest)
Definition: eostools.py:511