CMS 3D CMS Logo

DeepTauBase.cc
Go to the documentation of this file.
1 /*
2  * \class DeepTauBase
3  *
4  * Implementation of the base class for tau identification using Deep NN.
5  *
6  * \author Konstantin Androsov, INFN Pisa
7  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
8  */
9 
11 
12 namespace deep_tau {
13 
15  bool simple_value = false;
16  try {
17  size_t pos = 0;
18  value_ = std::stod(cut_str, &pos);
19  simple_value = (pos == cut_str.size());
20  } catch (std::invalid_argument&) {
21  } catch (std::out_of_range&) {
22  }
23  if (!simple_value) {
24  static const std::string prefix =
25  "[&](double *x, double *p) { const int decayMode = p[0];"
26  "const double pt = p[1]; const double eta = p[2];";
27  static const int n_params = 3;
28  static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};
29 
30  const std::string fn_str = prefix + cut_str + "}";
31  auto old_handler = SetErrorHandler(handler);
32  fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
33  SetErrorHandler(old_handler);
34  if (!fn_->IsValid())
35  throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
36  }
37  }
38 
39  double TauWPThreshold::operator()(const pat::Tau& tau) const {
40  if (!fn_)
41  return value_;
42  fn_->SetParameter(0, tau.decayMode());
43  fn_->SetParameter(1, tau.pt());
44  fn_->SetParameter(2, tau.eta());
45  return fn_->Eval(0);
46  }
47 
48  std::unique_ptr<DeepTauBase::TauDiscriminator> DeepTauBase::Output::get_value(const edm::Handle<TauCollection>& taus,
49  const tensorflow::Tensor& pred,
50  const WPList& working_points) const {
51  std::vector<reco::SingleTauDiscriminatorContainer> outputbuffer(taus->size());
52 
53  for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
54  float x = 0;
55  for (size_t num_elem : num_)
56  x += pred.matrix<float>()(tau_index, num_elem);
57  if (x != 0 && !den_.empty()) {
58  float den_val = 0;
59  for (size_t den_elem : den_)
60  den_val += pred.matrix<float>()(tau_index, den_elem);
61  x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
62  }
63  outputbuffer[tau_index].rawValues.push_back(x);
64  for (const auto& wp : working_points) {
65  const bool pass = x > (*wp)(taus->at(tau_index));
66  outputbuffer[tau_index].workingPoints.push_back(pass);
67  }
68  }
69  std::unique_ptr<TauDiscriminator> output = std::make_unique<TauDiscriminator>();
71  filler.insert(taus, outputbuffer.begin(), outputbuffer.end());
72  filler.fill();
73  return output;
74  }
75 
77  const OutputCollection& outputCollection,
78  const DeepTauCache* cache)
79  : tausToken_(consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
80  pfcandToken_(consumes<pat::PackedCandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
81  vtxToken_(consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
82  outputs_(outputCollection),
83  cache_(cache) {
84  for (const auto& output_desc : outputs_) {
85  produces<TauDiscriminator>(output_desc.first);
86  const auto& cut_list = cfg.getParameter<std::vector<std::string>>(output_desc.first + "WP");
87  for (const std::string& cut_str : cut_list) {
88  workingPoints_[output_desc.first].push_back(std::make_unique<Cutter>(cut_str));
89  }
90  }
91  }
92 
95  event.getByToken(tausToken_, taus);
96 
97  const tensorflow::Tensor& pred = getPredictions(event, es, taus);
98  createOutputs(event, pred, taus);
99  }
100 
102  for (const auto& output_desc : outputs_) {
103  auto result = output_desc.second.get_value(taus, pred, workingPoints_.at(output_desc.first));
104  event.put(std::move(result), output_desc.first);
105  }
106  }
107 
108  std::unique_ptr<DeepTauCache> DeepTauBase::initializeGlobalCache(const edm::ParameterSet& cfg) {
109  const auto graph_name_vector = cfg.getParameter<std::vector<std::string>>("graph_file");
110  std::map<std::string, std::string> graph_names;
111  for (const auto& entry : graph_name_vector) {
112  const size_t sep_pos = entry.find(':');
113  std::string entry_name, graph_file;
114  if (sep_pos != std::string::npos) {
115  entry_name = entry.substr(0, sep_pos);
116  graph_file = entry.substr(sep_pos + 1);
117  } else {
118  entry_name = "";
119  graph_file = entry;
120  }
121  graph_file = edm::FileInPath(graph_file).fullPath();
122  if (graph_names.count(entry_name))
123  throw cms::Exception("DeepTauCache") << "Duplicated graph entries";
124  graph_names[entry_name] = graph_file;
125  }
126  bool mem_mapped = cfg.getParameter<bool>("mem_mapped");
127  return std::make_unique<DeepTauCache>(graph_names, mem_mapped);
128  }
129 
130  DeepTauCache::DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped) {
131  for (const auto& graph_entry : graph_names) {
132  tensorflow::SessionOptions options;
134 
135  const std::string& entry_name = graph_entry.first;
136  const std::string& graph_file = graph_entry.second;
137  if (mem_mapped) {
138  memmappedEnv_[entry_name] = std::make_unique<tensorflow::MemmappedEnv>(tensorflow::Env::Default());
139  const tensorflow::Status mmap_status = memmappedEnv_.at(entry_name)->InitializeFromFile(graph_file);
140  if (!mmap_status.ok()) {
141  throw cms::Exception("DeepTauCache: unable to initalize memmapped environment for ")
142  << graph_file << ". \n"
143  << mmap_status.ToString();
144  }
145 
146  graphs_[entry_name] = std::make_unique<tensorflow::GraphDef>();
147  const tensorflow::Status load_graph_status =
148  ReadBinaryProto(memmappedEnv_.at(entry_name).get(),
149  tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
150  graphs_.at(entry_name).get());
151  if (!load_graph_status.ok())
152  throw cms::Exception("DeepTauCache: unable to load graph from ") << graph_file << ". \n"
153  << load_graph_status.ToString();
154 
155  options.config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(
156  ::tensorflow::OptimizerOptions::L0);
157  options.env = memmappedEnv_.at(entry_name).get();
158 
159  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
160 
161  } else {
162  graphs_[entry_name].reset(tensorflow::loadGraphDef(graph_file));
163  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
164  }
165  }
166  }
167 
169  for (auto& session_entry : sessions_)
170  tensorflow::closeSession(session_entry.second);
171  }
172 
173 } // namespace deep_tau
deep_tau::DeepTauBase::cache_
const DeepTauCache * cache_
Definition: DeepTauBase.h:104
tensorflow::createSession
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
deep_tau::DeepTauCache::sessions_
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauBase.h:54
DeepTauBase.h
metsig::tau
Definition: SignAlgoResolutions.h:49
deep_tau::DeepTauBase::WPList
std::vector< CutterPtr > WPList
Definition: DeepTauBase.h:70
deep_tau::DeepTauBase::produce
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:93
convertSQLitetoXML_cfg.output
output
Definition: convertSQLitetoXML_cfg.py:32
pfClustersFromHGC3DClusters_cfi.wp
wp
Definition: pfClustersFromHGC3DClusters_cfi.py:20
Tau3MuMonitor_cff.taus
taus
Definition: Tau3MuMonitor_cff.py:7
edm
HLT enums.
Definition: AlignableModifier.h:19
mps_splice.entry
entry
Definition: mps_splice.py:68
deep_tau::DeepTauBase::pfcandToken_
edm::EDGetTokenT< pat::PackedCandidateCollection > pfcandToken_
Definition: DeepTauBase.h:100
pat::Tau
Analysis-level tau class.
Definition: Tau.h:53
deep_tau::DeepTauBase::OutputCollection
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:82
reco::VertexCollection
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
pos
Definition: PixelAliasList.h:18
deep_tau::DeepTauBase::vtxToken_
edm::EDGetTokenT< reco::VertexCollection > vtxToken_
Definition: DeepTauBase.h:101
deep_tau
Definition: DeepTauBase.h:28
tensorflow::setThreading
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
deep_tau::TauWPThreshold::TauWPThreshold
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:14
reco
fixed size matrix
Definition: AlignmentAlgorithmBase.h:45
btagGenBb_cfi.Status
Status
Definition: btagGenBb_cfi.py:4
edm::Handle< TauCollection >
options
Definition: options.py:1
edm::FileInPath
Definition: FileInPath.h:64
deep_tau::DeepTauBase::DeepTauBase
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:76
tensorflow::closeSession
bool closeSession(Session *&session)
Definition: TensorFlow.cc:196
deep_tau::DeepTauBase::workingPoints_
std::map< std::string, WPList > workingPoints_
Definition: DeepTauBase.h:102
deep_tau::TauWPThreshold::value_
double value_
Definition: DeepTauBase.h:37
utilities.cache
def cache(function)
Definition: utilities.py:3
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
deep_tau::DeepTauBase::outputs_
OutputCollection outputs_
Definition: DeepTauBase.h:103
HLT_2018_cff.InputTag
InputTag
Definition: HLT_2018_cff.py:79016
edm::ParameterSet
Definition: ParameterSet.h:36
deep_tau::DeepTauBase::initializeGlobalCache
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:108
SiStripPI::max
Definition: SiStripPayloadInspectorHelper.h:169
deep_tau::DeepTauBase::Output::get_value
std::unique_ptr< TauDiscriminator > get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPList &working_points) const
Definition: DeepTauBase.cc:48
trigObjTnPSource_cfi.filler
filler
Definition: trigObjTnPSource_cfi.py:21
createfilelist.int
int
Definition: createfilelist.py:10
edm::EventSetup
Definition: EventSetup.h:57
pat
Definition: HeavyIon.h:7
deep_tau::DeepTauBase::TauCollection
std::vector< TauType > TauCollection
Definition: DeepTauBase.h:62
looper.cfg
cfg
Definition: looper.py:297
deep_tau::DeepTauCache::graphs_
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauBase.h:53
pat::PackedCandidateCollection
std::vector< pat::PackedCandidate > PackedCandidateCollection
Definition: PackedCandidate.h:1130
tensorflow::loadGraphDef
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
deep_tau::DeepTauBase::getPredictions
virtual tensorflow::Tensor getPredictions(edm::Event &event, const edm::EventSetup &es, edm::Handle< TauCollection > taus)=0
eostools.move
def move(src, dest)
Definition: eostools.py:511
deep_tau::DeepTauBase::Output::num_
std::vector< size_t > num_
Definition: DeepTauBase.h:73
Exception
Definition: hltDiff.cc:246
deep_tau::DeepTauBase::createOutputs
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:101
deep_tau::TauWPThreshold::operator()
double operator()(const pat::Tau &tau) const
Definition: DeepTauBase.cc:39
deep_tau::DeepTauCache::~DeepTauCache
~DeepTauCache()
Definition: DeepTauBase.cc:168
Default
#define Default
Definition: vmac.h:110
deep_tau::DeepTauCache::DeepTauCache
DeepTauCache(const std::map< std::string, std::string > &graph_names, bool mem_mapped)
Definition: DeepTauBase.cc:130
AlcaSiPixelAliHarvester0T_cff.options
options
Definition: AlcaSiPixelAliHarvester0T_cff.py:42
deep_tau::TauWPThreshold::fn_
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:36
mps_fire.result
result
Definition: mps_fire.py:303
cms::Exception
Definition: Exception.h:70
edm::helper::Filler
Definition: ValueMap.h:22
deep_tau::DeepTauCache
Definition: DeepTauBase.h:40
event
Definition: event.py:1
edm::Event
Definition: Event.h:73
FWLite.working_points
working_points
Definition: FWLite.py:121
deep_tau::DeepTauBase::tausToken_
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:99
deep_tau::DeepTauBase::Output::den_
std::vector< size_t > den_
Definition: DeepTauBase.h:73
edm::FileInPath::fullPath
std::string fullPath() const
Definition: FileInPath.cc:163
ZMuMuAnalysisNtupler_cff.prefix
prefix
Definition: ZMuMuAnalysisNtupler_cff.py:14
deep_tau::DeepTauCache::memmappedEnv_
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauBase.h:55