CMS 3D CMS Logo

List of all members | Public Types | Public Member Functions | Private Attributes
deep_tau::DeepTauCache Class Reference

Public Types

using GraphPtr = std::shared_ptr< tensorflow::GraphDef >
 

Public Member Functions

 DeepTauCache (const std::map< std::string, std::string > &graph_names, bool mem_mapped)
 
const tensorflow::GraphDef & getGraph (const std::string &name="") const
 
tensorflow::Session & getSession (const std::string &name="") const
 
 ~DeepTauCache ()
 

Private Attributes

std::map< std::string, GraphPtrgraphs_
 
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
 
std::map< std::string, tensorflow::Session * > sessions_
 

Detailed Description

Definition at line 99 of file DeepTauId.cc.

Member Typedef Documentation

◆ GraphPtr

using deep_tau::DeepTauCache::GraphPtr = std::shared_ptr<tensorflow::GraphDef>

Definition at line 101 of file DeepTauId.cc.

Constructor & Destructor Documentation

◆ DeepTauCache()

deep_tau::DeepTauCache::DeepTauCache ( const std::map< std::string, std::string > &  graph_names,
bool  mem_mapped 
)
inline

Definition at line 103 of file DeepTauId.cc.

References tensorflow::cpu, tensorflow::createSession(), cms::soa::RestrictQualify::Default, Exception, HLT_2023v12_cff::graph_file, graphs_, tensorflow::loadGraphDef(), HLT_2023v12_cff::mem_mapped, memmappedEnv_, AlcaSiPixelAliHarvester0T_cff::options, sessions_, and AlCaHLTBitMon_QueryRunRegistry::string.

103  {
104  for (const auto& graph_entry : graph_names) {
105  // Backend : to be parametrized from the python config
107 
108  const std::string& entry_name = graph_entry.first;
109  const std::string& graph_file = graph_entry.second;
110  if (mem_mapped) {
111  memmappedEnv_[entry_name] = std::make_unique<tensorflow::MemmappedEnv>(tensorflow::Env::Default());
112  const tensorflow::Status mmap_status = memmappedEnv_.at(entry_name)->InitializeFromFile(graph_file);
113  if (!mmap_status.ok()) {
114  throw cms::Exception("DeepTauCache: unable to initalize memmapped environment for ")
115  << graph_file << ". \n"
116  << mmap_status.ToString();
117  }
118 
119  graphs_[entry_name] = std::make_unique<tensorflow::GraphDef>();
120  const tensorflow::Status load_graph_status =
121  ReadBinaryProto(memmappedEnv_.at(entry_name).get(),
122  tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
123  graphs_.at(entry_name).get());
124  if (!load_graph_status.ok())
125  throw cms::Exception("DeepTauCache: unable to load graph from ") << graph_file << ". \n"
126  << load_graph_status.ToString();
127 
128  options.getSessionOptions().config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(
129  ::tensorflow::OptimizerOptions::L0);
130  options.getSessionOptions().env = memmappedEnv_.at(entry_name).get();
131 
132  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
133 
134  } else {
135  graphs_[entry_name].reset(tensorflow::loadGraphDef(graph_file));
136  sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
137  }
138  }
139  };
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauId.cc:152
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:120
Session * createSession()
Definition: TensorFlow.cc:137
constexpr bool Default
Definition: SoACommon.h:73
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauId.cc:151
std::map< std::string, std::unique_ptr< tensorflow::MemmappedEnv > > memmappedEnv_
Definition: DeepTauId.cc:153

◆ ~DeepTauCache()

deep_tau::DeepTauCache::~DeepTauCache ( )
inline

Definition at line 140 of file DeepTauId.cc.

References tensorflow::closeSession(), and sessions_.

140  {
141  for (auto& session_entry : sessions_)
142  tensorflow::closeSession(session_entry.second);
143  }
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauId.cc:152
bool closeSession(Session *&session)
Definition: TensorFlow.cc:234

Member Function Documentation

◆ getGraph()

const tensorflow::GraphDef& deep_tau::DeepTauCache::getGraph ( const std::string &  name = "") const
inline

Definition at line 148 of file DeepTauId.cc.

References graphs_, and Skims_PA_cff::name.

148 { return *graphs_.at(name); }
std::map< std::string, GraphPtr > graphs_
Definition: DeepTauId.cc:151

◆ getSession()

tensorflow::Session& deep_tau::DeepTauCache::getSession ( const std::string &  name = "") const
inline

Definition at line 147 of file DeepTauId.cc.

References Skims_PA_cff::name, and sessions_.

Referenced by DeepTauId::getPartialPredictions(), and DeepTauId::getPredictionsV2().

147 { return *sessions_.at(name); }
std::map< std::string, tensorflow::Session * > sessions_
Definition: DeepTauId.cc:152

Member Data Documentation

◆ graphs_

std::map<std::string, GraphPtr> deep_tau::DeepTauCache::graphs_
private

Definition at line 151 of file DeepTauId.cc.

Referenced by DeepTauCache(), and getGraph().

◆ memmappedEnv_

std::map<std::string, std::unique_ptr<tensorflow::MemmappedEnv> > deep_tau::DeepTauCache::memmappedEnv_
private

Definition at line 153 of file DeepTauId.cc.

Referenced by DeepTauCache().

◆ sessions_

std::map<std::string, tensorflow::Session*> deep_tau::DeepTauCache::sessions_
private

Definition at line 152 of file DeepTauId.cc.

Referenced by DeepTauCache(), getSession(), and ~DeepTauCache().