CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
Classes | Typedefs | Functions
tensorflow Namespace Reference

Classes

class  NoThreadPool
 
class  TBBThreadPool
 

Typedefs

typedef std::pair< std::string,
Tensor > 
NamedTensor
 
typedef std::vector< NamedTensorNamedTensorList
 

Functions

bool closeSession (Session *&session)
 
SessioncreateSession (SessionOptions &sessionOptions)
 
SessioncreateSession (int nThreads=1)
 
SessioncreateSession (const MetaGraphDef *metaGraphDef, const std::string &exportDir, SessionOptions &sessionOptions)
 
SessioncreateSession (const MetaGraphDef *metaGraphDef, const std::string &exportDir, int nThreads=1)
 
SessioncreateSession (const GraphDef *graphDef, SessionOptions &sessionOptions)
 
SessioncreateSession (const GraphDef *graphDef, int nThreads=1)
 
GraphDef * loadGraphDef (const std::string &pbFile)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag=kSavedModelTagServe, int nThreads=1)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag=kSavedModelTagServe, int nThreads=1)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void setLogging (const std::string &level="3")
 
void setThreading (SessionOptions &sessionOptions, int nThreads=1)
 
void setThreading (SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool)
 

Typedef Documentation

typedef std::pair<std::string, Tensor> tensorflow::NamedTensor

Definition at line 29 of file TensorFlow.h.

Definition at line 30 of file TensorFlow.h.

Function Documentation

bool tensorflow::closeSession ( Session *&  session)

Definition at line 198 of file TensorFlow.cc.

References run_AlCaRecoTriggerBitsUpdateWorkflow::session, and mps_update::status.

Referenced by DTOccupancyTestML::dqmEndLuminosityBlock(), GEDPhotonProducer::endStream(), GsfElectronProducer::endStream(), L2TauNNProducer::globalEndJob(), BaseMVACache::~BaseMVACache(), deep_tau::DeepTauCache::~DeepTauCache(), PtAssignmentEngineDxy::~PtAssignmentEngineDxy(), TauNNId::~TauNNId(), TfGraphDefWrapper::~TfGraphDefWrapper(), and TSGForOIDNN::~TSGForOIDNN().

198  {
199  if (session == nullptr) {
200  return true;
201  }
202 
203  // close and delete the session
204  Status status = session->Close();
205  delete session;
206 
207  // reset the pointer
208  session = nullptr;
209 
210  return status.ok();
211  }
list status
Definition: mps_update.py:107
SiPixelHitStatus Status
Session * tensorflow::createSession ( SessionOptions &  sessionOptions)

Definition at line 85 of file TensorFlow.cc.

References Exception, run_AlCaRecoTriggerBitsUpdateWorkflow::session, and mps_update::status.

Referenced by BaseMVACache::BaseMVACache(), PtAssignmentEngineDxy::configure(), createSession(), deep_tau::DeepTauCache::DeepTauCache(), DTOccupancyTestML::dqmEndLuminosityBlock(), egammaTools::EgammaDNNHelper::getSessions(), HGCalConcentratorAutoEncoderImpl::HGCalConcentratorAutoEncoderImpl(), L2TauNNProducer::initializeGlobalCache(), TfGraphDefProducer::produce(), TauNNId::TauNNId(), and TSGForOIDNN::TSGForOIDNN().

85  {
86  // objects to create the session
87  Status status;
88 
89  // create a new, empty session
90  Session* session = nullptr;
91  status = NewSession(sessionOptions, &session);
92  if (!status.ok()) {
93  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
94  }
95 
96  return session;
97  }
list status
Definition: mps_update.py:107
SiPixelHitStatus Status
Session * tensorflow::createSession ( int  nThreads = 1)

Definition at line 99 of file TensorFlow.cc.

References createSession(), and setThreading().

99  {
100  // create session options and set thread options
101  SessionOptions sessionOptions;
102  setThreading(sessionOptions, nThreads);
103 
104  return createSession(sessionOptions);
105  }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
Session * tensorflow::createSession ( const MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
SessionOptions &  sessionOptions 
)

Definition at line 107 of file TensorFlow.cc.

References createSession(), Exception, convertSQLiteXML::ok, run_AlCaRecoTriggerBitsUpdateWorkflow::session, mps_update::status, and AlCaHLTBitMon_QueryRunRegistry::string.

