CMS 3D CMS Logo

Classes | Typedefs | Enumerations | Functions
tensorflow Namespace Reference

Classes

class  NoThreadPool
 
struct  Options
 
struct  SessionCache
 
class  TBBThreadPool
 

Typedefs

typedef std::pair< std::string, Tensor > NamedTensor
 
typedef std::vector< NamedTensorNamedTensorList
 

Enumerations

enum  Backend {
  Backend::cpu, Backend::cuda, Backend::rocm, Backend::intel,
  Backend::best
}
 

Functions

bool checkEmptyInputs (const NamedTensorList &inputs)
 
bool closeSession (Session *&session)
 
bool closeSession (const Session *&session)
 
SessioncreateSession ()
 
SessioncreateSession (Options &options)
 
SessioncreateSession (const MetaGraphDef *metaGraphDef, const std::string &exportDir, Options &options)
 
SessioncreateSession (const GraphDef *graphDef)
 
SessioncreateSession (const GraphDef *graphDef, Options &options)
 
GraphDef * loadGraphDef (const std::string &pbFile)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag, Options &Options)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag, Options &options)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void setLogging (const std::string &level="3")
 

Typedef Documentation

◆ NamedTensor

typedef std::pair<std::string, Tensor> tensorflow::NamedTensor

Definition at line 30 of file TensorFlow.h.

◆ NamedTensorList

Definition at line 31 of file TensorFlow.h.

Enumeration Type Documentation

◆ Backend

enum tensorflow::Backend
strong
Enumerator
cpu 
cuda 
rocm 
intel 
best 

Definition at line 28 of file TensorFlow.h.

Function Documentation

◆ checkEmptyInputs()

bool tensorflow::checkEmptyInputs ( const NamedTensorList inputs)

Definition at line 268 of file TensorFlow.cc.

References input, and PixelMapPlotter::inputs.

Referenced by run().

268  {
269  // check for empty tensors in the inputs
270  bool isEmpty = false;
271  for (const auto& input : inputs) {
272  // Checking using the shape
273  if (input.second.shape().num_elements() == 0) {
274  isEmpty = true;
275  break;
276  }
277  }
278  return isEmpty;
279  }
static std::string const input
Definition: EdmProvDump.cc:50

◆ closeSession() [1/2]

bool tensorflow::closeSession ( Session *&  session)

◆ closeSession() [2/2]

bool tensorflow::closeSession ( const Session *&  session)

Definition at line 258 of file TensorFlow.cc.

References closeSession(), and alignCSCRings::s.

258  {
259  auto s = const_cast<Session*>(session);
260  bool state = closeSession(s);
261 
262  // reset the pointer
263  session = nullptr;
264 
265  return state;
266  }
bool closeSession(Session *&session)
Definition: TensorFlow.cc:243

◆ createSession() [1/5]

Session * tensorflow::createSession ( )

◆ createSession() [2/5]

Session * tensorflow::createSession ( Options options)

Definition at line 151 of file TensorFlow.cc.

References Exception, and mps_update::status.

151  {
152  // objects to create the session
153  Status status;
154 
155  // create a new, empty session
156  Session* session = nullptr;
157  status = NewSession(options.getSessionOptions(), &session);
158  if (!status.ok()) {
159  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
160  }
161 
162  return session;
163  }

◆ createSession() [3/5]

Session * tensorflow::createSession ( const MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
Options options 
)

Definition at line 165 of file TensorFlow.cc.

References createSession(), cms::soa::RestrictQualify::Default, Exception, convertSQLiteXML::ok, mps_update::status, and AlCaHLTBitMon_QueryRunRegistry::string.

