CMS 3D CMS Logo

Classes | Typedefs | Enumerations | Functions
tensorflow Namespace Reference

Classes

class  NoThreadPool
 
struct  Options
 
struct  SessionCache
 
class  TBBThreadPool
 

Typedefs

typedef std::pair< std::string, Tensor > NamedTensor
 
typedef std::vector< NamedTensorNamedTensorList
 

Enumerations

enum  Backend {
  Backend::cpu, Backend::cuda, Backend::rocm, Backend::intel,
  Backend::best
}
 

Functions

bool closeSession (Session *&session)
 
bool closeSession (const Session *&session)
 
SessioncreateSession ()
 
SessioncreateSession (Options &options)
 
SessioncreateSession (const MetaGraphDef *metaGraphDef, const std::string &exportDir, Options &options)
 
SessioncreateSession (const GraphDef *graphDef)
 
SessioncreateSession (const GraphDef *graphDef, Options &options)
 
GraphDef * loadGraphDef (const std::string &pbFile)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag, Options &Options)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag, Options &options)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void setLogging (const std::string &level="3")
 

Typedef Documentation

◆ NamedTensor

typedef std::pair<std::string, Tensor> tensorflow::NamedTensor

Definition at line 30 of file TensorFlow.h.

◆ NamedTensorList

Definition at line 31 of file TensorFlow.h.

Enumeration Type Documentation

◆ Backend

enum tensorflow::Backend
strong
Enumerator
cpu 
cuda 
rocm 
intel 
best 

Definition at line 28 of file TensorFlow.h.

Function Documentation

◆ closeSession() [1/2]

bool tensorflow::closeSession ( Session *&  session)

◆ closeSession() [2/2]

bool tensorflow::closeSession ( const Session *&  session)

Definition at line 249 of file TensorFlow.cc.

References closeSession(), and alignCSCRings::s.

249  {
250  auto s = const_cast<Session*>(session);
251  bool state = closeSession(s);
252 
253  // reset the pointer
254  session = nullptr;
255 
256  return state;
257  }
bool closeSession(Session *&session)
Definition: TensorFlow.cc:234

◆ createSession() [1/5]

Session * tensorflow::createSession ( )

◆ createSession() [2/5]

Session * tensorflow::createSession ( Options options)

Definition at line 142 of file TensorFlow.cc.

References Exception, and mps_update::status.

142  {
143  // objects to create the session
144  Status status;
145 
146  // create a new, empty session
147  Session* session = nullptr;
148  status = NewSession(options.getSessionOptions(), &session);
149  if (!status.ok()) {
150  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
151  }
152 
153  return session;
154  }

◆ createSession() [3/5]

Session * tensorflow::createSession ( const MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
Options options 
)

Definition at line 156 of file TensorFlow.cc.

References createSession(), cms::soa::RestrictQualify::Default, Exception, convertSQLiteXML::ok, mps_update::status, and AlCaHLTBitMon_QueryRunRegistry::string.

156  {
157  // check for valid pointer
158  if (metaGraphDef == nullptr) {
159  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
160  }
161 
162  // check that the graph has nodes
163  if (metaGraphDef->graph_def().node_size() <= 0) {
164  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
165  }
166 
167  Session* session = createSession(options);
168 
169  // add the graph def from the meta graph
170  Status status;
171  status = session->Create(metaGraphDef->graph_def());
172  if (!status.ok()) {
173  throw cms::Exception("InvalidMetaGraphDef")
174  << "error while attaching metaGraphDef to session: " << status.ToString();
175  }
176 
177  // restore variables using the variable and index files in the export directory
178  // first, find names and paths
179  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
180  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
181  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
182  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
183  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
184 
185  // when the index file is missing, there's nothing to do
186  if (!Env::Default()->FileExists(indexFile).ok()) {
187  return session;
188  }
189 
190  // create a tensor to store the variable file
191  Tensor varFileTensor(DT_STRING, TensorShape({}));
192  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
193 
194  // run the restore op
195  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
196  if (!status.ok()) {
197  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
198  }
199 
200  return session;
201  }
Session * createSession()
Definition: TensorFlow.cc:137
constexpr bool Default
Definition: SoACommon.h:73

