CMS 3D CMS Logo

DeepTauBase.h
Go to the documentation of this file.
1 #ifndef RecoTauTag_RecoTau_DeepTauBase_h
2 #define RecoTauTag_RecoTau_DeepTauBase_h
3 
4 /*
5  * \class DeepTauBase
6  *
7  * Definition of the base class for tau identification using Deep NN.
8  *
9  * \author Konstantin Androsov, INFN Pisa
10  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
11  */
12 
13 #include <Math/VectorUtil.h>
17 #include "tensorflow/core/util/memmapped_file_system.h"
26 #include <TF1.h>
27 
28 namespace deep_tau {
29 
31  public:
32  explicit TauWPThreshold(const std::string& cut_str);
33  double operator()(const pat::Tau& tau) const;
34 
35  private:
36  std::unique_ptr<TF1> fn_;
37  double value_;
38  };
39 
40  class DeepTauCache {
41  public:
42  using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
43 
44  DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped);
45  ~DeepTauCache();
46 
47  // A Session allows concurrent calls to Run(), though a Session must
48  // be created / extended by a single thread.
49  tensorflow::Session& getSession(const std::string& name = "") const { return *sessions_.at(name); }
50  const tensorflow::GraphDef& getGraph(const std::string& name = "") const { return *graphs_.at(name); }
51 
52  private:
53  std::map<std::string, GraphPtr> graphs_;
54  std::map<std::string, tensorflow::Session*> sessions_;
55  std::map<std::string, std::unique_ptr<tensorflow::MemmappedEnv>> memmappedEnv_;
56  };
57 
58  class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache>> {
59  public:
60  using TauType = pat::Tau;
62  using TauCollection = std::vector<TauType>;
67  using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
69  using CutterPtr = std::unique_ptr<Cutter>;
70  using WPList = std::vector<CutterPtr>;
71 
72  struct Output {
73  std::vector<size_t> num_, den_;
74 
75  Output(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
76 
77  std::unique_ptr<TauDiscriminator> get_value(const edm::Handle<TauCollection>& taus,
78  const tensorflow::Tensor& pred,
79  const WPList& working_points) const;
80  };
81 
82  using OutputCollection = std::map<std::string, Output>;
83 
85  ~DeepTauBase() override {}
86 
87  void produce(edm::Event& event, const edm::EventSetup& es) override;
88 
89  static std::unique_ptr<DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg);
90  static void globalEndJob(const DeepTauCache* cache) {}
91 
92  private:
93  virtual tensorflow::Tensor getPredictions(edm::Event& event,
94  const edm::EventSetup& es,
96  virtual void createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);
97 
98  protected:
102  std::map<std::string, WPList> workingPoints_;
105  };
106 
107 } // namespace deep_tau
108 
109 #endif
ConfigurationDescriptions.h
edm::RefProd
Definition: EDProductfwd.h:25
deep_tau::DeepTauBase::cache_
const DeepTauCache * cache_
Definition: DeepTauBase.h:104
deep_tau::DeepTauCache::sessions_
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauBase.h:54
TensorFlow.h
deep_tau::DeepTauCache::getSession
tensorflow::Session & getSession(const std::string &name="") const
Definition: DeepTauBase.h:49
metsig::tau
Definition: SignAlgoResolutions.h:49
pat::ElectronCollection
std::vector< Electron > ElectronCollection
Definition: Electron.h:36
deep_tau::DeepTauBase::WPList
std::vector< CutterPtr > WPList
Definition: DeepTauBase.h:70
deep_tau::DeepTauBase::produce
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:93
edm::EDGetTokenT< TauCollection >
Tau3MuMonitor_cff.taus
taus
Definition: Tau3MuMonitor_cff.py:7
L1TRate_Offline_cfi.Tau
Tau
Definition: L1TRate_Offline_cfi.py:43
Muon.h
deep_tau::DeepTauBase::pfcandToken_
edm::EDGetTokenT< pat::PackedCandidateCollection > pfcandToken_
Definition: DeepTauBase.h:100
pat::Tau
Analysis-level tau class.
Definition: Tau.h:53
deep_tau::DeepTauBase::OutputCollection
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:82
PatBasicFWLiteJetAnalyzer_Selector_cfg.outputs
outputs
Definition: PatBasicFWLiteJetAnalyzer_Selector_cfg.py:48
EDProducer.h
deep_tau::DeepTauBase::vtxToken_
edm::EDGetTokenT< reco::VertexCollection > vtxToken_
Definition: DeepTauBase.h:101
deep_tau
Definition: DeepTauBase.h:28
deep_tau::DeepTauBase::LorentzVectorXYZ
ROOT::Math::LorentzVector< ROOT::Math::PxPyPzE4D< double > > LorentzVectorXYZ
Definition: DeepTauBase.h:67
deep_tau::TauWPThreshold::TauWPThreshold
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:14
edm::Handle< TauCollection >
Tau.h
edm::Ref
Definition: AssociativeIterator.h:58
PFRecoTauClusterVariables.h
deep_tau::DeepTauBase
Definition: DeepTauBase.h:58
deep_tau::DeepTauBase::DeepTauBase
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:76
deep_tau::DeepTauBase::globalEndJob
static void globalEndJob(const DeepTauCache *cache)
Definition: DeepTauBase.h:90
deep_tau::DeepTauBase::workingPoints_
std::map< std::string, WPList > workingPoints_
Definition: DeepTauBase.h:102
deep_tau::TauWPThreshold
Definition: DeepTauBase.h:30
pat::MuonCollection
std::vector< Muon > MuonCollection
Definition: Muon.h:35
deep_tau::TauWPThreshold::value_
double value_
Definition: DeepTauBase.h:37
ParameterSetDescription.h
utilities.cache
def cache(function)
Definition: utilities.py:3
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
deep_tau::DeepTauBase::MuonCollection
pat::MuonCollection MuonCollection
Definition: DeepTauBase.h:66
deep_tau::DeepTauBase::outputs_
OutputCollection outputs_
Definition: DeepTauBase.h:103
deep_tau::DeepTauBase::Output
Definition: DeepTauBase.h:72
edm::ParameterSet
Definition: ParameterSet.h:36
deep_tau::DeepTauBase::initializeGlobalCache
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:108
deep_tau::DeepTauBase::Output::Output
Output(const std::vector< size_t > &num, const std::vector< size_t > &den)
Definition: DeepTauBase.h:75
deep_tau::DeepTauBase::Output::get_value
std::unique_ptr< TauDiscriminator > get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPList &working_points) const
Definition: DeepTauBase.cc:48
deep_tau::DeepTauBase::CutterPtr
std::unique_ptr< Cutter > CutterPtr
Definition: DeepTauBase.h:69
edm::stream::EDProducer
Definition: EDProducer.h:38
TauDiscriminatorContainer.h
edm::EventSetup
Definition: EventSetup.h:57
deep_tau::DeepTauBase::TauCollection
std::vector< TauType > TauCollection
Definition: DeepTauBase.h:62
EgammaValidation_cff.num
num
Definition: EgammaValidation_cff.py:34
looper.cfg
cfg
Definition: looper.py:297
deep_tau::DeepTauCache::graphs_
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauBase.h:53
deep_tau::DeepTauBase::ElectronCollection
pat::ElectronCollection ElectronCollection
Definition: DeepTauBase.h:65
deep_tau::DeepTauBase::getPredictions
virtual tensorflow::Tensor getPredictions(edm::Event &event, const edm::EventSetup &es, edm::Handle< TauCollection > taus)=0
edm::ValueMap
Definition: ValueMap.h:107
deep_tau::DeepTauBase::Output::num_
std::vector< size_t > num_
Definition: DeepTauBase.h:73
deep_tau::DeepTauBase::~DeepTauBase
~DeepTauBase() override
Definition: DeepTauBase.h:85
deep_tau::DeepTauBase::createOutputs
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:101
deep_tau::TauWPThreshold::operator()
double operator()(const pat::Tau &tau) const
Definition: DeepTauBase.cc:39
deep_tau::DeepTauCache::~DeepTauCache
~DeepTauCache()
Definition: DeepTauBase.cc:168
deep_tau::DeepTauCache::DeepTauCache
DeepTauCache(const std::map< std::string, std::string > &graph_names, bool mem_mapped)
Definition: DeepTauBase.cc:130
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
deep_tau::TauWPThreshold::fn_
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:36
Electron.h
deep_tau::DeepTauCache
Definition: DeepTauBase.h:40
ParameterSet.h
deep_tau::DeepTauCache::GraphPtr
std::shared_ptr< tensorflow::GraphDef > GraphPtr
Definition: DeepTauBase.h:42
event
Definition: event.py:1
reco::TauDiscriminatorContainer
edm::ValueMap< SingleTauDiscriminatorContainer > TauDiscriminatorContainer
Definition: TauDiscriminatorContainer.h:17
edm::Event
Definition: Event.h:73
StringObjectFunction.h
FWLite.working_points
working_points
Definition: FWLite.py:121
deep_tau::DeepTauBase::tausToken_
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:99
deep_tau::DeepTauBase::Output::den_
std::vector< size_t > den_
Definition: DeepTauBase.h:73
deep_tau::DeepTauCache::getGraph
const tensorflow::GraphDef & getGraph(const std::string &name="") const
Definition: DeepTauBase.h:50
deep_tau::DeepTauCache::memmappedEnv_
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauBase.h:55