CMS 3D CMS Logo

DeepTauBase.cc
Go to the documentation of this file.
1 /*
2  * \class DeepTauBase
3  *
4  * Implementation of the base class for tau identification using Deep NN.
5  *
6  * \author Konstantin Androsov, INFN Pisa
7  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
8  */
9 
10 //TODO: port to offline RECO/AOD inputs to allow usage with offline AOD
11 //TODO: Take into account that PFTaus can also be build with pat::PackedCandidates
12 
14 
15 namespace deep_tau {
16 
18  bool simple_value = false;
19  try {
20  size_t pos = 0;
21  value_ = std::stod(cut_str, &pos);
22  simple_value = (pos == cut_str.size());
23  } catch (std::invalid_argument&) {
24  } catch (std::out_of_range&) {
25  }
26  if (!simple_value) {
27  static const std::string prefix =
28  "[&](double *x, double *p) { const int decayMode = p[0];"
29  "const double pt = p[1]; const double eta = p[2];";
30  static const int n_params = 3;
31  static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};
32 
33  const std::string fn_str = prefix + cut_str + "}";
34  auto old_handler = SetErrorHandler(handler);
35  fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
36  SetErrorHandler(old_handler);
37  if (!fn_->IsValid())
38  throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
39  }
40  }
41 
42  double TauWPThreshold::operator()(const reco::BaseTau& tau, bool isPFTau) const {
43  if (!fn_)
44  return value_;
45 
46  if (isPFTau)
47  fn_->SetParameter(0, dynamic_cast<const reco::PFTau&>(tau).decayMode());
48  else
49  fn_->SetParameter(0, dynamic_cast<const pat::Tau&>(tau).decayMode());
50  fn_->SetParameter(1, tau.pt());
51  fn_->SetParameter(2, tau.eta());
52  return fn_->Eval(0);
53  }
54 
55  std::unique_ptr<DeepTauBase::TauDiscriminator> DeepTauBase::Output::get_value(const edm::Handle<TauCollection>& taus,
56  const tensorflow::Tensor& pred,
57  const WPList* working_points,
58  bool is_online) const {
59  std::vector<reco::SingleTauDiscriminatorContainer> outputbuffer(taus->size());
60 
61  for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
62  float x = 0;
63  for (size_t num_elem : num_)
64  x += pred.matrix<float>()(tau_index, num_elem);
65  if (x != 0 && !den_.empty()) {
66  float den_val = 0;
67  for (size_t den_elem : den_)
68  den_val += pred.matrix<float>()(tau_index, den_elem);
69  x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
70  }
71  outputbuffer[tau_index].rawValues.push_back(x);
72  if (working_points) {
73  for (const auto& wp : *working_points) {
74  const bool pass = x > (*wp)(taus->at(tau_index), is_online);
75  outputbuffer[tau_index].workingPoints.push_back(pass);
76  }
77  }
78  }
79  std::unique_ptr<TauDiscriminator> output = std::make_unique<TauDiscriminator>();
81  filler.insert(taus, outputbuffer.begin(), outputbuffer.end());
82  filler.fill();
83  return output;
84  }
85 
87  const OutputCollection& outputCollection,
88  const DeepTauCache* cache)
89  : tausToken_(consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
90  pfcandToken_(consumes<CandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
91  vtxToken_(consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
92  is_online_(cfg.getParameter<bool>("is_online")),
93  outputs_(outputCollection),
94  cache_(cache) {
95  for (const auto& output_desc : outputs_) {
96  produces<TauDiscriminator>(output_desc.first);
97  const auto& cut_list = cfg.getParameter<std::vector<std::string>>(output_desc.first + "WP");
98  for (const std::string& cut_str : cut_list) {
99  workingPoints_[output_desc.first].push_back(std::make_unique<Cutter>(cut_str));
100  }
101  }
102 
103  // prediscriminant operator
104  // require the tau to pass the following prediscriminants
105  const edm::ParameterSet& prediscriminantConfig = cfg.getParameter<edm::ParameterSet>("Prediscriminants");
106 
107  // determine boolean operator used on the prediscriminants
108  std::string pdBoolOperator = prediscriminantConfig.getParameter<std::string>("BooleanOperator");
109  // convert string to lowercase
110  transform(pdBoolOperator.begin(), pdBoolOperator.end(), pdBoolOperator.begin(), ::tolower);
111 
112  if (pdBoolOperator == "and") {
113  andPrediscriminants_ = 0x1; //use chars instead of bools so we can do a bitwise trick later
114  } else if (pdBoolOperator == "or") {
115  andPrediscriminants_ = 0x0;
116  } else {
117  throw cms::Exception("TauDiscriminationProducerBase")
118  << "PrediscriminantBooleanOperator defined incorrectly, options are: AND,OR";
119  }
120 
121  // get the list of prediscriminants
122  std::vector<std::string> prediscriminantsNames =
123  prediscriminantConfig.getParameterNamesForType<edm::ParameterSet>();
124 
125  for (auto const& iDisc : prediscriminantsNames) {
126  const edm::ParameterSet& iPredisc = prediscriminantConfig.getParameter<edm::ParameterSet>(iDisc);
127  const edm::InputTag& label = iPredisc.getParameter<edm::InputTag>("Producer");
128  double cut = iPredisc.getParameter<double>("cut");
129 
130  if (is_online_) {
131  TauDiscInfo<reco::PFTauDiscriminator> thisDiscriminator;
132  thisDiscriminator.label = label;
133  thisDiscriminator.cut = cut;
134  thisDiscriminator.disc_token = consumes<reco::PFTauDiscriminator>(label);
135  recoPrediscriminants_.push_back(thisDiscriminator);
136  } else {
137  TauDiscInfo<pat::PATTauDiscriminator> thisDiscriminator;
138  thisDiscriminator.label = label;
139  thisDiscriminator.cut = cut;
140  thisDiscriminator.disc_token = consumes<pat::PATTauDiscriminator>(label);
141  patPrediscriminants_.push_back(thisDiscriminator);
142  }
143  }
144  }
145 
148  event.getByToken(tausToken_, taus);
149  edm::ProductID tauProductID = taus.id();
150 
151  // load prediscriminators
152  size_t nPrediscriminants =
154  for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
155  edm::ProductID discKeyId;
156  if (is_online_) {
157  recoPrediscriminants_[iDisc].fill(event);
158  discKeyId = recoPrediscriminants_[iDisc].handle->keyProduct().id();
159  } else {
160  patPrediscriminants_[iDisc].fill(event);
161  discKeyId = patPrediscriminants_[iDisc].handle->keyProduct().id();
162  }
163 
164  // Check to make sure the product is correct for the discriminator.
165  // If not, throw a more informative exception.
166  if (tauProductID != discKeyId) {
167  throw cms::Exception("MisconfiguredPrediscriminant")
168  << "The tau collection has product ID: " << tauProductID
169  << " but the pre-discriminator is keyed with product ID: " << discKeyId << std::endl;
170  }
171  }
172 
173  const tensorflow::Tensor& pred = getPredictions(event, taus);
174  createOutputs(event, pred, taus);
175  }
176 
178  for (const auto& output_desc : outputs_) {
179  const WPList* working_points = nullptr;
180  if (workingPoints_.find(output_desc.first) != workingPoints_.end()) {
181  working_points = &workingPoints_.at(output_desc.first);
182  }
183  auto result = output_desc.second.get_value(taus, pred, working_points, is_online_);
184  event.put(std::move(result), output_desc.first);
185  }
186  }
187 
188  std::unique_ptr<DeepTauCache> DeepTauBase::initializeGlobalCache(const edm::ParameterSet& cfg) {
189  const auto graph_name_vector = cfg.getParameter<std::vector<std::string>>("graph_file");
190  std::map<std::string, std::string> graph_names;
191  for (const auto& entry : graph_name_vector) {
192  const size_t sep_pos = entry.find(':');
193  std::string entry_name, graph_file;
194  if (sep_pos != std::string::npos) {
195  entry_name = entry.substr(0, sep_pos);
196  graph_file = entry.substr(sep_pos + 1);
197  } else {
198  entry_name = "";
199  graph_file = entry;
200  }
201  graph_file = edm::FileInPath(graph_file).fullPath();
202  if (graph_names.count(entry_name))
203  throw cms::Exception("DeepTauCache") << "Duplicated graph entries";
204  graph_names[entry_name] = graph_file;
205  }
206  bool mem_mapped = cfg.getParameter<bool>("mem_mapped");
207  return std::make_unique<DeepTauCache>(graph_names, mem_mapped);
208  }
209 
210  DeepTauCache::DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped) {
211  for (const auto& graph_entry : graph_names) {
212  tensorflow::SessionOptions options;
214 
215  const std::string& entry_name = graph_entry.first;
216  const std::string& graph_file = graph_entry.second;
217  if (mem_mapped) {
218  memmappedEnv_[entry_name] = std::make_unique<tensorflow::MemmappedEnv>(tensorflow::Env::Default());
219  const tensorflow::Status mmap_status = memmappedEnv_.at(entry_name)->InitializeFromFile(graph_file);
220  if (!mmap_status.ok()) {
221  throw cms::Exception("DeepTauCache: unable to initalize memmapped environment for ")
222  << graph_file << ". \n"
223  << mmap_status.ToString();
224  }
225 
226  graphs_[entry_name] = std::make_unique<tensorflow::GraphDef>();
227  const tensorflow::Status load_graph_status =
228  ReadBinaryProto(memmappedEnv_.at(entry_name).get(),
229  tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
230  graphs_.at(entry_name).get());
231  if (!load_graph_status.ok())
232  throw cms::Exception("DeepTauCache: unable to load graph from ") << graph_file << ". \n"
233  << load_graph_status.ToString();
234 
235  options.config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(
236  ::tensorflow::OptimizerOptions::L0);
237  options.env = memmappedEnv_.at(entry_name).get();
238 
239  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
240 
241  } else {
242  graphs_[entry_name].reset(tensorflow::loadGraphDef(graph_file));
243  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
244  }
245  }
246  }
247 
249  for (auto& session_entry : sessions_)
250  tensorflow::closeSession(session_entry.second);
251  }
252 
253 } // namespace deep_tau
deep_tau::DeepTauBase::cache_
const DeepTauCache * cache_
Definition: DeepTauBase.h:135
tensorflow::createSession
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
deep_tau::DeepTauBase::TauDiscInfo::label
edm::InputTag label
Definition: DeepTauBase.h:103
deep_tau::DeepTauCache::sessions_
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauBase.h:62
deep_tau::DeepTauBase::patPrediscriminants_
std::vector< TauDiscInfo< pat::PATTauDiscriminator > > patPrediscriminants_
Definition: DeepTauBase.h:112
DeepTauBase.h
electrons_cff.bool
bool
Definition: electrons_cff.py:366
deep_tau::DeepTauBase::Output::get_value
std::unique_ptr< TauDiscriminator > get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPList *working_points, bool is_online) const
Definition: DeepTauBase.cc:55
deep_tau::DeepTauBase::pfcandToken_
edm::EDGetTokenT< CandidateCollection > pfcandToken_
Definition: DeepTauBase.h:130
TkAlMuonSelectors_cfi.cut
cut
Definition: TkAlMuonSelectors_cfi.py:5
metsig::tau
Definition: SignAlgoResolutions.h:49
deep_tau::DeepTauBase::WPList
std::vector< CutterPtr > WPList
Definition: DeepTauBase.h:78
deep_tau::DeepTauBase::produce
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:146
convertSQLitetoXML_cfg.output
output
Definition: convertSQLitetoXML_cfg.py:72
pfClustersFromHGC3DClusters_cfi.wp
wp
Definition: pfClustersFromHGC3DClusters_cfi.py:20
Tau3MuMonitor_cff.taus
taus
Definition: Tau3MuMonitor_cff.py:7
edm
HLT enums.
Definition: AlignableModifier.h:19
mps_splice.entry
entry
Definition: mps_splice.py:68
deep_tau::DeepTauBase::OutputCollection
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:91
reco::VertexCollection
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
pos
Definition: PixelAliasList.h:18
HLT_FULL_cff.InputTag
InputTag
Definition: HLT_FULL_cff.py:89281
deep_tau::DeepTauBase::vtxToken_
edm::EDGetTokenT< reco::VertexCollection > vtxToken_
Definition: DeepTauBase.h:131
deep_tau
Definition: DeepTauBase.h:36
tensorflow::setThreading
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
deep_tau::TauWPThreshold::TauWPThreshold
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:17
reco
fixed size matrix
Definition: AlignmentAlgorithmBase.h:45
btagGenBb_cfi.Status
Status
Definition: btagGenBb_cfi.py:4
edm::Handle< TauCollection >
deep_tau::TauWPThreshold::operator()
double operator()(const reco::BaseTau &tau, bool isPFTau) const
Definition: DeepTauBase.cc:42
options
Definition: options.py:1
edm::FileInPath
Definition: FileInPath.h:64
deep_tau::DeepTauBase::DeepTauBase
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:86
tensorflow::closeSession
bool closeSession(Session *&session)
Definition: TensorFlow.cc:198
reco::BaseTau
Definition: BaseTau.h:18
deep_tau::DeepTauBase::workingPoints_
std::map< std::string, WPList > workingPoints_
Definition: DeepTauBase.h:132
deep_tau::DeepTauBase::getPredictions
virtual tensorflow::Tensor getPredictions(edm::Event &event, edm::Handle< TauCollection > taus)=0
deep_tau::DeepTauBase::is_online_
const bool is_online_
Definition: DeepTauBase.h:133
HcalDetIdTransform::transform
unsigned transform(const HcalDetId &id, unsigned transformCode)
Definition: HcalDetIdTransform.cc:7
deep_tau::TauWPThreshold::value_
double value_
Definition: DeepTauBase.h:45
taus_cff.decayMode
decayMode
Definition: taus_cff.py:58
utilities.cache
def cache(function)
Definition: utilities.py:3
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
deep_tau::DeepTauBase::TauDiscInfo
Definition: DeepTauBase.h:102
deep_tau::DeepTauBase::outputs_
OutputCollection outputs_
Definition: DeepTauBase.h:134
edm::View
Definition: CaloClusterFwd.h:14
deep_tau::DeepTauBase::TauDiscInfo::cut
double cut
Definition: DeepTauBase.h:106
edm::ParameterSet
Definition: ParameterSet.h:47
deep_tau::DeepTauBase::initializeGlobalCache
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:188
SiStripPI::max
Definition: SiStripPayloadInspectorHelper.h:169
deep_tau::DeepTauBase::TauDiscInfo::disc_token
edm::EDGetTokenT< ConsumeType > disc_token
Definition: DeepTauBase.h:105
trigObjTnPSource_cfi.filler
filler
Definition: trigObjTnPSource_cfi.py:21
createfilelist.int
int
Definition: createfilelist.py:10
edm::ParameterSet::getParameterNamesForType
std::vector< std::string > getParameterNamesForType(bool trackiness=true) const
Definition: ParameterSet.h:179
edm::EventSetup
Definition: EventSetup.h:58
looper.cfg
cfg
Definition: looper.py:297
deep_tau::DeepTauCache::graphs_
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauBase.h:61
tensorflow::loadGraphDef
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
eostools.move
def move(src, dest)
Definition: eostools.py:511
deep_tau::DeepTauBase::recoPrediscriminants_
std::vector< TauDiscInfo< reco::PFTauDiscriminator > > recoPrediscriminants_
Definition: DeepTauBase.h:113
deep_tau::DeepTauBase::Output::num_
std::vector< size_t > num_
Definition: DeepTauBase.h:81
Exception
Definition: hltDiff.cc:245
deep_tau::DeepTauBase::createOutputs
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:177
deep_tau::DeepTauCache::~DeepTauCache
~DeepTauCache()
Definition: DeepTauBase.cc:248
deep_tau::DeepTauCache::DeepTauCache
DeepTauCache(const std::map< std::string, std::string > &graph_names, bool mem_mapped)
Definition: DeepTauBase.cc:210
AlcaSiPixelAliHarvester0T_cff.options
options
Definition: AlcaSiPixelAliHarvester0T_cff.py:42
edm::ParameterSet::getParameter
T getParameter(std::string const &) const
Definition: ParameterSet.h:303
deep_tau::TauWPThreshold::fn_
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:44
deep_tau::DeepTauBase::andPrediscriminants_
uint8_t andPrediscriminants_
Definition: DeepTauBase.h:111
mps_fire.result
result
Definition: mps_fire.py:311
cms::Exception
Definition: Exception.h:70
edm::helper::Filler
Definition: ValueMap.h:22
deep_tau::DeepTauCache
Definition: DeepTauBase.h:48
event
Definition: event.py:1
edm::Event
Definition: Event.h:73
FWLite.working_points
working_points
Definition: FWLite.py:121
deep_tau::DeepTauBase::tausToken_
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:129
edm::InputTag
Definition: InputTag.h:15
label
const char * label
Definition: PFTauDecayModeTools.cc:11
deep_tau::DeepTauBase::Output::den_
std::vector< size_t > den_
Definition: DeepTauBase.h:81
edm::ProductID
Definition: ProductID.h:27
hcallasereventfilter2012_cfi.prefix
prefix
Definition: hcallasereventfilter2012_cfi.py:10
edm::FileInPath::fullPath
std::string fullPath() const
Definition: FileInPath.cc:161
deep_tau::DeepTauCache::memmappedEnv_
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauBase.h:63