CMS 3D CMS Logo

Classes | Typedefs | Enumerations | Functions
tensorflow Namespace Reference

Classes

class  NoThreadPool
 
struct  Options
 
struct  SessionCache
 
class  TBBThreadPool
 

Typedefs

typedef std::pair< std::string, Tensor > NamedTensor
 
typedef std::vector< NamedTensorNamedTensorList
 

Enumerations

enum  Backend {
  Backend::cpu, Backend::cuda, Backend::rocm, Backend::intel,
  Backend::best
}
 

Functions

bool checkEmptyInputs (const NamedTensorList &inputs)
 
bool closeSession (Session *&session)
 
bool closeSession (const Session *&session)
 
SessioncreateSession ()
 
SessioncreateSession (Options &options)
 
SessioncreateSession (const MetaGraphDef *metaGraphDef, const std::string &exportDir, Options &options)
 
SessioncreateSession (const GraphDef *graphDef)
 
SessioncreateSession (const GraphDef *graphDef, Options &options)
 
GraphDef * loadGraphDef (const std::string &pbFile)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag, Options &Options)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag, Options &options)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void setLogging (const std::string &level="3")
 

Typedef Documentation

◆ NamedTensor

typedef std::pair<std::string, Tensor> tensorflow::NamedTensor

Definition at line 30 of file TensorFlow.h.

◆ NamedTensorList

Definition at line 31 of file TensorFlow.h.

Enumeration Type Documentation

◆ Backend

enum tensorflow::Backend
strong
Enumerator
cpu 
cuda 
rocm 
intel 
best 

Definition at line 28 of file TensorFlow.h.

Function Documentation

◆ checkEmptyInputs()

bool tensorflow::checkEmptyInputs ( const NamedTensorList inputs)

Definition at line 259 of file TensorFlow.cc.

References input, and PixelMapPlotter::inputs.

Referenced by run().

259  {
260  // check for empty tensors in the inputs
261  bool isEmpty = false;
262  for (const auto& input : inputs) {
263  // Checking using the shape
264  if (input.second.shape().num_elements() == 0) {
265  isEmpty = true;
266  break;
267  }
268  }
269  return isEmpty;
270  }
static std::string const input
Definition: EdmProvDump.cc:50

◆ closeSession() [1/2]

bool tensorflow::closeSession ( Session *&  session)

◆ closeSession() [2/2]

bool tensorflow::closeSession ( const Session *&  session)

Definition at line 249 of file TensorFlow.cc.

References closeSession(), and alignCSCRings::s.

249  {
250  auto s = const_cast<Session*>(session);
251  bool state = closeSession(s);
252 
253  // reset the pointer
254  session = nullptr;
255 
256  return state;
257  }
bool closeSession(Session *&session)
Definition: TensorFlow.cc:234

◆ createSession() [1/5]

Session * tensorflow::createSession ( )

◆ createSession() [2/5]

Session * tensorflow::createSession ( Options options)

Definition at line 142 of file TensorFlow.cc.

References Exception, and mps_update::status.

142  {
143  // objects to create the session
144  Status status;
145 
146  // create a new, empty session
147  Session* session = nullptr;
148  status = NewSession(options.getSessionOptions(), &session);
149  if (!status.ok()) {
150  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
151  }
152 
153  return session;
154  }

◆ createSession() [3/5]

Session * tensorflow::createSession ( const MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
Options options 
)

Definition at line 156 of file TensorFlow.cc.

References createSession(), cms::soa::RestrictQualify::Default, Exception, convertSQLiteXML::ok, mps_update::status, and AlCaHLTBitMon_QueryRunRegistry::string.

