CMS 3D CMS Logo

Classes | Typedefs | Enumerations | Functions
tensorflow Namespace Reference

Classes

class  NoThreadPool
 
struct  Options
 
struct  SessionCache
 
class  TBBThreadPool
 

Typedefs

typedef std::pair< std::string, Tensor > NamedTensor
 
typedef std::vector< NamedTensorNamedTensorList
 

Enumerations

enum  Backend {
  Backend::cpu, Backend::cuda, Backend::rocm, Backend::intel,
  Backend::best
}
 

Functions

bool checkEmptyInputs (const NamedTensorList &inputs)
 
bool closeSession (Session *&session)
 
bool closeSession (const Session *&session)
 
SessioncreateSession ()
 
SessioncreateSession (Options &options)
 
SessioncreateSession (const MetaGraphDef *metaGraphDef, const std::string &exportDir, Options &options)
 
SessioncreateSession (const GraphDef *graphDef)
 
SessioncreateSession (const GraphDef *graphDef, Options &options)
 
GraphDef * loadGraphDef (const std::string &pbFile)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag, Options &Options)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag, Options &options)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 

Typedef Documentation

◆ NamedTensor

typedef std::pair<std::string, Tensor> tensorflow::NamedTensor

Definition at line 30 of file TensorFlow.h.

◆ NamedTensorList

Definition at line 31 of file TensorFlow.h.

Enumeration Type Documentation

◆ Backend

enum tensorflow::Backend
strong
Enumerator
cpu 
cuda 
rocm 
intel 
best 

Definition at line 28 of file TensorFlow.h.

Function Documentation

◆ checkEmptyInputs()

bool tensorflow::checkEmptyInputs ( const NamedTensorList inputs)

Definition at line 258 of file TensorFlow.cc.

References input, and PixelMapPlotter::inputs.

Referenced by run().

258  {
259  // check for empty tensors in the inputs
260  bool isEmpty = false;
261  for (const auto& input : inputs) {
262  // Checking using the shape
263  if (input.second.shape().num_elements() == 0) {
264  isEmpty = true;
265  break;
266  }
267  }
268  return isEmpty;
269  }
static std::string const input
Definition: EdmProvDump.cc:50

◆ closeSession() [1/2]

bool tensorflow::closeSession ( Session *&  session)

◆ closeSession() [2/2]

bool tensorflow::closeSession ( const Session *&  session)

Definition at line 248 of file TensorFlow.cc.

References closeSession(), and alignCSCRings::s.

248  {
249  auto s = const_cast<Session*>(session);
250  bool state = closeSession(s);
251 
252  // reset the pointer
253  session = nullptr;
254 
255  return state;
256  }
bool closeSession(Session *&session)
Definition: TensorFlow.cc:233

◆ createSession() [1/5]

Session * tensorflow::createSession ( )

◆ createSession() [2/5]

Session * tensorflow::createSession ( Options options)

Definition at line 141 of file TensorFlow.cc.

References Exception, and mps_update::status.

141  {
142  // objects to create the session
143  Status status;
144 
145  // create a new, empty session
146  Session* session = nullptr;
147  status = NewSession(options.getSessionOptions(), &session);
148  if (!status.ok()) {
149  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
150  }
151 
152  return session;
153  }

◆ createSession() [3/5]

Session * tensorflow::createSession ( const MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
Options options 
)

Definition at line 155 of file TensorFlow.cc.

References createSession(), cms::soa::RestrictQualify::Default, Exception, convertSQLiteXML::ok, mps_update::status, and AlCaHLTBitMon_QueryRunRegistry::string.

155  {
156  // check for valid pointer
157  if (metaGraphDef == nullptr) {
158  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
159  }
160 
161  // check that the graph has nodes
162  if (metaGraphDef->graph_def().node_size() <= 0) {
163  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
164  }
165 
166  Session* session = createSession(options);
167 
168  // add the graph def from the meta graph
169  Status status;
170  status = session->Create(metaGraphDef->graph_def());
171  if (!status.ok()) {
172  throw cms::Exception("InvalidMetaGraphDef")
173  << "error while attaching metaGraphDef to session: " << status.ToString();
174  }
175 
176  // restore variables using the variable and index files in the export directory
177  // first, find names and paths
178  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
179  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
180  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
181  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
182  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
183 
184  // when the index file is missing, there's nothing to do
185  if (!Env::Default()->FileExists(indexFile).ok()) {
186  return session;
187  }
188 
189  // create a tensor to store the variable file
190  Tensor varFileTensor(DT_STRING, TensorShape({}));
191  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
192 
193  // run the restore op
194  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
195  if (!status.ok()) {
196  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
197  }
198 
199  return session;
200  }
Session * createSession()
Definition: TensorFlow.cc:136
constexpr bool Default
Definition: SoACommon.h:75

