CMS 3D CMS Logo

DeepTauBase.h
Go to the documentation of this file.
1 #ifndef RecoTauTag_RecoTau_DeepTauBase_h
2 #define RecoTauTag_RecoTau_DeepTauBase_h
3 
4 /*
5  * \class DeepTauBase
6  *
7  * Definition of the base class for tau identification using Deep NN.
8  *
9  * \author Konstantin Androsov, INFN Pisa
10  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
11  */
12 
13 #include <Math/VectorUtil.h>
17 #include "tensorflow/core/util/memmapped_file_system.h"
26 #include <TF1.h>
27 
28 namespace deep_tau {
29 
31 public:
32  explicit TauWPThreshold(const std::string& cut_str);
33  double operator()(const pat::Tau& tau) const;
34 
35 private:
36  std::unique_ptr<TF1> fn_;
37  double value_;
38 };
39 
40 class DeepTauCache {
41 public:
42  using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
43 
44  DeepTauCache(const std::string& graph_name, bool mem_mapped);
45  ~DeepTauCache();
46 
47  // A Session allows concurrent calls to Run(), though a Session must
48  // be created / extended by a single thread.
49  tensorflow::Session& getSession() const { return *session_; }
50  const tensorflow::GraphDef& getGraph() const { return *graph_; }
51 
52 private:
54  tensorflow::Session* session_;
55  std::unique_ptr<tensorflow::MemmappedEnv> memmappedEnv_;
56 };
57 
58 class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache>> {
59 public:
60  using TauType = pat::Tau;
62  using TauCollection = std::vector<TauType>;
67  using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
69  using CutterPtr = std::unique_ptr<Cutter>;
70  using WPMap = std::map<std::string, CutterPtr>;
71 
72  struct Output {
73  using ResultMap = std::map<std::string, std::unique_ptr<TauDiscriminator>>;
74  std::vector<size_t> num_, den_;
75 
76  Output(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
77 
78  ResultMap get_value(const edm::Handle<TauCollection>& taus, const tensorflow::Tensor& pred,
79  const WPMap& working_points) const;
80  };
81 
82  using OutputCollection = std::map<std::string, Output>;
83 
84 
86  virtual ~DeepTauBase() {}
87 
88  virtual void produce(edm::Event& event, const edm::EventSetup& es) override;
89 
90  static std::unique_ptr<DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg);
91  static void globalEndJob(const DeepTauCache* cache){ }
92 private:
93  virtual tensorflow::Tensor getPredictions(edm::Event& event, const edm::EventSetup& es,
95  virtual void createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);
96 
97 protected:
99  std::map<std::string, WPMap> workingPoints_;
102 };
103 
104 } // namespace deep_tau
105 
106 #endif
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:36
std::unique_ptr< tensorflow::MemmappedEnv > memmappedEnv_
Definition: DeepTauBase.h:55
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:98
const tensorflow::GraphDef & getGraph() const
Definition: DeepTauBase.h:50
tensorflow::Session & getSession() const
Definition: DeepTauBase.h:49
const DeepTauCache * cache_
Definition: DeepTauBase.h:101
std::vector< size_t > num_
Definition: DeepTauBase.h:74
tensorflow::Session * session_
Definition: DeepTauBase.h:54
std::map< std::string, CutterPtr > WPMap
Definition: DeepTauBase.h:70
static void globalEndJob(const DeepTauCache *cache)
Definition: DeepTauBase.h:91
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:82
OutputCollection outputs_
Definition: DeepTauBase.h:100
Output(const std::vector< size_t > &num, const std::vector< size_t > &den)
Definition: DeepTauBase.h:76
std::map< std::string, std::unique_ptr< TauDiscriminator >> ResultMap
Definition: DeepTauBase.h:73
std::vector< Electron > ElectronCollection
Definition: Electron.h:37
std::shared_ptr< tensorflow::GraphDef > GraphPtr
Definition: DeepTauBase.h:42
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:15
def get_value(data, filters)
Definition: das.py:54
pat::ElectronCollection ElectronCollection
Definition: DeepTauBase.h:65
std::vector< TauType > TauCollection
Definition: DeepTauBase.h:62
Analysis-level tau class.
Definition: Tau.h:55
def cache(function)
virtual ~DeepTauBase()
Definition: DeepTauBase.h:86
pat::MuonCollection MuonCollection
Definition: DeepTauBase.h:66
double operator()(const pat::Tau &tau) const
Definition: DeepTauBase.cc:40
std::vector< Muon > MuonCollection
Definition: Muon.h:34
std::unique_ptr< Cutter > CutterPtr
Definition: DeepTauBase.h:69
ROOT::Math::LorentzVector< ROOT::Math::PxPyPzE4D< double >> LorentzVectorXYZ
Definition: DeepTauBase.h:67
std::map< std::string, WPMap > workingPoints_
Definition: DeepTauBase.h:99
Definition: event.py:1