CMS 3D CMS Logo

DeepTauBase.h
Go to the documentation of this file.
1 #ifndef RecoTauTag_RecoTau_DeepTauBase_h
2 #define RecoTauTag_RecoTau_DeepTauBase_h
3 
4 /*
5  * \class DeepTauBase
6  *
7  * Definition of the base class for tau identification using Deep NN.
8  *
9  * \author Konstantin Androsov, INFN Pisa
10  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
11  */
12 
13 #include <Math/VectorUtil.h>
17 #include "tensorflow/core/util/memmapped_file_system.h"
26 #include <TF1.h>
27 
28 namespace deep_tau {
29 
31 public:
32  explicit TauWPThreshold(const std::string& cut_str);
33  double operator()(const pat::Tau& tau) const;
34 
35 private:
36  std::unique_ptr<TF1> fn_;
37  double value_;
38 };
39 
40 class DeepTauCache {
41 public:
42  using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
43 
44  DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped);
45  ~DeepTauCache();
46 
47  // A Session allows concurrent calls to Run(), though a Session must
48  // be created / extended by a single thread.
49  tensorflow::Session& getSession(const std::string& name = "") const { return *sessions_.at(name); }
50  const tensorflow::GraphDef& getGraph(const std::string& name = "") const { return *graphs_.at(name); }
51 
52 private:
53  std::map<std::string, GraphPtr> graphs_;
54  std::map<std::string, tensorflow::Session*> sessions_;
55  std::map<std::string, std::unique_ptr<tensorflow::MemmappedEnv>> memmappedEnv_;
56 };
57 
58 class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache>> {
59 public:
60  using TauType = pat::Tau;
62  using TauCollection = std::vector<TauType>;
67  using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
69  using CutterPtr = std::unique_ptr<Cutter>;
70  using WPMap = std::map<std::string, CutterPtr>;
71 
72  struct Output {
73  using ResultMap = std::map<std::string, std::unique_ptr<TauDiscriminator>>;
74  std::vector<size_t> num_, den_;
75 
76  Output(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
77 
78  ResultMap get_value(const edm::Handle<TauCollection>& taus, const tensorflow::Tensor& pred,
79  const WPMap& working_points) const;
80  };
81 
82  using OutputCollection = std::map<std::string, Output>;
83 
84 
86  ~DeepTauBase() override {}
87 
88  void produce(edm::Event& event, const edm::EventSetup& es) override;
89 
90  static std::unique_ptr<DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg);
91  static void globalEndJob(const DeepTauCache* cache){ }
92 private:
93  virtual tensorflow::Tensor getPredictions(edm::Event& event, const edm::EventSetup& es,
95  virtual void createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);
96 
97 protected:
101  std::map<std::string, WPMap> workingPoints_;
104 };
105 
106 } // namespace deep_tau
107 
108 #endif
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:36
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauBase.h:54
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:98
const DeepTauCache * cache_
Definition: DeepTauBase.h:103
std::vector< size_t > num_
Definition: DeepTauBase.h:74
const tensorflow::GraphDef & getGraph(const std::string &name="") const
Definition: DeepTauBase.h:50
working_points
Definition: FWLite.py:126
std::map< std::string, CutterPtr > WPMap
Definition: DeepTauBase.h:70
static void globalEndJob(const DeepTauCache *cache)
Definition: DeepTauBase.h:91
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:82
OutputCollection outputs_
Definition: DeepTauBase.h:102
Output(const std::vector< size_t > &num, const std::vector< size_t > &den)
Definition: DeepTauBase.h:76
std::map< std::string, std::unique_ptr< TauDiscriminator >> ResultMap
Definition: DeepTauBase.h:73
std::vector< Electron > ElectronCollection
Definition: Electron.h:37
std::shared_ptr< tensorflow::GraphDef > GraphPtr
Definition: DeepTauBase.h:42
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:15
def get_value(data, filters)
Definition: das.py:55
pat::ElectronCollection ElectronCollection
Definition: DeepTauBase.h:65
std::vector< TauType > TauCollection
Definition: DeepTauBase.h:62
Analysis-level tau class.
Definition: Tau.h:56
~DeepTauBase() override
Definition: DeepTauBase.h:86
pat::MuonCollection MuonCollection
Definition: DeepTauBase.h:66
double operator()(const pat::Tau &tau) const
Definition: DeepTauBase.cc:40
std::vector< Muon > MuonCollection
Definition: Muon.h:35
std::unique_ptr< Cutter > CutterPtr
Definition: DeepTauBase.h:69
def cache(function)
Definition: utilities.py:3
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauBase.h:53
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauBase.h:55
ROOT::Math::LorentzVector< ROOT::Math::PxPyPzE4D< double >> LorentzVectorXYZ
Definition: DeepTauBase.h:67
std::map< std::string, WPMap > workingPoints_
Definition: DeepTauBase.h:101
Definition: event.py:1
tensorflow::Session & getSession(const std::string &name="") const
Definition: DeepTauBase.h:49
edm::EDGetTokenT< pat::PackedCandidateCollection > pfcandToken_
Definition: DeepTauBase.h:99
edm::EDGetTokenT< reco::VertexCollection > vtxToken_
Definition: DeepTauBase.h:100