165  {
166  // check for valid pointer
167  if (metaGraphDef == nullptr) {
168  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
169  }
170 
171  // check that the graph has nodes
172  if (metaGraphDef->graph_def().node_size() <= 0) {
173  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
174  }
175 
176  Session* session = createSession(options);
177 
178  // add the graph def from the meta graph
179  Status status;
180  status = session->Create(metaGraphDef->graph_def());
181  if (!status.ok()) {
182  throw cms::Exception("InvalidMetaGraphDef")
183  << "error while attaching metaGraphDef to session: " << status.ToString();
184  }
185 
186  // restore variables using the variable and index files in the export directory
187  // first, find names and paths
188  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
189  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
190  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
191  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
192  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
193 
194  // when the index file is missing, there's nothing to do
195  if (!Env::Default()->FileExists(indexFile).ok()) {
196  return session;
197  }
198 
199  // create a tensor to store the variable file
200  Tensor varFileTensor(DT_STRING, TensorShape({}));
201  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
202 
203  // run the restore op
204  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
205  if (!status.ok()) {
206  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
207  }
208 
209  return session;
210  }
Session * createSession()
Definition: TensorFlow.cc:146
constexpr bool Default
Definition: SoACommon.h:75

◆ createSession() [4/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef)

Definition at line 212 of file TensorFlow.cc.

References createSession().

212  {
213  Options default_options{};
214  return createSession(graphDef, default_options);
215  }
Session * createSession()
Definition: TensorFlow.cc:146

◆ createSession() [5/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef,
Options options 
)

Definition at line 217 of file TensorFlow.cc.

References createSession(), Exception, and mps_update::status.

217  {
218  // check for valid pointer
219  if (graphDef == nullptr) {
220  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
221  }
222 
223  // check that the graph has nodes
224  if (graphDef->node_size() <= 0) {
225  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
226  }
227 
228  // create a new, empty session
229  Session* session = createSession(options);
230 
231  // add the graph def
232  Status status;
233  status = session->Create(*graphDef);
234 
235  // check for success
236  if (!status.ok()) {
237  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
238  }
239 
240  return session;
241  }
Session * createSession()
Definition: TensorFlow.cc:146

◆ loadGraphDef()

GraphDef * tensorflow::loadGraphDef ( const std::string &  pbFile)

◆ loadMetaGraph()

MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag,
Options Options 
)

Definition at line 122 of file TensorFlow.cc.

References loadMetaGraphDef(), and makeGlobalPositionRcd_cfg::tag.

