CMS 3D CMS Logo

DeepTauBase.h
Go to the documentation of this file.
1 #ifndef RecoTauTag_RecoTau_DeepTauBase_h
2 #define RecoTauTag_RecoTau_DeepTauBase_h
3 
4 /*
5  * \class DeepTauBase
6  *
7  * Definition of the base class for tau identification using Deep NN.
8  *
9  * \author Konstantin Androsov, INFN Pisa
10  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
11  */
12 
13 #include <Math/VectorUtil.h>
17 #include "tensorflow/core/util/memmapped_file_system.h"
33 #include <TF1.h>
34 #include <map>
35 
36 namespace deep_tau {
37 
39  public:
40  explicit TauWPThreshold(const std::string& cut_str);
41  double operator()(const reco::BaseTau& tau, bool isPFTau) const;
42 
43  private:
44  std::unique_ptr<TF1> fn_;
45  double value_;
46  };
47 
48  class DeepTauCache {
49  public:
50  using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
51 
52  DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped);
53  ~DeepTauCache();
54 
55  // A Session allows concurrent calls to Run(), though a Session must
56  // be created / extended by a single thread.
57  tensorflow::Session& getSession(const std::string& name = "") const { return *sessions_.at(name); }
58  const tensorflow::GraphDef& getGraph(const std::string& name = "") const { return *graphs_.at(name); }
59 
60  private:
61  std::map<std::string, GraphPtr> graphs_;
62  std::map<std::string, tensorflow::Session*> sessions_;
63  std::map<std::string, std::unique_ptr<tensorflow::MemmappedEnv>> memmappedEnv_;
64  };
65 
66  class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache>> {
67  public:
75  using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
77  using CutterPtr = std::unique_ptr<Cutter>;
78  using WPList = std::vector<CutterPtr>;
79 
80  struct Output {
81  std::vector<size_t> num_, den_;
82 
83  Output(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
84 
85  std::unique_ptr<TauDiscriminator> get_value(const edm::Handle<TauCollection>& taus,
86  const tensorflow::Tensor& pred,
87  const WPList* working_points,
88  bool is_online) const;
89  };
90 
91  using OutputCollection = std::map<std::string, Output>;
92 
94  ~DeepTauBase() override {}
95 
96  void produce(edm::Event& event, const edm::EventSetup& es) override;
97 
98  static std::unique_ptr<DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg);
99  static void globalEndJob(const DeepTauCache* cache) {}
100 
101  template <typename ConsumeType>
102  struct TauDiscInfo {
106  double cut;
107  void fill(const edm::Event& evt) { evt.getByToken(disc_token, handle); }
108  };
109 
110  // select boolean operation on prediscriminants (and = 0x01, or = 0x00)
112  std::vector<TauDiscInfo<pat::PATTauDiscriminator>> patPrediscriminants_;
113  std::vector<TauDiscInfo<reco::PFTauDiscriminator>> recoPrediscriminants_;
114 
122  };
123 
124  private:
125  virtual tensorflow::Tensor getPredictions(edm::Event& event, edm::Handle<TauCollection> taus) = 0;
126  virtual void createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);
127 
128  protected:
132  std::map<std::string, WPList> workingPoints_;
133  const bool is_online_;
136 
137  static const std::map<BasicDiscriminator, std::string> stringFromDiscriminator_;
138  static const std::vector<BasicDiscriminator> requiredBasicDiscriminators_;
139  static const std::vector<BasicDiscriminator> requiredBasicDiscriminatorsdR03_;
140  };
141 
142 } // namespace deep_tau
143 
144 #endif
ConfigurationDescriptions.h
deep_tau::DeepTauBase::TauDiscInfo::fill
void fill(const edm::Event &evt)
Definition: DeepTauBase.h:107
edm::RefProd
Definition: EDProductfwd.h:25
deep_tau::DeepTauBase::cache_
const DeepTauCache * cache_
Definition: DeepTauBase.h:135
deep_tau::DeepTauBase::TauDiscInfo::label
edm::InputTag label
Definition: DeepTauBase.h:103
deep_tau::DeepTauCache::sessions_
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauBase.h:62
ProcessHistoryID.h
deep_tau::DeepTauBase::NeutralIsoPtSumWeight
Definition: DeepTauBase.h:118
deep_tau::DeepTauBase::patPrediscriminants_
std::vector< TauDiscInfo< pat::PATTauDiscriminator > > patPrediscriminants_
Definition: DeepTauBase.h:112
deep_tau::DeepTauBase::ChargedIsoPtSum
Definition: DeepTauBase.h:116
PATTauDiscriminator.h
deep_tau::DeepTauBase::Output::get_value
std::unique_ptr< TauDiscriminator > get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPList *working_points, bool is_online) const
Definition: DeepTauBase.cc:55
TensorFlow.h
deep_tau::DeepTauCache::getSession
tensorflow::Session & getSession(const std::string &name="") const
Definition: DeepTauBase.h:57
deep_tau::DeepTauBase::pfcandToken_
edm::EDGetTokenT< CandidateCollection > pfcandToken_
Definition: DeepTauBase.h:130
deep_tau::DeepTauBase::FootprintCorrection
Definition: DeepTauBase.h:119
PFTauDiscriminator.h
metsig::tau
Definition: SignAlgoResolutions.h:49
pat::ElectronCollection
std::vector< Electron > ElectronCollection
Definition: Electron.h:36
deep_tau::DeepTauBase::WPList
std::vector< CutterPtr > WPList
Definition: DeepTauBase.h:78
deep_tau::DeepTauBase::produce
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:146
edm::EDGetTokenT< ConsumeType >
Tau3MuMonitor_cff.taus
taus
Definition: Tau3MuMonitor_cff.py:7
Muon.h
deep_tau::DeepTauBase::BasicDiscriminator
BasicDiscriminator
Definition: DeepTauBase.h:115
deep_tau::DeepTauBase::OutputCollection
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:91
PatBasicFWLiteJetAnalyzer_Selector_cfg.outputs
outputs
Definition: PatBasicFWLiteJetAnalyzer_Selector_cfg.py:48
deep_tau::DeepTauBase::requiredBasicDiscriminatorsdR03_
static const std::vector< BasicDiscriminator > requiredBasicDiscriminatorsdR03_
Definition: DeepTauBase.h:139
EDProducer.h
deep_tau::DeepTauBase::vtxToken_
edm::EDGetTokenT< reco::VertexCollection > vtxToken_
Definition: DeepTauBase.h:131
deep_tau
Definition: DeepTauBase.h:36
deep_tau::DeepTauBase::PUcorrPtSum
Definition: DeepTauBase.h:121
deep_tau::DeepTauBase::LorentzVectorXYZ
ROOT::Math::LorentzVector< ROOT::Math::PxPyPzE4D< double > > LorentzVectorXYZ
Definition: DeepTauBase.h:75
deep_tau::TauWPThreshold::TauWPThreshold
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:17
edm::Handle< TauCollection >
Tau.h
deep_tau::TauWPThreshold::operator()
double operator()(const reco::BaseTau &tau, bool isPFTau) const
Definition: DeepTauBase.cc:42
edm::Ref
Definition: AssociativeIterator.h:58
PFRecoTauClusterVariables.h
deep_tau::DeepTauBase
Definition: DeepTauBase.h:66
deep_tau::DeepTauBase::DeepTauBase
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:86
deep_tau::DeepTauBase::globalEndJob
static void globalEndJob(const DeepTauCache *cache)
Definition: DeepTauBase.h:99
reco::BaseTau
Definition: BaseTau.h:18
deep_tau::DeepTauBase::requiredBasicDiscriminators_
static const std::vector< BasicDiscriminator > requiredBasicDiscriminators_
Definition: DeepTauBase.h:138
ProductProvenance.h
deep_tau::DeepTauBase::workingPoints_
std::map< std::string, WPList > workingPoints_
Definition: DeepTauBase.h:132
deep_tau::DeepTauBase::NeutralIsoPtSum
Definition: DeepTauBase.h:117
deep_tau::TauWPThreshold
Definition: DeepTauBase.h:38
deep_tau::DeepTauBase::getPredictions
virtual tensorflow::Tensor getPredictions(edm::Event &event, edm::Handle< TauCollection > taus)=0
deep_tau::DeepTauBase::is_online_
const bool is_online_
Definition: DeepTauBase.h:133
pat::MuonCollection
std::vector< Muon > MuonCollection
Definition: Muon.h:35
edm::Event::getByToken
bool getByToken(EDGetToken token, Handle< PROD > &result) const
Definition: Event.h:539
deep_tau::TauWPThreshold::value_
double value_
Definition: DeepTauBase.h:45
ParameterSetDescription.h
RefToBase.h
utilities.cache
def cache(function)
Definition: utilities.py:3
deep_tau::DeepTauBase::TauDiscInfo
Definition: DeepTauBase.h:102
deep_tau::DeepTauBase::MuonCollection
pat::MuonCollection MuonCollection
Definition: DeepTauBase.h:74
deep_tau::DeepTauBase::outputs_
OutputCollection outputs_
Definition: DeepTauBase.h:134
edm::View
Definition: CaloClusterFwd.h:14
deep_tau::DeepTauBase::Output
Definition: DeepTauBase.h:80
deep_tau::DeepTauBase::TauDiscInfo::cut
double cut
Definition: DeepTauBase.h:106
edm::ParameterSet
Definition: ParameterSet.h:47
deep_tau::DeepTauBase::initializeGlobalCache
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:188
deep_tau::DeepTauBase::Output::Output
Output(const std::vector< size_t > &num, const std::vector< size_t > &den)
Definition: DeepTauBase.h:83
deep_tau::DeepTauBase::TauDiscInfo::disc_token
edm::EDGetTokenT< ConsumeType > disc_token
Definition: DeepTauBase.h:105
deep_tau::DeepTauBase::CutterPtr
std::unique_ptr< Cutter > CutterPtr
Definition: DeepTauBase.h:77
deep_tau::DeepTauBase::stringFromDiscriminator_
static const std::map< BasicDiscriminator, std::string > stringFromDiscriminator_
Definition: DeepTauBase.h:137
edm::stream::EDProducer
Definition: EDProducer.h:36
TauDiscriminatorContainer.h
edm::EventSetup
Definition: EventSetup.h:58
deep_tau::DeepTauBase::PhotonPtSumOutsideSignalCone
Definition: DeepTauBase.h:120
AlCaHLTBitMon_QueryRunRegistry.string
string string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
EgammaValidation_cff.num
num
Definition: EgammaValidation_cff.py:33
looper.cfg
cfg
Definition: looper.py:296
deep_tau::DeepTauCache::graphs_
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauBase.h:61
deep_tau::DeepTauBase::ElectronCollection
pat::ElectronCollection ElectronCollection
Definition: DeepTauBase.h:73
deep_tau::DeepTauBase::recoPrediscriminants_
std::vector< TauDiscInfo< reco::PFTauDiscriminator > > recoPrediscriminants_
Definition: DeepTauBase.h:113
edm::ValueMap
Definition: ValueMap.h:107
deep_tau::DeepTauBase::Output::num_
std::vector< size_t > num_
Definition: DeepTauBase.h:81
deep_tau::DeepTauBase::~DeepTauBase
~DeepTauBase() override
Definition: DeepTauBase.h:94
deep_tau::DeepTauBase::createOutputs
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:177
deep_tau::DeepTauCache::~DeepTauCache
~DeepTauCache()
Definition: DeepTauBase.cc:248
Provenance.h
deep_tau::DeepTauCache::DeepTauCache
DeepTauCache(const std::map< std::string, std::string > &graph_names, bool mem_mapped)
Definition: DeepTauBase.cc:210
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
deep_tau::TauWPThreshold::fn_
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:44
Electron.h
deep_tau::DeepTauBase::andPrediscriminants_
uint8_t andPrediscriminants_
Definition: DeepTauBase.h:111
deep_tau::DeepTauCache
Definition: DeepTauBase.h:48
View.h
ParameterSet.h
deep_tau::DeepTauBase::TauDiscInfo::handle
edm::Handle< ConsumeType > handle
Definition: DeepTauBase.h:104
deep_tau::DeepTauCache::GraphPtr
std::shared_ptr< tensorflow::GraphDef > GraphPtr
Definition: DeepTauBase.h:50
event
Definition: event.py:1
reco::TauDiscriminatorContainer
edm::ValueMap< SingleTauDiscriminatorContainer > TauDiscriminatorContainer
Definition: TauDiscriminatorContainer.h:17
edm::Event
Definition: Event.h:73
StringObjectFunction.h
FWLite.working_points
working_points
Definition: FWLite.py:121
deep_tau::DeepTauBase::tausToken_
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:129
edm::InputTag
Definition: InputTag.h:15
deep_tau::DeepTauBase::Output::den_
std::vector< size_t > den_
Definition: DeepTauBase.h:81
deep_tau::DeepTauCache::getGraph
const tensorflow::GraphDef & getGraph(const std::string &name="") const
Definition: DeepTauBase.h:58
deep_tau::DeepTauCache::memmappedEnv_
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauBase.h:63