CMS 3D CMS Logo

Classes | Typedefs | Functions
tensorflow Namespace Reference

Classes

class  NoThreadPool
 
class  TBBThreadPool
 

Typedefs

typedef std::pair< std::string, Tensor > NamedTensor
 
typedef std::vector< NamedTensorNamedTensorList
 

Functions

bool closeSession (Session *&session)
 
SessioncreateSession (GraphDef *graphDef, int nThreads=1)
 
SessioncreateSession (GraphDef *graphDef, SessionOptions &sessionOptions)
 
SessioncreateSession (int nThreads=1)
 
SessioncreateSession (MetaGraphDef *metaGraphDef, const std::string &exportDir, int nThreads=1)
 
SessioncreateSession (MetaGraphDef *metaGraphDef, const std::string &exportDir, SessionOptions &sessionOptions)
 
SessioncreateSession (SessionOptions &sessionOptions)
 
GraphDef * loadGraphDef (const std::string &pbFile)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag=kSavedModelTagServe, int nThreads=1)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
 
MetaGraphDef * loadMetaGraphDef (const std::string &exportDir, const std::string &tag=kSavedModelTagServe, int nThreads=1)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, thread::ThreadPoolInterface *threadPool)
 
void run (Session *session, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const std::string &threadPoolName="no_threads")
 
void setLogging (const std::string &level="3")
 
void setThreading (SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool)
 
void setThreading (SessionOptions &sessionOptions, int nThreads=1)
 

Typedef Documentation

◆ NamedTensor

typedef std::pair<std::string, Tensor> tensorflow::NamedTensor

Definition at line 29 of file TensorFlow.h.

◆ NamedTensorList

Definition at line 30 of file TensorFlow.h.

Function Documentation

◆ closeSession()

bool tensorflow::closeSession ( Session *&  session)

Definition at line 196 of file TensorFlow.cc.

196  {
197  if (session == nullptr) {
198  return true;
199  }
200 
201  // close and delete the session
202  Status status = session->Close();
203  delete session;
204 
205  // reset the pointer
206  session = nullptr;
207 
208  return status.ok();
209  }

References btagGenBb_cfi::Status, and mps_update::status.

Referenced by DTOccupancyTestML::dqmEndLuminosityBlock(), deep_tau::DeepTauCache::~DeepTauCache(), and DeepVertexTFJetTagsProducer::~DeepVertexTFJetTagsProducer().

◆ createSession() [1/6]

Session * tensorflow::createSession ( GraphDef *  graphDef,
int  nThreads = 1 
)

Definition at line 188 of file TensorFlow.cc.

188  {
189  // create session options and set thread options
190  SessionOptions sessionOptions;
191  setThreading(sessionOptions, nThreads);
192 
193  return createSession(graphDef, sessionOptions);
194  }

References createSession(), runTheMatrix::nThreads, and setThreading().

◆ createSession() [2/6]

Session * tensorflow::createSession ( GraphDef *  graphDef,
SessionOptions &  sessionOptions 
)

Definition at line 162 of file TensorFlow.cc.

162  {
163  // check for valid pointer
164  if (graphDef == nullptr) {
165  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
166  }
167 
168  // check that the graph has nodes
169  if (graphDef->node_size() <= 0) {
170  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
171  }
172 
173  // create a new, empty session
174  Session* session = createSession(sessionOptions);
175 
176  // add the graph def
177  Status status;
178  status = session->Create(*graphDef);
179 
180  // check for success
181  if (!status.ok()) {
182  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
183  }
184 
185  return session;
186  }

References createSession(), Exception, btagGenBb_cfi::Status, and mps_update::status.

◆ createSession() [3/6]

Session * tensorflow::createSession ( int  nThreads = 1)

Definition at line 99 of file TensorFlow.cc.

99  {
100  // create session options and set thread options
101  SessionOptions sessionOptions;
102  setThreading(sessionOptions, nThreads);
103 
104  return createSession(sessionOptions);
105  }

References createSession(), runTheMatrix::nThreads, and setThreading().

◆ createSession() [4/6]

Session * tensorflow::createSession ( MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
int  nThreads = 1 
)

Definition at line 154 of file TensorFlow.cc.

154  {
155  // create session options and set thread options
156  SessionOptions sessionOptions;
157  setThreading(sessionOptions, nThreads);
158 
159  return createSession(metaGraphDef, exportDir, sessionOptions);
160  }

References createSession(), runTheMatrix::nThreads, and setThreading().

◆ createSession() [5/6]

Session * tensorflow::createSession ( MetaGraphDef *  metaGraphDef,
const std::string &  exportDir,
SessionOptions &  sessionOptions 
)

Definition at line 107 of file TensorFlow.cc.