◆ createSession() [4/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef)

Definition at line 202 of file TensorFlow.cc.

References createSession().

202  {
203  Options default_options{};
204  return createSession(graphDef, default_options);
205  }
Session * createSession()
Definition: TensorFlow.cc:136

◆ createSession() [5/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef,
Options options 
)

Definition at line 207 of file TensorFlow.cc.

References createSession(), Exception, and mps_update::status.

207  {
208  // check for valid pointer
209  if (graphDef == nullptr) {
210  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
211  }
212 
213  // check that the graph has nodes
214  if (graphDef->node_size() <= 0) {
215  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
216  }
217 
218  // create a new, empty session
219  Session* session = createSession(options);
220 
221  // add the graph def
222  Status status;
223  status = session->Create(*graphDef);
224 
225  // check for success
226  if (!status.ok()) {
227  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
228  }
229 
230  return session;
231  }
Session * createSession()
Definition: TensorFlow.cc:136

◆ loadGraphDef()

GraphDef * tensorflow::loadGraphDef ( const std::string &  pbFile)

Definition at line 119 of file TensorFlow.cc.

References cms::soa::RestrictQualify::Default, Exception, and mps_update::status.

Referenced by BaseMVACache::BaseMVACache(), PtAssignmentEngineDxy::configure(), tensorflow::SessionCache::createSession(), deep_tau::DeepTauCache::DeepTauCache(), DisplacedRegionSeedingVertexProducer::DisplacedRegionSeedingVertexProducer(), DTOccupancyTestML::dqmEndLuminosityBlock(), HGCalConcentratorAutoEncoderImpl::HGCalConcentratorAutoEncoderImpl(), L1NNCaloTauProducer::initializeGlobalCache(), L1NNCaloTauEmulator::initializeGlobalCache(), L2TauNNProducerAlpaka::initializeGlobalCache(), L2TauNNProducer::initializeGlobalCache(), reco::DeepSCGraphEvaluation::initTensorFlowGraphAndSession(), egammaTools::EgammaDNNHelper::initTensorFlowGraphs(), L1TrackVertexAssociationProducer::L1TrackVertexAssociationProducer(), TfGraphDefProducer::produce(), TSGForOIDNN::TSGForOIDNN(), emtf::phase2::EMTFContext::update(), and VertexProducer::VertexProducer().

119  {
120  // objects to load the graph
121  Status status;
122 
123  // load it
124  GraphDef* graphDef = new GraphDef();
125  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
126 
127  // check for success
128  if (!status.ok()) {
129  throw cms::Exception("InvalidGraphDef")
130  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
131  }
132 
133  return graphDef;
134  }
constexpr bool Default
Definition: SoACommon.h:75

◆ loadMetaGraph()

MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag,
Options Options 
)

Definition at line 112 of file TensorFlow.cc.

References loadMetaGraphDef(), and makeGlobalPositionRcd_cfg::tag.