109  {
110  // check for valid pointer
111  if (metaGraphDef == nullptr) {
112  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
113  }
114 
115  // check that the graph has nodes
116  if (metaGraphDef->graph_def().node_size() <= 0) {
117  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
118  }
119 
120  Session* session = createSession(sessionOptions);
121 
122  // add the graph def from the meta graph
123  Status status;
124  status = session->Create(metaGraphDef->graph_def());
125  if (!status.ok()) {
126  throw cms::Exception("InvalidMetaGraphDef")
127  << "error while attaching metaGraphDef to session: " << status.ToString();
128  }
129 
130  // restore variables using the variable and index files in the export directory
131  // first, find names and paths
132  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
133  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
134  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
135  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
136  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
137 
138  // when the index file is missing, there's nothing to do
139  if (!Env::Default()->FileExists(indexFile).ok()) {
140  return session;
141  }
142 
143  // create a tensor to store the variable file
144  Tensor varFileTensor(DT_STRING, TensorShape({}));
145  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
146 
147  // run the restore op
148  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
149  if (!status.ok()) {
150  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
151  }
152 
153  return session;
154  }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
list status
Definition: mps_update.py:107
SiPixelHitStatus Status
Session * tensorflow::createSession ( const MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
int  nThreads = 1 
)

Definition at line 156 of file TensorFlow.cc.

References createSession(), and setThreading().

156  {
157  // create session options and set thread options
158  SessionOptions sessionOptions;
159  setThreading(sessionOptions, nThreads);
160 
161  return createSession(metaGraphDef, exportDir, sessionOptions);
162  }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
Session * tensorflow::createSession ( const GraphDef *  graphDef,
SessionOptions &  sessionOptions 
)

Definition at line 164 of file TensorFlow.cc.

References createSession(), Exception, run_AlCaRecoTriggerBitsUpdateWorkflow::session, and mps_update::status.

164  {
165  // check for valid pointer
166  if (graphDef == nullptr) {
167  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
168  }
169 
170  // check that the graph has nodes
171  if (graphDef->node_size() <= 0) {
172  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
173  }
174 
175  // create a new, empty session
176  Session* session = createSession(sessionOptions);
177 
178  // add the graph def
179  Status status;
180  status = session->Create(*graphDef);
181 
182  // check for success
183  if (!status.ok()) {
184  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
185  }
186 
187  return session;
188  }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
list status
Definition: mps_update.py:107
SiPixelHitStatus Status
Session * tensorflow::createSession ( const GraphDef *  graphDef,
int  nThreads = 1 
)

Definition at line 190 of file TensorFlow.cc.

References createSession(), and setThreading().

190  {
191  // create session options and set thread options
192  SessionOptions sessionOptions;
193  setThreading(sessionOptions, nThreads);
194 
195  return createSession(graphDef, sessionOptions);
196  }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
GraphDef * tensorflow::loadGraphDef ( const std::string &  pbFile)

Definition at line 68 of file TensorFlow.cc.

References Exception, and mps_update::status.

Referenced by BaseMVACache::BaseMVACache(), PtAssignmentEngineDxy::configure(), deep_tau::DeepTauCache::DeepTauCache(), DTOccupancyTestML::dqmEndLuminosityBlock(), HGCalConcentratorAutoEncoderImpl::HGCalConcentratorAutoEncoderImpl(), DeepMETProducer::initializeGlobalCache(), L1NNTauProducer::initializeGlobalCache(), DeepCoreSeedGenerator::initializeGlobalCache(), L2TauNNProducer::initializeGlobalCache(), egammaTools::EgammaDNNHelper::initTensorFlowGraphs(), TfGraphDefProducer::produce(), and TSGForOIDNN::TSGForOIDNN().

68  {
69  // objects to load the graph
70  Status status;
71 
72  // load it
73  GraphDef* graphDef = new GraphDef();
74  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
75 
76  // check for success
77  if (!status.ok()) {
78  throw cms::Exception("InvalidGraphDef")
79  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
80  }
81 
82  return graphDef;
83  }
list status
Definition: mps_update.py:107
SiPixelHitStatus Status
MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag,
SessionOptions &  sessionOptions 
)

Definition at line 46 of file TensorFlow.cc.

References loadMetaGraphDef().

46  {
47  edm::LogInfo("PhysicsTools/TensorFlow")
48  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
49 
50  return loadMetaGraphDef(exportDir, tag, sessionOptions);
51  }
Log< level::Info, false > LogInfo
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe,
int  nThreads = 1 
)

Definition at line 61 of file TensorFlow.cc.

References loadMetaGraphDef().

61  {
62  edm::LogInfo("PhysicsTools/TensorFlow")
63  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
64 
65  return loadMetaGraphDef(exportDir, tag, nThreads);
66  }
Log< level::Info, false > LogInfo
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag,
SessionOptions &  sessionOptions 
)

Definition at line 29 of file TensorFlow.cc.

References Exception, and mps_update::status.

Referenced by loadMetaGraph(), and loadMetaGraphDef().