156  {
157  // check for valid pointer
158  if (metaGraphDef == nullptr) {
159  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
160  }
161 
162  // check that the graph has nodes
163  if (metaGraphDef->graph_def().node_size() <= 0) {
164  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
165  }
166 
167  Session* session = createSession(options);
168 
169  // add the graph def from the meta graph
170  Status status;
171  status = session->Create(metaGraphDef->graph_def());
172  if (!status.ok()) {
173  throw cms::Exception("InvalidMetaGraphDef")
174  << "error while attaching metaGraphDef to session: " << status.ToString();
175  }
176 
177  // restore variables using the variable and index files in the export directory
178  // first, find names and paths
179  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
180  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
181  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
182  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
183  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
184 
185  // when the index file is missing, there's nothing to do
186  if (!Env::Default()->FileExists(indexFile).ok()) {
187  return session;
188  }
189 
190  // create a tensor to store the variable file
191  Tensor varFileTensor(DT_STRING, TensorShape({}));
192  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
193 
194  // run the restore op
195  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
196  if (!status.ok()) {
197  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
198  }
199 
200  return session;
201  }
Session * createSession()
Definition: TensorFlow.cc:137
constexpr bool Default
Definition: SoACommon.h:75

◆ createSession() [4/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef)

Definition at line 203 of file TensorFlow.cc.

References createSession().

203  {
204  Options default_options{};
205  return createSession(graphDef, default_options);
206  }
Session * createSession()
Definition: TensorFlow.cc:137

◆ createSession() [5/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef,
Options options 
)

Definition at line 208 of file TensorFlow.cc.

References createSession(), Exception, and mps_update::status.

208  {
209  // check for valid pointer
210  if (graphDef == nullptr) {
211  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
212  }
213 
214  // check that the graph has nodes
215  if (graphDef->node_size() <= 0) {
216  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
217  }
218 
219  // create a new, empty session
220  Session* session = createSession(options);
221 
222  // add the graph def
223  Status status;
224  status = session->Create(*graphDef);
225 
226  // check for success
227  if (!status.ok()) {
228  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
229  }
230 
231  return session;
232  }
Session * createSession()
Definition: TensorFlow.cc:137

◆ loadGraphDef()

GraphDef * tensorflow::loadGraphDef ( const std::string &  pbFile)

◆ loadMetaGraph()

MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag,
Options Options 
)

Definition at line 113 of file TensorFlow.cc.

References loadMetaGraphDef(), and makeGlobalPositionRcd_cfg::tag.