107  {
108  // check for valid pointer
109  if (metaGraphDef == nullptr) {
110  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
111  }
112 
113  // check that the graph has nodes
114  if (metaGraphDef->graph_def().node_size() <= 0) {
115  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
116  }
117 
118  Session* session = createSession(sessionOptions);
119 
120  // add the graph def from the meta graph
121  Status status;
122  status = session->Create(metaGraphDef->graph_def());
123  if (!status.ok()) {
124  throw cms::Exception("InvalidMetaGraphDef")
125  << "error while attaching metaGraphDef to session: " << status.ToString();
126  }
127 
128  // restore variables using the variable and index files in the export directory
129  // first, find names and paths
130  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
131  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
132  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
133  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
134  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
135 
136  // when the index file is missing, there's nothing to do
137  if (!Env::Default()->FileExists(indexFile).ok()) {
138  return session;
139  }
140 
141  // create a tensor to store the variable file
142  Tensor varFileTensor(DT_STRING, TensorShape({}));
143  varFileTensor.scalar<std::string>()() = varFile;
144 
145  // run the restore op
146  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
147  if (!status.ok()) {
148  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
149  }
150 
151  return session;
152  }

References createSession(), Default, Exception, convertSQLiteXML::ok, btagGenBb_cfi::Status, mps_update::status, and AlCaHLTBitMon_QueryRunRegistry::string.

◆ createSession() [6/6]

Session * tensorflow::createSession ( SessionOptions &  sessionOptions)

Definition at line 85 of file TensorFlow.cc.

85  {
86  // objects to create the session
87  Status status;
88 
89  // create a new, empty session
90  Session* session = nullptr;
91  status = NewSession(sessionOptions, &session);
92  if (!status.ok()) {
93  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
94  }
95 
96  return session;
97  }

References Exception, btagGenBb_cfi::Status, and mps_update::status.

Referenced by BaseMVAValueMapProducer< pat::Jet >::BaseMVAValueMapProducer(), createSession(), deep_tau::DeepTauCache::DeepTauCache(), DeepVertexTFJetTagsProducer::DeepVertexTFJetTagsProducer(), DTOccupancyTestML::dqmEndLuminosityBlock(), ticl::PatternRecognitionbyCA::PatternRecognitionbyCA(), and TrackstersMergeProducer::TrackstersMergeProducer().

◆ loadGraphDef()

GraphDef * tensorflow::loadGraphDef ( const std::string &  pbFile)

Definition at line 68 of file TensorFlow.cc.

68  {
69  // objects to load the graph
70  Status status;
71 
72  // load it
73  GraphDef* graphDef = new GraphDef();
74  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
75 
76  // check for success
77  if (!status.ok()) {
78  throw cms::Exception("InvalidGraphDef")
79  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
80  }
81 
82  return graphDef;
83  }

References Default, Exception, btagGenBb_cfi::Status, and mps_update::status.

Referenced by BaseMVAValueMapProducer< pat::Jet >::BaseMVAValueMapProducer(), deep_tau::DeepTauCache::DeepTauCache(), DTOccupancyTestML::dqmEndLuminosityBlock(), DeepMETProducer::initializeGlobalCache(), TrackstersMergeProducer::initializeGlobalCache(), TrackstersProducer::initializeGlobalCache(), and DeepVertexTFJetTagsProducer::initializeGlobalCache().

◆ loadMetaGraph() [1/2]

MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag,
SessionOptions &  sessionOptions 
)

Definition at line 46 of file TensorFlow.cc.

46  {
47  edm::LogInfo("PhysicsTools/TensorFlow")
48  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
49 
50  return loadMetaGraphDef(exportDir, tag, sessionOptions);
51  }

References loadMetaGraphDef(), and GlobalPosition_Frontier_DevDB_cff::tag.

◆ loadMetaGraph() [2/2]

MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe,
int  nThreads = 1 
)

Definition at line 61 of file TensorFlow.cc.

61  {
62  edm::LogInfo("PhysicsTools/TensorFlow")
63  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
64 
65  return loadMetaGraphDef(exportDir, tag, nThreads);
66  }

References loadMetaGraphDef(), runTheMatrix::nThreads, and GlobalPosition_Frontier_DevDB_cff::tag.

◆ loadMetaGraphDef() [1/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag,
SessionOptions &  sessionOptions 
)

Definition at line 29 of file TensorFlow.cc.

29  {
30  // objects to load the graph
31  Status status;
32  RunOptions runOptions;
33  SavedModelBundle bundle;
34 
35  // load the model
36  status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
37  if (!status.ok()) {
38  throw cms::Exception("InvalidMetaGraphDef")
39  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
40  }
41 
42  // return a copy of the graph
43  return new MetaGraphDef(bundle.meta_graph_def);
44  }

References Exception, btagGenBb_cfi::Status, mps_update::status, and GlobalPosition_Frontier_DevDB_cff::tag.

Referenced by loadMetaGraph(), and loadMetaGraphDef().

◆ loadMetaGraphDef() [2/2]

MetaGraphDef * tensorflow::loadMetaGraphDef ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe,
int  nThreads = 1 
)

Definition at line 53 of file TensorFlow.cc.

53  {
54  // create session options and set thread options
55  SessionOptions sessionOptions;
56  setThreading(sessionOptions, nThreads);
57 
58  return loadMetaGraphDef(exportDir, tag, sessionOptions);
59  }