29  {
30  // objects to load the graph
31  Status status;
32  RunOptions runOptions;
33  SavedModelBundle bundle;
34 
35  // load the model
36  status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
37  if (!status.ok()) {
38  throw cms::Exception("InvalidMetaGraphDef")
39  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
40  }
41 
42  // return a copy of the graph
43  return new MetaGraphDef(bundle.meta_graph_def);
44  }
list status
Definition: mps_update.py:107
SiPixelHitStatus Status
MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe,
int  nThreads = 1 
)

Definition at line 53 of file TensorFlow.cc.

References loadMetaGraphDef(), and setThreading().

53  {
54  // create session options and set thread options
55  SessionOptions sessionOptions;
56  setThreading(sessionOptions, nThreads);
57 
58  return loadMetaGraphDef(exportDir, tag, sessionOptions);
59  }
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
void tensorflow::run ( Session session,
const NamedTensorList &  inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)

Definition at line 213 of file TensorFlow.cc.

References Exception, and mps_update::status.

Referenced by PtAssignmentEngineDxy::call_tensorflow_dxy(), ticl::PatternRecognitionbyCLUE3D< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyCA< TILES >::energyRegressionAndID(), ticl::PatternRecognitionbyFastJet< TILES >::energyRegressionAndID(), TrackstersMergeProducer::energyRegressionAndID(), egammaTools::EgammaDNNHelper::evaluate(), TSGForOIDNN::evaluateClassifier(), TauNNId::EvaluateNN(), TSGForOIDNN::evaluateRegressor(), DeepTauId::getPartialPredictions(), DPFIsolation::getPredictions(), DeepTauId::getPredictionsV1(), DeepTauId::getPredictionsV2(), L2TauNNProducer::getTauScore(), DeepMETProducer::produce(), BaseMVAValueMapProducer< T >::produce(), run(), DTOccupancyTestML::runOccupancyTest(), DeepCoreSeedGenerator::SeedEvaluation(), and HGCalConcentratorAutoEncoderImpl::select().

217  {
218  if (session == nullptr) {
219  throw cms::Exception("InvalidSession") << "cannot run empty session";
220  }
221 
222  // create empty run options
223  RunOptions runOptions;
224 
225  // run and check the status
226  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
227  if (!status.ok()) {
228  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
229  }
230  }
list status
Definition: mps_update.py:107
SiPixelHitStatus Status
void tensorflow::run ( Session session,
const NamedTensorList &  inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)

Definition at line 232 of file TensorFlow.cc.

References run().

236  {
237  // create thread pool options
238  thread::ThreadPoolOptions threadPoolOptions;
239  threadPoolOptions.inter_op_threadpool = threadPool;
240  threadPoolOptions.intra_op_threadpool = threadPool;
241 
242  // run
243  run(session, inputs, outputNames, outputs, threadPoolOptions);
244  }
void tensorflow::run ( Session session,
const NamedTensorList &  inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 246 of file TensorFlow.cc.

References Exception, tensorflow::NoThreadPool::instance(), tensorflow::TBBThreadPool::instance(), and run().

250  {
251  // lookup the thread pool and forward the call accordingly
252  if (threadPoolName == "no_threads") {
253  run(session, inputs, outputNames, outputs, &NoThreadPool::instance());
254  } else if (threadPoolName == "tbb") {
255  // the TBBTreadPool singleton should be already initialized before with a number of threads
256  run(session, inputs, outputNames, outputs, &TBBThreadPool::instance());
257  } else if (threadPoolName == "tensorflow") {
258  run(session, inputs, outputNames, outputs, nullptr);
259  } else {
260  throw cms::Exception("UnknownThreadPool")
261  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
262  }
263  }
static PFTauRenderPlugin instance
void tensorflow::run ( Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 265 of file TensorFlow.cc.

References run().

268  {
269  run(session, {}, outputNames, outputs, threadPoolName);
270  }
void tensorflow::setLogging ( const std::string &  level = "3")
void tensorflow::setThreading ( SessionOptions &  sessionOptions,
int  nThreads = 1 
)

Definition at line 17 of file TensorFlow.cc.

Referenced by createSession(), deep_tau::DeepTauCache::DeepTauCache(), loadMetaGraphDef(), and setThreading().

17  {
18  // set number of threads used for intra and inter operation communication
19  sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
20  sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
21  }
void tensorflow::setThreading ( SessionOptions &  sessionOptions,
int  nThreads,
const std::string &  singleThreadPool 
)

Definition at line 23 of file TensorFlow.cc.

References setThreading().

23  {
24  edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
25 
26  setThreading(sessionOptions, nThreads);
27  }
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
Log< level::Info, false > LogInfo