CMS 3D CMS Logo

TensorFlow.cc
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * Based on TensorFlow C++ API 2.1.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
10 
12 
13 namespace tensorflow {
14 
15  void setLogging(const std::string& level) { setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0); }
16 
17  void setThreading(SessionOptions& sessionOptions, int nThreads) {
18  // set number of threads used for intra and inter operation communication
19  sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
20  sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
21  }
22 
23  void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool) {
24  edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
25 
26  setThreading(sessionOptions, nThreads);
27  }
28 
29  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
30  // objects to load the graph
31  Status status;
32  RunOptions runOptions;
33  SavedModelBundle bundle;
34 
35  // load the model
36  status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
37  if (!status.ok()) {
38  throw cms::Exception("InvalidMetaGraphDef")
39  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
40  }
41 
42  // return a copy of the graph
43  return new MetaGraphDef(bundle.meta_graph_def);
44  }
45 
46  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
47  edm::LogInfo("PhysicsTools/TensorFlow")
48  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
49 
50  return loadMetaGraphDef(exportDir, tag, sessionOptions);
51  }
52 
53  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
54  // create session options and set thread options
55  SessionOptions sessionOptions;
56  setThreading(sessionOptions, nThreads);
57 
58  return loadMetaGraphDef(exportDir, tag, sessionOptions);
59  }
60 
61  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
62  edm::LogInfo("PhysicsTools/TensorFlow")
63  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
64 
65  return loadMetaGraphDef(exportDir, tag, nThreads);
66  }
67 
68  GraphDef* loadGraphDef(const std::string& pbFile) {
69  // objects to load the graph
70  Status status;
71 
72  // load it
73  GraphDef* graphDef = new GraphDef();
74  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
75 
76  // check for success
77  if (!status.ok()) {
78  throw cms::Exception("InvalidGraphDef")
79  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
80  }
81 
82  return graphDef;
83  }
84 
85  Session* createSession(SessionOptions& sessionOptions) {
86  // objects to create the session
87  Status status;
88 
89  // create a new, empty session
90  Session* session = nullptr;
91  status = NewSession(sessionOptions, &session);
92  if (!status.ok()) {
93  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
94  }
95 
96  return session;
97  }
98 
100  // create session options and set thread options
101  SessionOptions sessionOptions;
102  setThreading(sessionOptions, nThreads);
103 
104  return createSession(sessionOptions);
105  }
106 
107  Session* createSession(MetaGraphDef* metaGraphDef, const std::string& exportDir, SessionOptions& sessionOptions) {
108  // check for valid pointer
109  if (metaGraphDef == nullptr) {
110  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
111  }
112 
113  // check that the graph has nodes
114  if (metaGraphDef->graph_def().node_size() <= 0) {
115  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
116  }
117 
118  Session* session = createSession(sessionOptions);
119 
120  // add the graph def from the meta graph
121  Status status;
122  status = session->Create(metaGraphDef->graph_def());
123  if (!status.ok()) {
124  throw cms::Exception("InvalidMetaGraphDef")
125  << "error while attaching metaGraphDef to session: " << status.ToString();
126  }
127 
128  // restore variables using the variable and index files in the export directory
129  // first, find names and paths
130  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
131  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
132  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
133  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
134  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
135 
136  // when the index file is missing, there's nothing to do
137  if (!Env::Default()->FileExists(indexFile).ok()) {
138  return session;
139  }
140 
141  // create a tensor to store the variable file
142  Tensor varFileTensor(DT_STRING, TensorShape({}));
143  varFileTensor.scalar<std::string>()() = varFile;
144 
145  // run the restore op
146  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
147  if (!status.ok()) {
148  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
149  }
150 
151  return session;
152  }
153 
154  Session* createSession(MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
155  // create session options and set thread options
156  SessionOptions sessionOptions;
157  setThreading(sessionOptions, nThreads);
158 
159  return createSession(metaGraphDef, exportDir, sessionOptions);
160  }
161 
162  Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions) {
163  // check for valid pointer
164  if (graphDef == nullptr) {
165  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
166  }
167 
168  // check that the graph has nodes
169  if (graphDef->node_size() <= 0) {
170  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
171  }
172 
173  // create a new, empty session
174  Session* session = createSession(sessionOptions);
175 
176  // add the graph def
177  Status status;
178  status = session->Create(*graphDef);
179 
180  // check for success
181  if (!status.ok()) {
182  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
183  }
184 
185  return session;
186  }
187 
188  Session* createSession(GraphDef* graphDef, int nThreads) {
189  // create session options and set thread options
190  SessionOptions sessionOptions;
191  setThreading(sessionOptions, nThreads);
192 
193  return createSession(graphDef, sessionOptions);
194  }
195 
196  bool closeSession(Session*& session) {
197  if (session == nullptr) {
198  return true;
199  }
200 
201  // close and delete the session
202  Status status = session->Close();
203  delete session;
204 
205  // reset the pointer
206  session = nullptr;
207 
208  return status.ok();
209  }
210 
211  void run(Session* session,
212  const NamedTensorList& inputs,
213  const std::vector<std::string>& outputNames,
214  std::vector<Tensor>* outputs,
215  const thread::ThreadPoolOptions& threadPoolOptions) {
216  if (session == nullptr) {
217  throw cms::Exception("InvalidSession") << "cannot run empty session";
218  }
219 
220  // create empty run options
221  RunOptions runOptions;
222 
223  // run and check the status
224  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
225  if (!status.ok()) {
226  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
227  }
228  }
229 
230  void run(Session* session,
231  const NamedTensorList& inputs,
232  const std::vector<std::string>& outputNames,
233  std::vector<Tensor>* outputs,
234  thread::ThreadPoolInterface* threadPool) {
235  // create thread pool options
236  thread::ThreadPoolOptions threadPoolOptions;
237  threadPoolOptions.inter_op_threadpool = threadPool;
238  threadPoolOptions.intra_op_threadpool = threadPool;
239 
240  // run
241  run(session, inputs, outputNames, outputs, threadPoolOptions);
242  }
243 
244  void run(Session* session,
245  const NamedTensorList& inputs,
246  const std::vector<std::string>& outputNames,
247  std::vector<Tensor>* outputs,
248  const std::string& threadPoolName) {
249  // lookup the thread pool and forward the call accordingly
250  if (threadPoolName == "no_threads") {
252  } else if (threadPoolName == "tbb") {
253  // the TBBTreadPool singleton should be already initialized before with a number of threads
255  } else if (threadPoolName == "tensorflow") {
256  run(session, inputs, outputNames, outputs, nullptr);
257  } else {
258  throw cms::Exception("UnknownThreadPool")
259  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
260  }
261  }
262 
263  void run(Session* session,
264  const std::vector<std::string>& outputNames,
265  std::vector<Tensor>* outputs,
266  const std::string& threadPoolName) {
267  run(session, {}, outputNames, outputs, threadPoolName);
268  }
269 
270 } // namespace tensorflow
personalPlayback.level
level
Definition: personalPlayback.py:22
tensorflow::createSession
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
jets_cff.singleThreadPool
singleThreadPool
Definition: jets_cff.py:297
TensorFlow.h
MessageLogger.h
mps_update.status
status
Definition: mps_update.py:69
PatBasicFWLiteJetAnalyzer_Selector_cfg.outputs
outputs
Definition: PatBasicFWLiteJetAnalyzer_Selector_cfg.py:48
edm::LogInfo
Definition: MessageLogger.h:254
tensorflow::setThreading
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
btagGenBb_cfi.Status
Status
Definition: btagGenBb_cfi.py:4
convertSQLiteXML.ok
bool ok
Definition: convertSQLiteXML.py:98
tensorflow::NoThreadPool::instance
static NoThreadPool & instance()
Definition: NoThreadPool.h:22
tensorflow::closeSession
bool closeSession(Session *&session)
Definition: TensorFlow.cc:196
Session
GlobalPosition_Frontier_DevDB_cff.tag
tag
Definition: GlobalPosition_Frontier_DevDB_cff.py:11
runTheMatrix.nThreads
nThreads
Definition: runTheMatrix.py:344
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
tensorflow::NamedTensorList
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:30
PixelMapPlotter.inputs
inputs
Definition: PixelMapPlotter.py:490
tensorflow::setLogging
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:15
tensorflow::TBBThreadPool::instance
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
tensorflow::loadGraphDef
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
Exception
Definition: hltDiff.cc:246
tensorflow::loadMetaGraph
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:46
tensorflow::run
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:211
Default
#define Default
Definition: vmac.h:110
tensorflow
Definition: NoThreadPool.h:18
tensorflow::loadMetaGraphDef
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
jets_cff.outputNames
outputNames
Definition: jets_cff.py:294