CMS 3D CMS Logo

DeepTauBase.cc
Go to the documentation of this file.
1 /*
2  * \class DeepTauBase
3  *
4  * Implementation of the base class for tau identification using Deep NN.
5  *
6  * \author Konstantin Androsov, INFN Pisa
7  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
8  */
9 
10 
12 
13 namespace deep_tau {
14 
16 {
17  bool simple_value = false;
18  try {
19  size_t pos = 0;
20  value_ = std::stod(cut_str, &pos);
21  simple_value = (pos == cut_str.size());
22  } catch(std::invalid_argument&) {
23  } catch(std::out_of_range&) {
24  }
25  if(!simple_value) {
26  static const std::string prefix = "[&](double *x, double *p) { const int decayMode = p[0];"
27  "const double pt = p[1]; const double eta = p[2];";
28  static const int n_params = 3;
29  static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};
30 
31  const std::string fn_str = prefix + cut_str + "}";
32  auto old_handler = SetErrorHandler(handler);
33  fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
34  SetErrorHandler(old_handler);
35  if(!fn_->IsValid())
36  throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
37  }
38 }
39 
41 {
42  if(!fn_)
43  return value_;
44  fn_->SetParameter(0, tau.decayMode());
45  fn_->SetParameter(1, tau.pt());
46  fn_->SetParameter(2, tau.eta());
47  return fn_->Eval(0);
48 }
49 
51  const tensorflow::Tensor& pred,
52  const WPMap& working_points) const
53 {
55  output[""] = std::make_unique<TauDiscriminator>(TauRefProd(taus));
56  for(const auto& wp : working_points)
57  output[wp.first] = std::make_unique<TauDiscriminator>(TauRefProd(taus));
58 
59  for(size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
60  float x = 0;
61  for(size_t num_elem : num_)
62  x += pred.matrix<float>()(tau_index, num_elem);
63  if(x != 0 && !den_.empty()) {
64  float den_val = 0;
65  for(size_t den_elem : den_)
66  den_val += pred.matrix<float>()(tau_index, den_elem);
67  x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
68  }
69  output[""]->setValue(tau_index, x);
70  for(const auto& wp : working_points) {
71  const auto& tau = taus->at(tau_index);
72  const bool pass = x > (*wp.second)(tau);
73  output[wp.first]->setValue(tau_index, pass);
74  }
75  }
76  return output;
77 }
78 
80  tausToken_(consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
81  outputs_(outputCollection),
82  cache_(cache)
83 {
84  for(const auto& output_desc : outputs_) {
85  produces<TauDiscriminator>(output_desc.first);
86  const auto& cut_pset = cfg.getParameter<edm::ParameterSet>(output_desc.first + "WP");
87  for(const std::string& wp_name : cut_pset.getParameterNames()) {
88  const auto& cut_str = cut_pset.getParameter<std::string>(wp_name);
89  workingPoints_[output_desc.first][wp_name] = std::make_unique<Cutter>(cut_str);
90  produces<TauDiscriminator>(output_desc.first + wp_name);
91  }
92  }
93 }
94 
96 {
98  event.getByToken(tausToken_, taus);
99 
100  const tensorflow::Tensor& pred = getPredictions(event, es, taus);
101  createOutputs(event, pred, taus);
102 }
103 
105 {
106  for(const auto& output_desc : outputs_) {
107  auto result_map = output_desc.second.get_value(taus, pred, workingPoints_.at(output_desc.first));
108  for(auto& result : result_map)
109  event.put(std::move(result.second), output_desc.first + result.first);
110  }
111 }
112 
113 std::unique_ptr<DeepTauCache> DeepTauBase::initializeGlobalCache(const edm::ParameterSet& cfg )
114 {
115  std::string graph_name = edm::FileInPath(cfg.getParameter<std::string>("graph_file")).fullPath();
116  bool mem_mapped = cfg.getParameter<bool>("mem_mapped");
117  return std::make_unique<DeepTauCache>(graph_name, mem_mapped);
118 }
119 
120 DeepTauCache::DeepTauCache(const std::string& graph_name, bool mem_mapped)
121 {
122  tensorflow::SessionOptions options;
123  tensorflow::setThreading(options, 1, "no_threads");
124 
125  if(mem_mapped) {
126  memmappedEnv_ = std::make_unique<tensorflow::MemmappedEnv>(tensorflow::Env::Default());
127  const tensorflow::Status mmap_status = memmappedEnv_.get()->InitializeFromFile(graph_name);
128  if(!mmap_status.ok())
129  throw cms::Exception("DeepTauCache: unable to initalize memmapped environment for ") << graph_name << ". \n"
130  << mmap_status.ToString();
131 
132  graph_ = std::make_unique<tensorflow::GraphDef>();
133  const tensorflow::Status load_graph_status = ReadBinaryProto(memmappedEnv_.get(),
134  tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
135  graph_.get());
136  if(!load_graph_status.ok())
137  throw cms::Exception("DeepTauCache: unable to load graph_ from ") << graph_name << ". \n"
138  << mmap_status.ToString();
139  options.config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(::tensorflow::OptimizerOptions::L0);
140  options.env = memmappedEnv_.get();
141 
142  session_ = tensorflow::createSession(graph_.get(), options);
143 
144  } else {
145  graph_.reset(tensorflow::loadGraphDef(graph_name));
146  session_ = tensorflow::createSession(graph_.get(), options);
147  }
148 }
149 
151 {
152  tensorflow::closeSession(session_);
153 }
154 
155 } // namespace deep_tau
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
T getParameter(std::string const &) const
edm::RefProd< TauCollection > TauRefProd
Definition: Tau.h:39
double eta() const final
momentum pseudorapidity
#define Default
Definition: vmac.h:110
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:36
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:98
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
std::map< std::string, CutterPtr > WPMap
Definition: DeepTauBase.h:70
double pt() const final
transverse momentum
DeepTauCache(const std::string &graph_name, bool mem_mapped)
Definition: DeepTauBase.cc:120
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:82
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:95
OutputCollection outputs_
Definition: DeepTauBase.h:100
std::map< std::string, std::unique_ptr< TauDiscriminator >> ResultMap
Definition: DeepTauBase.h:73
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:15
bool closeSession(Session *&session)
Definition: TensorFlow.cc:193
virtual tensorflow::Tensor getPredictions(edm::Event &event, const edm::EventSetup &es, edm::Handle< TauCollection > taus)=0
std::vector< TauType > TauCollection
Definition: DeepTauBase.h:62
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:79
Analysis-level tau class.
Definition: Tau.h:55
def cache(function)
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:104
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
int decayMode() const
reconstructed tau decay mode (specific to PFTau)
Definition: Tau.h:365
double operator()(const pat::Tau &tau) const
Definition: DeepTauBase.cc:40
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:113
ResultMap get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPMap &working_points) const
Definition: DeepTauBase.cc:50
HLT enums.
std::string fullPath() const
Definition: FileInPath.cc:197
std::map< std::string, WPMap > workingPoints_
Definition: DeepTauBase.h:99
def move(src, dest)
Definition: eostools.py:510
Definition: event.py:1