113  {
114  edm::LogInfo("PhysicsTools/TensorFlow")
115  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
116 
117  return loadMetaGraphDef(exportDir, tag, options);
118  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:91
Log< level::Info, false > LogInfo

◆ loadMetaGraphDef() [1/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe 
)

Definition at line 91 of file TensorFlow.cc.

References makeGlobalPositionRcd_cfg::tag.

Referenced by loadMetaGraph().

91  {
92  Options default_options{};
93  return loadMetaGraphDef(exportDir, tag, default_options);
94  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:91

◆ loadMetaGraphDef() [2/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag,
Options options 
)

Definition at line 96 of file TensorFlow.cc.

References Exception, mps_update::status, and makeGlobalPositionRcd_cfg::tag.

96  {
97  // objects to load the graph
98  Status status;
99  RunOptions runOptions;
100  SavedModelBundle bundle;
101 
102  // load the model
103  status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
104  if (!status.ok()) {
105  throw cms::Exception("InvalidMetaGraphDef")
106  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
107  }
108 
109  // return a copy of the graph
110  return new MetaGraphDef(bundle.meta_graph_def);
111  }

◆ run() [1/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)

Definition at line 272 of file TensorFlow.cc.

References checkEmptyInputs(), Exception, PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and mps_update::status.

Referenced by PtAssignmentEngineDxy::call_tensorflow_dxy(), MkFitOutputConverter::computeDNNs(), ticl::PatternRecognitionbyCLUE3D< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyCA< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyFastJet< TILES >::energyRegressionAndID(), TrackstersMergeProducerV3::energyRegressionAndID(), TrackstersMergeProducer::energyRegressionAndID(), egammaTools::EgammaDNNHelper::evaluate(), reco::DeepSCGraphEvaluation::evaluate(), TSGForOIDNN::evaluateClassifier(), TauNNId::EvaluateNN(), BJetId::EvaluateNN(), TSGForOIDNN::evaluateRegressor(), DisplacedRegionSeedingVertexProducer::getDiscriminatorValue(), DeepTauId::getPartialPredictions(), DeepTauId::getPredictionsV2(), L2TauNNProducerAlpaka::getTauScore(), L2TauNNProducer::getTauScore(), DeepMETProducer::produce(), BaseMVAValueMapProducer< pat::Muon >::produce(), run(), DTOccupancyTestML::runOccupancyTest(), DeepCoreSeedGenerator::SeedEvaluation(), and HGCalConcentratorAutoEncoderImpl::select().

276  {
277  if (session == nullptr) {
278  throw cms::Exception("InvalidSession") << "cannot run empty session";
279  }
280 
281  // create empty run options
282  RunOptions runOptions;
283 
284  // Check if the inputs are empty
286  return;
287 
288  // run and check the status
289  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
290  if (!status.ok()) {
291  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
292  }
293  }
bool checkEmptyInputs(const NamedTensorList &inputs)
Definition: TensorFlow.cc:259

◆ run() [2/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)
inline

Definition at line 122 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

126  {
127  // TF takes a non-const session in the run call which is, however, thread-safe and logically
128  // const, thus const_cast is consistent
129  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
130  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ run() [3/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)

Definition at line 295 of file TensorFlow.cc.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

299  {
300  // create thread pool options
301  thread::ThreadPoolOptions threadPoolOptions;
302  threadPoolOptions.inter_op_threadpool = threadPool;
303  threadPoolOptions.intra_op_threadpool = threadPool;
304 
305  // run
306  run(session, inputs, outputNames, outputs, threadPoolOptions);
307  }

◆ run() [4/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)
inline

Definition at line 142 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

146  {
147  // TF takes a non-const session in the run call which is, however, thread-safe and logically
148  // const, thus const_cast is consistent
149  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
150  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ run() [5/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 309 of file TensorFlow.cc.

References Exception, PixelMapPlotter::inputs, tensorflow::NoThreadPool::instance(), tensorflow::TBBThreadPool::instance(), jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

313  {
314  // lookup the thread pool and forward the call accordingly
315  if (threadPoolName == "no_threads") {
317  } else if (threadPoolName == "tbb") {
318  // the TBBTreadPool singleton should be already initialized before with a number of threads
320  } else if (threadPoolName == "tensorflow") {
321  run(session, inputs, outputNames, outputs, nullptr);
322  } else {
323  throw cms::Exception("UnknownThreadPool")
324  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
325  }
326  }
static PFTauRenderPlugin instance

◆ run() [6/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 162 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

166  {
167  // TF takes a non-const session in the run call which is, however, thread-safe and logically
168  // const, thus const_cast is consistent
169  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
170  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ run() [7/8]

void tensorflow::run ( Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

◆ run() [8/8]

void tensorflow::run ( const Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 181 of file TensorFlow.h.

References jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

184  {
185  // TF takes a non-const session in the run call which is, however, thread-safe and logically
186  // const, thus const_cast is consistent
187  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
188  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:181

◆ setLogging()

void tensorflow::setLogging ( const std::string &  level = "3")

Definition at line 81 of file TensorFlow.cc.

References personalPlayback::level.

Referenced by reco::DeepSCGraphEvaluation::DeepSCGraphEvaluation(), DTOccupancyTestML::dqmEndLuminosityBlock(), HGCalConcentratorAutoEncoderImpl::HGCalConcentratorAutoEncoderImpl(), L1NNTauProducer::initializeGlobalCache(), L2TauNNProducerAlpaka::initializeGlobalCache(), L2TauNNProducer::initializeGlobalCache(), and TSGForOIDNN::TSGForOIDNN().

81  {
82  /*
83  * 0 = all messages are logged (default behavior)
84  * 1 = INFO messages are not printed
85  * 2 = INFO and WARNING messages are not printed
86  * 3 = INFO, WARNING, and ERROR messages are not printed
87  */
88  setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0);
89  }