References loadMetaGraphDef(), runTheMatrix::nThreads, setThreading(), and GlobalPosition_Frontier_DevDB_cff::tag.

◆ run() [1/4]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 244 of file TensorFlow.cc.

248  {
249  // lookup the thread pool and forward the call accordingly
250  if (threadPoolName == "no_threads") {
252  } else if (threadPoolName == "tbb") {
253  // the TBBTreadPool singleton should be already initialized before with a number of threads
255  } else if (threadPoolName == "tensorflow") {
256  run(session, inputs, outputNames, outputs, nullptr);
257  } else {
258  throw cms::Exception("UnknownThreadPool")
259  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
260  }
261  }

References Exception, PixelMapPlotter::inputs, tensorflow::NoThreadPool::instance(), tensorflow::TBBThreadPool::instance(), jets_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

◆ run() [2/4]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const thread::ThreadPoolOptions &  threadPoolOptions 
)

Definition at line 211 of file TensorFlow.cc.

215  {
216  if (session == nullptr) {
217  throw cms::Exception("InvalidSession") << "cannot run empty session";
218  }
219 
220  // create empty run options
221  RunOptions runOptions;
222 
223  // run and check the status
224  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
225  if (!status.ok()) {
226  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
227  }
228  }

References Exception, PixelMapPlotter::inputs, jets_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, btagGenBb_cfi::Status, and mps_update::status.

Referenced by ticl::PatternRecognitionbyCA::energyRegressionAndID(), TrackstersMergeProducer::energyRegressionAndID(), DeepTauId::getPartialPredictions(), DPFIsolation::getPredictions(), DeepTauId::getPredictionsV1(), DeepTauId::getPredictionsV2(), DeepMETProducer::produce(), DeepVertexTFJetTagsProducer::produce(), BaseMVAValueMapProducer< pat::Jet >::produce(), run(), and DTOccupancyTestML::runOccupancyTest().

◆ run() [3/4]

void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
thread::ThreadPoolInterface *  threadPool 
)

Definition at line 230 of file TensorFlow.cc.

234  {
235  // create thread pool options
236  thread::ThreadPoolOptions threadPoolOptions;
237  threadPoolOptions.inter_op_threadpool = threadPool;
238  threadPoolOptions.intra_op_threadpool = threadPool;
239 
240  // run
241  run(session, inputs, outputNames, outputs, threadPoolOptions);
242  }

References PixelMapPlotter::inputs, jets_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

◆ run() [4/4]

void tensorflow::run ( Session session,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs,
const std::string &  threadPoolName = "no_threads" 
)

Definition at line 263 of file TensorFlow.cc.

266  {
267  run(session, {}, outputNames, outputs, threadPoolName);
268  }

References jets_cff::outputNames, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

◆ setLogging()

void tensorflow::setLogging ( const std::string &  level = "3")

◆ setThreading() [1/2]

void tensorflow::setThreading ( SessionOptions &  sessionOptions,
int  nThreads,
const std::string &  singleThreadPool 
)

Definition at line 23 of file TensorFlow.cc.

23  {
24  edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
25 
26  setThreading(sessionOptions, nThreads);
27  }

References runTheMatrix::nThreads, and setThreading().

◆ setThreading() [2/2]

void tensorflow::setThreading ( SessionOptions &  sessionOptions,
int  nThreads = 1 
)

Definition at line 17 of file TensorFlow.cc.

17  {
18  // set number of threads used for intra and inter operation communication
19  sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
20  sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
21  }

References runTheMatrix::nThreads.

Referenced by createSession(), deep_tau::DeepTauCache::DeepTauCache(), loadMetaGraphDef(), and setThreading().

personalPlayback.level
level
Definition: personalPlayback.py:22
tensorflow::createSession
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
mps_update.status
status
Definition: mps_update.py:69
PatBasicFWLiteJetAnalyzer_Selector_cfg.outputs
outputs
Definition: PatBasicFWLiteJetAnalyzer_Selector_cfg.py:48
edm::LogInfo
Definition: MessageLogger.h:254
tensorflow::setThreading
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
btagGenBb_cfi.Status
Status
Definition: btagGenBb_cfi.py:4
convertSQLiteXML.ok
bool ok
Definition: convertSQLiteXML.py:98
Session
GlobalPosition_Frontier_DevDB_cff.tag
tag
Definition: GlobalPosition_Frontier_DevDB_cff.py:11
runTheMatrix.nThreads
nThreads
Definition: runTheMatrix.py:344
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
PixelMapPlotter.inputs
inputs
Definition: PixelMapPlotter.py:490
instance
static PFTauRenderPlugin instance
Definition: PFTauRenderPlugin.cc:70
writedatasetfile.run
run
Definition: writedatasetfile.py:27
Exception
Definition: hltDiff.cc:246
Default
#define Default
Definition: vmac.h:110
tensorflow::loadMetaGraphDef
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
jets_cff.outputNames
outputNames
Definition: jets_cff.py:294