◆ createSession() [4/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef)

Definition at line 203 of file TensorFlow.cc.

References createSession().

203  {
204  Options default_options{};
205  return createSession(graphDef, default_options);
206  }
Session * createSession()
Definition: TensorFlow.cc:137

◆ createSession() [5/5]

Session * tensorflow::createSession ( const GraphDef *  graphDef,
Options options 
)

Definition at line 208 of file TensorFlow.cc.

References createSession(), Exception, and mps_update::status.

208  {
209  // check for valid pointer
210  if (graphDef == nullptr) {
211  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
212  }
213 
214  // check that the graph has nodes
215  if (graphDef->node_size() <= 0) {
216  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
217  }
218 
219  // create a new, empty session
220  Session* session = createSession(options);
221 
222  // add the graph def
223  Status status;
224  status = session->Create(*graphDef);
225 
226  // check for success
227  if (!status.ok()) {
228  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
229  }
230 
231  return session;
232  }
Session * createSession()
Definition: TensorFlow.cc:137

◆ loadGraphDef()

GraphDef * tensorflow::loadGraphDef ( const std::string &  pbFile)

Definition at line 120 of file TensorFlow.cc.

References cms::soa::RestrictQualify::Default, Exception, and mps_update::status.

Referenced by BaseMVACache::BaseMVACache(), PtAssignmentEngineDxy::configure(), tensorflow::SessionCache::createSession(), deep_tau::DeepTauCache::DeepTauCache(), DisplacedRegionSeedingVertexProducer::DisplacedRegionSeedingVertexProducer(), DTOccupancyTestML::dqmEndLuminosityBlock(), HGCalConcentratorAutoEncoderImpl::HGCalConcentratorAutoEncoderImpl(), L2TauNNProducer::initializeGlobalCache(), reco::DeepSCGraphEvaluation::initTensorFlowGraphAndSession(), egammaTools::EgammaDNNHelper::initTensorFlowGraphs(), TfGraphDefProducer::produce(), and TSGForOIDNN::TSGForOIDNN().

120  {
121  // objects to load the graph
122  Status status;
123 
124  // load it
125  GraphDef* graphDef = new GraphDef();
126  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
127 
128  // check for success
129  if (!status.ok()) {
130  throw cms::Exception("InvalidGraphDef")
131  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
132  }
133 
134  return graphDef;
135  }
constexpr bool Default
Definition: SoACommon.h:73

◆ loadMetaGraph()

MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag,
Options Options 
)

Definition at line 113 of file TensorFlow.cc.

References loadMetaGraphDef(), and makeGlobalPositionRcd_cfg::tag.