112  {
113  edm::LogInfo("PhysicsTools/TensorFlow")
114  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
115 
116  return loadMetaGraphDef(exportDir, tag, options);
117  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:90
Log< level::Info, false > LogInfo

◆ loadMetaGraphDef() [1/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe 
)

Definition at line 90 of file TensorFlow.cc.

References makeGlobalPositionRcd_cfg::tag.

Referenced by loadMetaGraph().

90  {
91  Options default_options{};
92  return loadMetaGraphDef(exportDir, tag, default_options);
93  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:90

◆ loadMetaGraphDef() [2/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag,
Options options 
)

Definition at line 95 of file TensorFlow.cc.

References Exception, mps_update::status, and makeGlobalPositionRcd_cfg::tag.

95  {
96  // objects to load the graph
97  Status status;
98  RunOptions runOptions;
99  SavedModelBundle bundle;
100 
101  // load the model
102  status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
103  if (!status.ok()) {
104  throw cms::Exception("InvalidMetaGraphDef")
105  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
106  }
107 
108  // return a copy of the graph
109  return new MetaGraphDef(bundle.meta_graph_def);
110  }

◆ run() [1/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)

Definition at line 271 of file TensorFlow.cc.

References checkEmptyInputs(), Exception, PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and mps_update::status.

Referenced by emtf::phase2::algo::ParameterAssignmentLayer::apply(), PtAssignmentEngineDxy::call_tensorflow_dxy(), MkFitOutputConverter::computeDNNs(), ticl::PatternRecognitionbyCLUE3D< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyCA< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyFastJet< TILES >::energyRegressionAndID(), TracksterLinksProducer::energyRegressionAndID(), TrackstersMergeProducer::energyRegressionAndID(), egammaTools::EgammaDNNHelper::evaluate(), reco::DeepSCGraphEvaluation::evaluate(), TSGForOIDNN::evaluateClassifier(), TauNNId::EvaluateNN(), JetId::EvaluateNN(), TSGForOIDNN::evaluateRegressor(), DisplacedRegionSeedingVertexProducer::getDiscriminatorValue(), DeepTauId::getPartialPredictions(), DeepTauId::getPredictionsV2(), L2TauNNProducerAlpaka::getTauScore(), L2TauNNProducer::getTauScore(), l1tVertexFinder::VertexFinder::NNVtxEmulation(), L1TrackVertexAssociationProducer::NNTrackWordSelector::operator()(), DeepMETProducer::produce(), L1NNCaloTauProducer::produce(), BaseMVAValueMapProducer< pat::Muon >::produce(), L1NNCaloTauEmulator::produce(), run(), DTOccupancyTestML::runOccupancyTest(), DeepCoreSeedGenerator::SeedEvaluation(), and HGCalConcentratorAutoEncoderImpl::select().

275  {
276  if (session == nullptr) {
277  throw cms::Exception("InvalidSession") << "cannot run empty session";
278  }
279 
280  // create empty run options
281  RunOptions runOptions;
282 
283  // Check if the inputs are empty
285  return;
286 
287  // run and check the status
288  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
289  if (!status.ok()) {
290  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
291  }
292  }
bool checkEmptyInputs(const NamedTensorList &inputs)
Definition: TensorFlow.cc:258

◆ run() [2/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)
inline

Definition at line 119 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

123  {
124  // TF takes a non-const session in the run call which is, however, thread-safe and logically
125  // const, thus const_cast is consistent
126  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
127  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:178

◆ run() [3/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)

Definition at line 294 of file TensorFlow.cc.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

298  {
299  // create thread pool options
300  thread::ThreadPoolOptions threadPoolOptions;
301  threadPoolOptions.inter_op_threadpool = threadPool;
302  threadPoolOptions.intra_op_threadpool = threadPool;
303 
304  // run
305  run(session, inputs, outputNames, outputs, threadPoolOptions);
306  }

◆ run() [4/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)
inline

Definition at line 139 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

143  {
144  // TF takes a non-const session in the run call which is, however, thread-safe and logically
145  // const, thus const_cast is consistent
146  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
147  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:178

◆ run() [5/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 308 of file TensorFlow.cc.

References Exception, PixelMapPlotter::inputs, tensorflow::NoThreadPool::instance(), tensorflow::TBBThreadPool::instance(), jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

312  {
313  // lookup the thread pool and forward the call accordingly
314  if (threadPoolName == "no_threads") {
316  } else if (threadPoolName == "tbb") {
317  // the TBBTreadPool singleton should be already initialized before with a number of threads
319  } else if (threadPoolName == "tensorflow") {
320  run(session, inputs, outputNames, outputs, nullptr);
321  } else {
322  throw cms::Exception("UnknownThreadPool")
323  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
324  }
325  }
static PFTauRenderPlugin instance

◆ run() [6/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 159 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

163  {
164  // TF takes a non-const session in the run call which is, however, thread-safe and logically
165  // const, thus const_cast is consistent
166  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
167  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:178

◆ run() [7/8]

void tensorflow::run ( Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

◆ run() [8/8]

void tensorflow::run ( const Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 178 of file TensorFlow.h.

References jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

181  {
182  // TF takes a non-const session in the run call which is, however, thread-safe and logically
183  // const, thus const_cast is consistent
184  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
185  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:178