CMS 3D CMS Logo

DeepTauBase.cc
Go to the documentation of this file.
1 /*
2  * \class DeepTauBase
3  *
4  * Implementation of the base class for tau identification using Deep NN.
5  *
6  * \author Konstantin Androsov, INFN Pisa
7  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
8  */
9 
10 
12 
13 namespace deep_tau {
14 
16 {
17  bool simple_value = false;
18  try {
19  size_t pos = 0;
20  value_ = std::stod(cut_str, &pos);
21  simple_value = (pos == cut_str.size());
22  } catch(std::invalid_argument&) {
23  } catch(std::out_of_range&) {
24  }
25  if(!simple_value) {
26  static const std::string prefix = "[&](double *x, double *p) { const int decayMode = p[0];"
27  "const double pt = p[1]; const double eta = p[2];";
28  static const int n_params = 3;
29  static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};
30 
31  const std::string fn_str = prefix + cut_str + "}";
32  auto old_handler = SetErrorHandler(handler);
33  fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
34  SetErrorHandler(old_handler);
35  if(!fn_->IsValid())
36  throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
37  }
38 }
39 
41 {
42  if(!fn_)
43  return value_;
44  fn_->SetParameter(0, tau.decayMode());
45  fn_->SetParameter(1, tau.pt());
46  fn_->SetParameter(2, tau.eta());
47  return fn_->Eval(0);
48 }
49 
51  const tensorflow::Tensor& pred,
52  const WPMap& working_points) const
53 {
55  output[""] = std::make_unique<TauDiscriminator>(TauRefProd(taus));
56  for(const auto& wp : working_points)
57  output[wp.first] = std::make_unique<TauDiscriminator>(TauRefProd(taus));
58 
59  for(size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
60  float x = 0;
61  for(size_t num_elem : num_)
62  x += pred.matrix<float>()(tau_index, num_elem);
63  if(x != 0 && !den_.empty()) {
64  float den_val = 0;
65  for(size_t den_elem : den_)
66  den_val += pred.matrix<float>()(tau_index, den_elem);
67  x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
68  }
69  output[""]->setValue(tau_index, x);
70  for(const auto& wp : working_points) {
71  const auto& tau = taus->at(tau_index);
72  const bool pass = x > (*wp.second)(tau);
73  output[wp.first]->setValue(tau_index, pass);
74  }
75  }
76  return output;
77 }
78 
80  const DeepTauCache* cache) :
81  tausToken_(consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
82  pfcandToken_(consumes<pat::PackedCandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
83  vtxToken_(consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
84  outputs_(outputCollection),
85  cache_(cache)
86 {
87  for(const auto& output_desc : outputs_) {
88  produces<TauDiscriminator>(output_desc.first);
89  const auto& cut_pset = cfg.getParameter<edm::ParameterSet>(output_desc.first + "WP");
90  for(const std::string& wp_name : cut_pset.getParameterNames()) {
91  const auto& cut_str = cut_pset.getParameter<std::string>(wp_name);
92  workingPoints_[output_desc.first][wp_name] = std::make_unique<Cutter>(cut_str);
93  produces<TauDiscriminator>(output_desc.first + wp_name);
94  }
95  }
96 }
97 
99 {
101  event.getByToken(tausToken_, taus);
102 
103  const tensorflow::Tensor& pred = getPredictions(event, es, taus);
104  createOutputs(event, pred, taus);
105 }
106 
108 {
109  for(const auto& output_desc : outputs_) {
110  auto result_map = output_desc.second.get_value(taus, pred, workingPoints_.at(output_desc.first));
111  for(auto& result : result_map)
112  event.put(std::move(result.second), output_desc.first + result.first);
113  }
114 }
115 
116 std::unique_ptr<DeepTauCache> DeepTauBase::initializeGlobalCache(const edm::ParameterSet& cfg )
117 {
118  const auto graph_name_vector = cfg.getParameter<std::vector<std::string>>("graph_file");
119  std::map<std::string, std::string> graph_names;
120  for(const auto& entry : graph_name_vector) {
121  const size_t sep_pos = entry.find(':');
122  std::string entry_name, graph_file;
123  if(sep_pos != std::string::npos) {
124  entry_name = entry.substr(0, sep_pos);
125  graph_file = entry.substr(sep_pos + 1);
126  } else {
127  entry_name = "";
128  graph_file = entry;
129  }
130  graph_file = edm::FileInPath(graph_file).fullPath();
131  if(graph_names.count(entry_name))
132  throw cms::Exception("DeepTauCache") << "Duplicated graph entries";
133  graph_names[entry_name] = graph_file;
134  }
135  bool mem_mapped = cfg.getParameter<bool>("mem_mapped");
136  return std::make_unique<DeepTauCache>(graph_names, mem_mapped);
137 }
138 
139 DeepTauCache::DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped)
140 {
141  for(const auto& graph_entry : graph_names) {
142  tensorflow::SessionOptions options;
143  tensorflow::setThreading(options, 1, "no_threads");
144 
145  const std::string& entry_name = graph_entry.first;
146  const std::string& graph_file = graph_entry.second;
147  if(mem_mapped) {
148  memmappedEnv_[entry_name] = std::make_unique<tensorflow::MemmappedEnv>(tensorflow::Env::Default());
149  const tensorflow::Status mmap_status = memmappedEnv_.at(entry_name)->InitializeFromFile(graph_file);
150  if(!mmap_status.ok()) {
151  throw cms::Exception("DeepTauCache: unable to initalize memmapped environment for ")
152  << graph_file << ". \n" << mmap_status.ToString();
153  }
154 
155  graphs_[entry_name] = std::make_unique<tensorflow::GraphDef>();
156  const tensorflow::Status load_graph_status = ReadBinaryProto(memmappedEnv_.at(entry_name).get(),
157  tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
158  graphs_.at(entry_name).get());
159  if(!load_graph_status.ok())
160  throw cms::Exception("DeepTauCache: unable to load graph from ") << graph_file << ". \n"
161  << mmap_status.ToString();
162  options.config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(::tensorflow::OptimizerOptions::L0);
163  options.env = memmappedEnv_.at(entry_name).get();
164 
165  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
166 
167  } else {
168  graphs_[entry_name].reset(tensorflow::loadGraphDef(graph_file));
169  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
170  }
171  }
172 }
173 
175 {
176  for(auto& session_entry : sessions_)
177  tensorflow::closeSession(session_entry.second);
178 }
179 
180 } // namespace deep_tau
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
T getParameter(std::string const &) const
edm::RefProd< TauCollection > TauRefProd
Definition: Tau.h:39
double eta() const final
momentum pseudorapidity
#define Default
Definition: vmac.h:110
DeepTauCache(const std::map< std::string, std::string > &graph_names, bool mem_mapped)
Definition: DeepTauBase.cc:139
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:36
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:98
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
std::vector< pat::PackedCandidate > PackedCandidateCollection
std::map< std::string, CutterPtr > WPMap
Definition: DeepTauBase.h:70
double pt() const final
transverse momentum
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:82
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:98
Definition: HeavyIon.h:7
OutputCollection outputs_
Definition: DeepTauBase.h:102
std::map< std::string, std::unique_ptr< TauDiscriminator >> ResultMap
Definition: DeepTauBase.h:73
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:15
bool closeSession(Session *&session)
Definition: TensorFlow.cc:193
virtual tensorflow::Tensor getPredictions(edm::Event &event, const edm::EventSetup &es, edm::Handle< TauCollection > taus)=0
std::vector< TauType > TauCollection
Definition: DeepTauBase.h:62
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:79
Analysis-level tau class.
Definition: Tau.h:55
def cache(function)
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:107
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
int decayMode() const
reconstructed tau decay mode (specific to PFTau)
Definition: Tau.h:365
double operator()(const pat::Tau &tau) const
Definition: DeepTauBase.cc:40
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:116
ResultMap get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPMap &working_points) const
Definition: DeepTauBase.cc:50
fixed size matrix
HLT enums.
std::string fullPath() const
Definition: FileInPath.cc:197
std::map< std::string, WPMap > workingPoints_
Definition: DeepTauBase.h:101
def move(src, dest)
Definition: eostools.py:510
Definition: event.py:1