113  {
114  edm::LogInfo("PhysicsTools/TensorFlow")
115  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
116 
117  return loadMetaGraphDef(exportDir, tag, options);
118  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:91
Log< level::Info, false > LogInfo

◆ loadMetaGraphDef() [1/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe 
)

Definition at line 91 of file TensorFlow.cc.

References makeGlobalPositionRcd_cfg::tag.

Referenced by loadMetaGraph().

91  {
92  Options default_options{};
93  return loadMetaGraphDef(exportDir, tag, default_options);
94  }
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:91

◆ loadMetaGraphDef() [2/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag,
Options options 
)

Definition at line 96 of file TensorFlow.cc.

References Exception, mps_update::status, and makeGlobalPositionRcd_cfg::tag.

96  {
97  // objects to load the graph
98  Status status;
99  RunOptions runOptions;
100  SavedModelBundle bundle;
101 
102  // load the model
103  status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
104  if (!status.ok()) {
105  throw cms::Exception("InvalidMetaGraphDef")
106  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
107  }
108 
109  // return a copy of the graph
110  return new MetaGraphDef(bundle.meta_graph_def);
111  }

◆ run() [1/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)

Definition at line 259 of file TensorFlow.cc.

References Exception, PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and mps_update::status.

Referenced by PtAssignmentEngineDxy::call_tensorflow_dxy(), MkFitOutputConverter::computeDNNs(), ticl::PatternRecognitionbyCLUE3D< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyCA< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyFastJet< TILES >::energyRegressionAndID(), TrackstersMergeProducerV3::energyRegressionAndID(), TrackstersMergeProducer::energyRegressionAndID(), egammaTools::EgammaDNNHelper::evaluate(), reco::DeepSCGraphEvaluation::evaluate(), TSGForOIDNN::evaluateClassifier(), TauNNId::EvaluateNN(), TSGForOIDNN::evaluateRegressor(), DisplacedRegionSeedingVertexProducer::getDiscriminatorValue(), DeepTauId::getPartialPredictions(), DeepTauId::getPredictionsV2(), L2TauNNProducer::getTauScore(), DeepMETProducer::produce(), BaseMVAValueMapProducer< pat::Muon >::produce(), run(), DTOccupancyTestML::runOccupancyTest(), DeepCoreSeedGenerator::SeedEvaluation(), and HGCalConcentratorAutoEncoderImpl::select().

263  {
264  if (session == nullptr) {
265  throw cms::Exception("InvalidSession") << "cannot run empty session";
266  }
267 
268  // create empty run options
269  RunOptions runOptions;
270 
271  // run and check the status
272  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
273  if (!status.ok()) {
274  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
275  }
276  }

◆ run() [2/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)
inline

Definition at line 120 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

124  {
125  // TF takes a non-const session in the run call which is, however, thread-safe and logically
126  // const, thus const_cast is consistent
127  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
128  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:179

◆ run() [3/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)

Definition at line 278 of file TensorFlow.cc.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

282  {
283  // create thread pool options
284  thread::ThreadPoolOptions threadPoolOptions;
285  threadPoolOptions.inter_op_threadpool = threadPool;
286  threadPoolOptions.intra_op_threadpool = threadPool;
287 
288  // run
289  run(session, inputs, outputNames, outputs, threadPoolOptions);
290  }

◆ run() [4/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)
inline

Definition at line 140 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

144  {
145  // TF takes a non-const session in the run call which is, however, thread-safe and logically
146  // const, thus const_cast is consistent
147  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
148  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:179

◆ run() [5/8]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 292 of file TensorFlow.cc.

References Exception, PixelMapPlotter::inputs, tensorflow::NoThreadPool::instance(), tensorflow::TBBThreadPool::instance(), jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

296  {
297  // lookup the thread pool and forward the call accordingly
298  if (threadPoolName == "no_threads") {
300  } else if (threadPoolName == "tbb") {
301  // the TBBTreadPool singleton should be already initialized before with a number of threads
303  } else if (threadPoolName == "tensorflow") {
304  run(session, inputs, outputNames, outputs, nullptr);
305  } else {
306  throw cms::Exception("UnknownThreadPool")
307  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
308  }
309  }
static PFTauRenderPlugin instance

◆ run() [6/8]

void tensorflow::run ( const Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 160 of file TensorFlow.h.

References PixelMapPlotter::inputs, jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

164  {
165  // TF takes a non-const session in the run call which is, however, thread-safe and logically
166  // const, thus const_cast is consistent
167  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
168  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:179

◆ run() [7/8]

void tensorflow::run ( Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

◆ run() [8/8]

void tensorflow::run ( const Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)
inline

Definition at line 179 of file TensorFlow.h.

References jetsAK4_CHS_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

182  {
183  // TF takes a non-const session in the run call which is, however, thread-safe and logically
184  // const, thus const_cast is consistent
185  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
186  }
void run(const Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
Definition: TensorFlow.h:179

◆ setLogging()

void tensorflow::setLogging ( const std::string &  level = "3")

Definition at line 81 of file TensorFlow.cc.

References personalPlayback::level.

Referenced by reco::DeepSCGraphEvaluation::DeepSCGraphEvaluation(), DTOccupancyTestML::dqmEndLuminosityBlock(), HGCalConcentratorAutoEncoderImpl::HGCalConcentratorAutoEncoderImpl(), L1NNTauProducer::initializeGlobalCache(), L2TauNNProducer::initializeGlobalCache(), and TSGForOIDNN::TSGForOIDNN().

81  {
82  /*
83  * 0 = all messages are logged (default behavior)
84  * 1 = INFO messages are not printed
85  * 2 = INFO and WARNING messages are not printed
86  * 3 = INFO, WARNING, and ERROR messages are not printed
87  */
88  setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0);
89  }