CMS 3D CMS Logo

DeepTauBase.h
Go to the documentation of this file.
1 #ifndef RecoTauTag_RecoTau_DeepTauBase_h
2 #define RecoTauTag_RecoTau_DeepTauBase_h
3 
4 /*
5  * \class DeepTauBase
6  *
7  * Definition of the base class for tau identification using Deep NN.
8  *
9  * \author Konstantin Androsov, INFN Pisa
10  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
11  */
12 
13 #include <Math/VectorUtil.h>
17 #include "tensorflow/core/util/memmapped_file_system.h"
33 #include <TF1.h>
34 #include <map>
35 
36 namespace deep_tau {
37 
39  public:
40  explicit TauWPThreshold(const std::string& cut_str);
41  double operator()(const reco::BaseTau& tau, bool isPFTau) const;
42 
43  private:
44  std::unique_ptr<TF1> fn_;
45  double value_;
46  };
47 
48  class DeepTauCache {
49  public:
50  using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
51 
52  DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped);
53  ~DeepTauCache();
54 
55  // A Session allows concurrent calls to Run(), though a Session must
56  // be created / extended by a single thread.
57  tensorflow::Session& getSession(const std::string& name = "") const { return *sessions_.at(name); }
58  const tensorflow::GraphDef& getGraph(const std::string& name = "") const { return *graphs_.at(name); }
59 
60  private:
61  std::map<std::string, GraphPtr> graphs_;
62  std::map<std::string, tensorflow::Session*> sessions_;
63  std::map<std::string, std::unique_ptr<tensorflow::MemmappedEnv>> memmappedEnv_;
64  };
65 
66  class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache>> {
67  public:
75  using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
77  using CutterPtr = std::unique_ptr<Cutter>;
78  using WPList = std::vector<CutterPtr>;
79 
80  struct Output {
81  std::vector<size_t> num_, den_;
82 
83  Output(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
84 
85  std::unique_ptr<TauDiscriminator> get_value(const edm::Handle<TauCollection>& taus,
86  const tensorflow::Tensor& pred,
87  const WPList* working_points,
88  bool is_online) const;
89  };
90 
91  using OutputCollection = std::map<std::string, Output>;
92 
94  ~DeepTauBase() override {}
95 
96  void produce(edm::Event& event, const edm::EventSetup& es) override;
97 
98  static std::unique_ptr<DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg);
99  static void globalEndJob(const DeepTauCache* cache) {}
100 
101  template <typename ConsumeType>
102  struct TauDiscInfo {
106  double cut;
107  void fill(const edm::Event& evt) { evt.getByToken(disc_token, handle); }
108  };
109 
110  // select boolean operation on prediscriminants (and = 0x01, or = 0x00)
112  std::vector<TauDiscInfo<pat::PATTauDiscriminator>> patPrediscriminants_;
113  std::vector<TauDiscInfo<reco::PFTauDiscriminator>> recoPrediscriminants_;
114 
122  };
123 
124  private:
125  virtual tensorflow::Tensor getPredictions(edm::Event& event, edm::Handle<TauCollection> taus) = 0;
126  virtual void createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);
127 
128  protected:
132  std::map<std::string, WPList> workingPoints_;
133  const bool is_online_;
136 
137  static const std::map<BasicDiscriminator, std::string> stringFromDiscriminator_;
138  static const std::vector<BasicDiscriminator> requiredBasicDiscriminators_;
139  static const std::vector<BasicDiscriminator> requiredBasicDiscriminatorsdR03_;
140  };
141 
142 } // namespace deep_tau
143 
144 #endif
ROOT::Math::LorentzVector< ROOT::Math::PxPyPzE4D< double > > LorentzVectorXYZ
Definition: DeepTauBase.h:75
void fill(const edm::Event &evt)
Definition: DeepTauBase.h:107
DeepTauCache(const std::map< std::string, std::string > &graph_names, bool mem_mapped)
Definition: DeepTauBase.cc:215
std::unique_ptr< TauDiscriminator > get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPList *working_points, bool is_online) const
Definition: DeepTauBase.cc:60
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:44
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauBase.h:62
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:129
const DeepTauCache * cache_
Definition: DeepTauBase.h:135
std::vector< size_t > num_
Definition: DeepTauBase.h:81
working_points
Definition: FWLite.py:134
static void globalEndJob(const DeepTauCache *cache)
Definition: DeepTauBase.h:99
bool getByToken(EDGetToken token, Handle< PROD > &result) const
Definition: Event.h:540
std::vector< size_t > den_
Definition: DeepTauBase.h:81
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:91
std::vector< CutterPtr > WPList
Definition: DeepTauBase.h:78
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:151
static const std::vector< BasicDiscriminator > requiredBasicDiscriminatorsdR03_
Definition: DeepTauBase.h:139
OutputCollection outputs_
Definition: DeepTauBase.h:134
Output(const std::vector< size_t > &num, const std::vector< size_t > &den)
Definition: DeepTauBase.h:83
std::vector< Electron > ElectronCollection
Definition: Electron.h:36
std::shared_ptr< tensorflow::GraphDef > GraphPtr
Definition: DeepTauBase.h:50
edm::EDGetTokenT< ConsumeType > disc_token
Definition: DeepTauBase.h:105
const tensorflow::GraphDef & getGraph(const std::string &name="") const
Definition: DeepTauBase.h:58
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:17
pat::ElectronCollection ElectronCollection
Definition: DeepTauBase.h:73
static const std::vector< BasicDiscriminator > requiredBasicDiscriminators_
Definition: DeepTauBase.h:138
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:91
std::map< std::string, WPList > workingPoints_
Definition: DeepTauBase.h:132
tensorflow::Session & getSession(const std::string &name="") const
Definition: DeepTauBase.h:57
virtual tensorflow::Tensor getPredictions(edm::Event &event, edm::Handle< TauCollection > taus)=0
~DeepTauBase() override
Definition: DeepTauBase.h:94
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:182
pat::MuonCollection MuonCollection
Definition: DeepTauBase.h:74
edm::ValueMap< SingleTauDiscriminatorContainer > TauDiscriminatorContainer
double operator()(const reco::BaseTau &tau, bool isPFTau) const
Definition: DeepTauBase.cc:46
edm::Handle< ConsumeType > handle
Definition: DeepTauBase.h:104
uint8_t andPrediscriminants_
Definition: DeepTauBase.h:111
static const std::map< BasicDiscriminator, std::string > stringFromDiscriminator_
Definition: DeepTauBase.h:137
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:193
std::vector< Muon > MuonCollection
Definition: Muon.h:35
std::unique_ptr< Cutter > CutterPtr
Definition: DeepTauBase.h:77
def cache(function)
Definition: utilities.py:3
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauBase.h:61
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauBase.h:63
std::vector< TauDiscInfo< pat::PATTauDiscriminator > > patPrediscriminants_
Definition: DeepTauBase.h:112
edm::EDGetTokenT< CandidateCollection > pfcandToken_
Definition: DeepTauBase.h:130
std::vector< TauDiscInfo< reco::PFTauDiscriminator > > recoPrediscriminants_
Definition: DeepTauBase.h:113
Definition: event.py:1
edm::EDGetTokenT< reco::VertexCollection > vtxToken_
Definition: DeepTauBase.h:131