122  {
123  edm::LogInfo("PhysicsTools/TensorFlow")
124  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
125 
126  return loadMetaGraphDef(exportDir, tag, options);
127  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:100
Log< level::Info, false > LogInfo

◆ loadMetaGraphDef() [1/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe 
)

Definition at line 100 of file TensorFlow.cc.

References makeGlobalPositionRcd_cfg::tag.

Referenced by loadMetaGraph().

100  {
101  Options default_options{};
102  return loadMetaGraphDef(exportDir, tag, default_options);
103  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:100

◆ loadMetaGraphDef() [2/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag,
Options options 
)

Definition at line 105 of file TensorFlow.cc.

References Exception, mps_update::status, and makeGlobalPositionRcd_cfg::tag.

105  {
106  // objects to load the graph
107  Status status;
108  RunOptions runOptions;
109  SavedModelBundle bundle;
110 
111  // load the model
112  status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
113  if (!status.ok()) {
114  throw cms::Exception("InvalidMetaGraphDef")
115  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
116  }
117 
118  // return a copy of the graph
119  return new MetaGraphDef(bundle.meta_graph_def);
120  }

◆ run() [1/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)

Definition at line 281 of file TensorFlow.cc.

References checkEmptyInputs(), Exception, PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and mps_update::status.

Referenced by PtAssignmentEngineDxy::call_tensorflow_dxy(), MkFitOutputConverter::computeDNNs(), ticl::PatternRecognitionbyCLUE3D< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyCA< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyFastJet< TILES >::energyRegressionAndID(), TrackstersMergeProducerV3::energyRegressionAndID(), TrackstersMergeProducer::energyRegressionAndID(), egammaTools::EgammaDNNHelper::evaluate(), reco::DeepSCGraphEvaluation::evaluate(), TSGForOIDNN::evaluateClassifier(), TauNNId::EvaluateNN(), BJetId::EvaluateNN(), TSGForOIDNN::evaluateRegressor(), DisplacedRegionSeedingVertexProducer::getDiscriminatorValue(), DeepTauId::getPartialPredictions(), DeepTauId::getPredictionsV2(), L2TauNNProducerAlpaka::getTauScore(), L2TauNNProducer::getTauScore(), DeepMETProducer::produce(), BaseMVAValueMapProducer< pat::Muon >::produce(), run(), DTOccupancyTestML::runOccupancyTest(), DeepCoreSeedGenerator::SeedEvaluation(), and HGCalConcentratorAutoEncoderImpl::select().

285  {
286  if (session == nullptr) {
287  throw cms::Exception("InvalidSession") << "cannot run empty session";
288  }
289 
290  // create empty run options
291  RunOptions runOptions;
292 
293  // Check if the inputs are empty
295  return;
296 
297  // run and check the status
298  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
299  if (!status.ok()) {
300  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
301  }
302  }
bool checkEmptyInputs(const NamedTensorList &inputs)
Definition: TensorFlow.cc:268

◆ run() [2/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)
inline

Definition at line 122 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

126  {
127  // TF takes a non-const session in the run call which is, however, thread-safe and logically
128  // const, thus const_cast is consistent
129  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
130  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ run() [3/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)

Definition at line 304 of file TensorFlow.cc.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

308  {
309  // create thread pool options
310  thread::ThreadPoolOptions threadPoolOptions;
311  threadPoolOptions.inter_op_threadpool = threadPool;
312  threadPoolOptions.intra_op_threadpool = threadPool;
313 
314  // run
315  run(session, inputs, outputNames, outputs, threadPoolOptions);
316  }

◆ run() [4/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)
inline

Definition at line 142 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

146  {
147  // TF takes a non-const session in the run call which is, however, thread-safe and logically
148  // const, thus const_cast is consistent
149  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
150  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ run() [5/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 318 of file TensorFlow.cc.

References Exception, PixelMapPlotter::inputs, tensorflow::NoThreadPool::instance(), tensorflow::TBBThreadPool::instance(), jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

322  {
323  // lookup the thread pool and forward the call accordingly
324  if (threadPoolName == "no_threads") {
326  } else if (threadPoolName == "tbb") {
327  // the TBBTreadPool singleton should be already initialized before with a number of threads
329  } else if (threadPoolName == "tensorflow") {
330  run(session, inputs, outputNames, outputs, nullptr);
331  } else {
332  throw cms::Exception("UnknownThreadPool")
333  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
334  }
335  }
static PFTauRenderPlugin instance

◆ run() [6/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 162 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

166  {
167  // TF takes a non-const session in the run call which is, however, thread-safe and logically
168  // const, thus const_cast is consistent
169  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
170  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ run() [7/8]

void tensorflow::run ( Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

◆ run() [8/8]

void tensorflow::run ( const Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 181 of file TensorFlow.h.

References jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

184  {
185  // TF takes a non-const session in the run call which is, however, thread-safe and logically
186  // const, thus const_cast is consistent
187  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
188  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ setLogging()

void tensorflow::setLogging ( const std::string &  level = "3")

Definition at line 90 of file TensorFlow.cc.

References personalPlayback::level.

Referenced by reco::DeepSCGraphEvaluation::DeepSCGraphEvaluation(), DTOccupancyTestML::dqmEndLuminosityBlock(), HGCalConcentratorAutoEncoderImpl::HGCalConcentratorAutoEncoderImpl(), L1NNTauProducer::initializeGlobalCache(), L2TauNNProducerAlpaka::initializeGlobalCache(), L2TauNNProducer::initializeGlobalCache(), and TSGForOIDNN::TSGForOIDNN().

90  {
91  /*
92  * 0 = all messages are logged (default behavior)
93  * 1 = INFO messages are not printed
94  * 2 = INFO and WARNING messages are not printed
95  * 3 = INFO, WARNING, and ERROR messages are not printed
96  */
97  setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0);
98  }