CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
DeepTauBase.cc
Go to the documentation of this file.
1 /*
2  * \class DeepTauBase
3  *
4  * Implementation of the base class for tau identification using Deep NN.
5  *
6  * \author Konstantin Androsov, INFN Pisa
7  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
8  */
9 
10 //TODO: port to offline RECO/AOD inputs to allow usage with offline AOD
11 //TODO: Take into account that PFTaus can also be build with pat::PackedCandidates
12 
14 
15 namespace deep_tau {
16 
18  bool simple_value = false;
19  try {
20  size_t pos = 0;
21  value_ = std::stod(cut_str, &pos);
22  simple_value = (pos == cut_str.size());
23  } catch (std::invalid_argument&) {
24  } catch (std::out_of_range&) {
25  }
26  if (!simple_value) {
27  static const std::string prefix =
28  "[&](double *x, double *p) { const int decayMode = p[0];"
29  "const double pt = p[1]; const double eta = p[2];";
30  static const int n_params = 3;
31  static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};
32 
33  std::string fn_str = prefix;
34  if (cut_str.find("return") == std::string::npos)
35  fn_str += " return " + cut_str + ";}";
36  else
37  fn_str += cut_str + "}";
38  auto old_handler = SetErrorHandler(handler);
39  fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
40  SetErrorHandler(old_handler);
41  if (!fn_->IsValid())
42  throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
43  }
44  }
45 
46  double TauWPThreshold::operator()(const reco::BaseTau& tau, bool isPFTau) const {
47  if (!fn_) {
48  return value_;
49  }
50 
51  if (isPFTau)
52  fn_->SetParameter(0, dynamic_cast<const reco::PFTau&>(tau).decayMode());
53  else
54  fn_->SetParameter(0, dynamic_cast<const pat::Tau&>(tau).decayMode());
55  fn_->SetParameter(1, tau.pt());
56  fn_->SetParameter(2, tau.eta());
57  return fn_->Eval(0);
58  }
59 
60  std::unique_ptr<DeepTauBase::TauDiscriminator> DeepTauBase::Output::get_value(const edm::Handle<TauCollection>& taus,
61  const tensorflow::Tensor& pred,
62  const WPList* working_points,
63  bool is_online) const {
64  std::vector<reco::SingleTauDiscriminatorContainer> outputbuffer(taus->size());
65 
66  for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
67  float x = 0;
68  for (size_t num_elem : num_)
69  x += pred.matrix<float>()(tau_index, num_elem);
70  if (x != 0 && !den_.empty()) {
71  float den_val = 0;
72  for (size_t den_elem : den_)
73  den_val += pred.matrix<float>()(tau_index, den_elem);
74  x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
75  }
76  outputbuffer[tau_index].rawValues.push_back(x);
77  if (working_points) {
78  for (const auto& wp : *working_points) {
79  const bool pass = x > (*wp)(taus->at(tau_index), is_online);
80  outputbuffer[tau_index].workingPoints.push_back(pass);
81  }
82  }
83  }
84  std::unique_ptr<TauDiscriminator> output = std::make_unique<TauDiscriminator>();
86  filler.insert(taus, outputbuffer.begin(), outputbuffer.end());
87  filler.fill();
88  return output;
89  }
90 
92  const OutputCollection& outputCollection,
93  const DeepTauCache* cache)
94  : tausToken_(consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
95  pfcandToken_(consumes<CandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
96  vtxToken_(consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
97  is_online_(cfg.getParameter<bool>("is_online")),
98  outputs_(outputCollection),
99  cache_(cache) {
100  for (const auto& output_desc : outputs_) {
101  produces<TauDiscriminator>(output_desc.first);
102  const auto& cut_list = cfg.getParameter<std::vector<std::string>>(output_desc.first + "WP");
103  for (const std::string& cut_str : cut_list) {
104  workingPoints_[output_desc.first].push_back(std::make_unique<Cutter>(cut_str));
105  }
106  }
107 
108  // prediscriminant operator
109  // require the tau to pass the following prediscriminants
110  const edm::ParameterSet& prediscriminantConfig = cfg.getParameter<edm::ParameterSet>("Prediscriminants");
111 
112  // determine boolean operator used on the prediscriminants
113  std::string pdBoolOperator = prediscriminantConfig.getParameter<std::string>("BooleanOperator");
114  // convert string to lowercase
115  transform(pdBoolOperator.begin(), pdBoolOperator.end(), pdBoolOperator.begin(), ::tolower);
116 
117  if (pdBoolOperator == "and") {
118  andPrediscriminants_ = 0x1; //use chars instead of bools so we can do a bitwise trick later
119  } else if (pdBoolOperator == "or") {
120  andPrediscriminants_ = 0x0;
121  } else {
122  throw cms::Exception("TauDiscriminationProducerBase")
123  << "PrediscriminantBooleanOperator defined incorrectly, options are: AND,OR";
124  }
125 
126  // get the list of prediscriminants
127  std::vector<std::string> prediscriminantsNames =
128  prediscriminantConfig.getParameterNamesForType<edm::ParameterSet>();
129 
130  for (auto const& iDisc : prediscriminantsNames) {
131  const edm::ParameterSet& iPredisc = prediscriminantConfig.getParameter<edm::ParameterSet>(iDisc);
132  const edm::InputTag& label = iPredisc.getParameter<edm::InputTag>("Producer");
133  double cut = iPredisc.getParameter<double>("cut");
134 
135  if (is_online_) {
136  TauDiscInfo<reco::PFTauDiscriminator> thisDiscriminator;
137  thisDiscriminator.label = label;
138  thisDiscriminator.cut = cut;
139  thisDiscriminator.disc_token = consumes<reco::PFTauDiscriminator>(label);
140  recoPrediscriminants_.push_back(thisDiscriminator);
141  } else {
142  TauDiscInfo<pat::PATTauDiscriminator> thisDiscriminator;
143  thisDiscriminator.label = label;
144  thisDiscriminator.cut = cut;
145  thisDiscriminator.disc_token = consumes<pat::PATTauDiscriminator>(label);
146  patPrediscriminants_.push_back(thisDiscriminator);
147  }
148  }
149  }
150 
153  event.getByToken(tausToken_, taus);
154  edm::ProductID tauProductID = taus.id();
155 
156  // load prediscriminators
157  size_t nPrediscriminants =
159  for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
160  edm::ProductID discKeyId;
161  if (is_online_) {
162  recoPrediscriminants_[iDisc].fill(event);
163  discKeyId = recoPrediscriminants_[iDisc].handle->keyProduct().id();
164  } else {
165  patPrediscriminants_[iDisc].fill(event);
166  discKeyId = patPrediscriminants_[iDisc].handle->keyProduct().id();
167  }
168 
169  // Check to make sure the product is correct for the discriminator.
170  // If not, throw a more informative exception.
171  if (tauProductID != discKeyId) {
172  throw cms::Exception("MisconfiguredPrediscriminant")
173  << "The tau collection has product ID: " << tauProductID
174  << " but the pre-discriminator is keyed with product ID: " << discKeyId << std::endl;
175  }
176  }
177 
178  const tensorflow::Tensor& pred = getPredictions(event, taus);
179  createOutputs(event, pred, taus);
180  }
181 
183  for (const auto& output_desc : outputs_) {
184  const WPList* working_points = nullptr;
185  if (workingPoints_.find(output_desc.first) != workingPoints_.end()) {
186  working_points = &workingPoints_.at(output_desc.first);
187  }
188  auto result = output_desc.second.get_value(taus, pred, working_points, is_online_);
189  event.put(std::move(result), output_desc.first);
190  }
191  }
192 
193  std::unique_ptr<DeepTauCache> DeepTauBase::initializeGlobalCache(const edm::ParameterSet& cfg) {
194  const auto graph_name_vector = cfg.getParameter<std::vector<std::string>>("graph_file");
195  std::map<std::string, std::string> graph_names;
196  for (const auto& entry : graph_name_vector) {
197  const size_t sep_pos = entry.find(':');
198  std::string entry_name, graph_file;
199  if (sep_pos != std::string::npos) {
200  entry_name = entry.substr(0, sep_pos);
201  graph_file = entry.substr(sep_pos + 1);
202  } else {
203  entry_name = "";
204  graph_file = entry;
205  }
206  graph_file = edm::FileInPath(graph_file).fullPath();
207  if (graph_names.count(entry_name))
208  throw cms::Exception("DeepTauCache") << "Duplicated graph entries";
209  graph_names[entry_name] = graph_file;
210  }
211  bool mem_mapped = cfg.getParameter<bool>("mem_mapped");
212  return std::make_unique<DeepTauCache>(graph_names, mem_mapped);
213  }
214 
215  DeepTauCache::DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped) {
216  for (const auto& graph_entry : graph_names) {
217  tensorflow::SessionOptions options;
218  tensorflow::setThreading(options, 1);
219 
220  const std::string& entry_name = graph_entry.first;
221  const std::string& graph_file = graph_entry.second;
222  if (mem_mapped) {
223  memmappedEnv_[entry_name] = std::make_unique<tensorflow::MemmappedEnv>(tensorflow::Env::Default());
224  const tensorflow::Status mmap_status = memmappedEnv_.at(entry_name)->InitializeFromFile(graph_file);
225  if (!mmap_status.ok()) {
226  throw cms::Exception("DeepTauCache: unable to initalize memmapped environment for ")
227  << graph_file << ". \n"
228  << mmap_status.ToString();
229  }
230 
231  graphs_[entry_name] = std::make_unique<tensorflow::GraphDef>();
232  const tensorflow::Status load_graph_status =
233  ReadBinaryProto(memmappedEnv_.at(entry_name).get(),
234  tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
235  graphs_.at(entry_name).get());
236  if (!load_graph_status.ok())
237  throw cms::Exception("DeepTauCache: unable to load graph from ") << graph_file << ". \n"
238  << load_graph_status.ToString();
239 
240  options.config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(
241  ::tensorflow::OptimizerOptions::L0);
242  options.env = memmappedEnv_.at(entry_name).get();
243 
244  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
245 
246  } else {
247  graphs_[entry_name].reset(tensorflow::loadGraphDef(graph_file));
248  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
249  }
250  }
251  }
252 
254  for (auto& session_entry : sessions_)
255  tensorflow::closeSession(session_entry.second);
256  }
257 
258 } // namespace deep_tau
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
tuple cfg
Definition: looper.py:296
double pt() const final
transverse momentum
DeepTauCache(const std::map< std::string, std::string > &graph_names, bool mem_mapped)
Definition: DeepTauBase.cc:215
ProductID id() const
Definition: HandleBase.cc:29
std::unique_ptr< TF1 > fn_
Definition: DeepTauBase.h:44
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauBase.h:62
edm::EDGetTokenT< TauCollection > tausToken_
Definition: DeepTauBase.h:129
dictionary working_points
Definition: FWLite.py:122
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
double operator()(const reco::BaseTau &tau, bool isPFTau) const
Definition: DeepTauBase.cc:46
const DeepTauCache * cache_
Definition: DeepTauBase.h:135
std::vector< size_t > num_
Definition: DeepTauBase.h:81
std::vector< Vertex > VertexCollection
Definition: Vertex.h:12
std::vector< size_t > den_
Definition: DeepTauBase.h:81
std::map< std::string, Output > OutputCollection
Definition: DeepTauBase.h:91
std::vector< CutterPtr > WPList
Definition: DeepTauBase.h:78
void produce(edm::Event &event, const edm::EventSetup &es) override
Definition: DeepTauBase.cc:151
std::unique_ptr< TauDiscriminator > get_value(const edm::Handle< TauCollection > &taus, const tensorflow::Tensor &pred, const WPList *working_points, bool is_online) const
Definition: DeepTauBase.cc:60
tuple result
Definition: mps_fire.py:311
OutputCollection outputs_
Definition: DeepTauBase.h:134
std::vector< std::string > getParameterNamesForType(bool trackiness=true) const
Definition: ParameterSet.h:179
char const * label
SiPixelHitStatus Status
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
edm::EDGetTokenT< ConsumeType > disc_token
Definition: DeepTauBase.h:105
TauWPThreshold(const std::string &cut_str)
Definition: DeepTauBase.cc:17
def move
Definition: eostools.py:511
bool closeSession(Session *&session)
Definition: TensorFlow.cc:198
DeepTauBase(const edm::ParameterSet &cfg, const OutputCollection &outputs, const DeepTauCache *cache)
Definition: DeepTauBase.cc:91
std::map< std::string, WPList > workingPoints_
Definition: DeepTauBase.h:132
virtual tensorflow::Tensor getPredictions(edm::Event &event, edm::Handle< TauCollection > taus)=0
virtual void createOutputs(edm::Event &event, const tensorflow::Tensor &pred, edm::Handle< TauCollection > taus)
Definition: DeepTauBase.cc:182
uint8_t andPrediscriminants_
Definition: DeepTauBase.h:111
T getParameter(std::string const &) const
Definition: ParameterSet.h:303
static std::unique_ptr< DeepTauCache > initializeGlobalCache(const edm::ParameterSet &cfg)
Definition: DeepTauBase.cc:193
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauBase.h:61
float x
list entry
Definition: mps_splice.py:68
std::string fullPath() const
Definition: FileInPath.cc:161
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauBase.h:63
std::vector< TauDiscInfo< pat::PATTauDiscriminator > > patPrediscriminants_
Definition: DeepTauBase.h:112
edm::EDGetTokenT< CandidateCollection > pfcandToken_
Definition: DeepTauBase.h:130
std::vector< TauDiscInfo< reco::PFTauDiscriminator > > recoPrediscriminants_
Definition: DeepTauBase.h:113
def cache
Definition: utilities.py:3
double eta() const final
momentum pseudorapidity
unsigned transform(const HcalDetId &id, unsigned transformCode)
edm::EDGetTokenT< reco::VertexCollection > vtxToken_
Definition: DeepTauBase.h:131