CMS 3D CMS Logo

TensorFlow.cc
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
4  *
5  * Author: Marcel Rieger
6  */
7 
9 
11 
12 namespace tensorflow {
13 
14  void setLogging(const std::string& level) { setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0); }
15 
16  void setThreading(SessionOptions& sessionOptions, int nThreads) {
17  // set number of threads used for intra and inter operation communication
18  sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
19  sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
20  }
21 
22  void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool) {
23  edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
24 
25  setThreading(sessionOptions, nThreads);
26  }
27 
28  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
29  // objects to load the graph
30  Status status;
31  RunOptions runOptions;
32  SavedModelBundle bundle;
33 
34  // load the model
35  status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
36  if (!status.ok()) {
37  throw cms::Exception("InvalidMetaGraphDef")
38  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
39  }
40 
41  // return a copy of the graph
42  return new MetaGraphDef(bundle.meta_graph_def);
43  }
44 
45  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
46  edm::LogInfo("PhysicsTools/TensorFlow")
47  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
48 
49  return loadMetaGraphDef(exportDir, tag, sessionOptions);
50  }
51 
52  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
53  // create session options and set thread options
54  SessionOptions sessionOptions;
55  setThreading(sessionOptions, nThreads);
56 
57  return loadMetaGraphDef(exportDir, tag, sessionOptions);
58  }
59 
60  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
61  edm::LogInfo("PhysicsTools/TensorFlow")
62  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
63 
64  return loadMetaGraphDef(exportDir, tag, nThreads);
65  }
66 
67  GraphDef* loadGraphDef(const std::string& pbFile) {
68  // objects to load the graph
69  Status status;
70 
71  // load it
72  GraphDef* graphDef = new GraphDef();
73  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
74 
75  // check for success
76  if (!status.ok()) {
77  throw cms::Exception("InvalidGraphDef")
78  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
79  }
80 
81  return graphDef;
82  }
83 
84  Session* createSession(SessionOptions& sessionOptions) {
85  // objects to create the session
86  Status status;
87 
88  // create a new, empty session
89  Session* session = nullptr;
90  status = NewSession(sessionOptions, &session);
91  if (!status.ok()) {
92  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
93  }
94 
95  return session;
96  }
97 
98  Session* createSession(int nThreads) {
99  // create session options and set thread options
100  SessionOptions sessionOptions;
101  setThreading(sessionOptions, nThreads);
102 
103  return createSession(sessionOptions);
104  }
105 
106  Session* createSession(const MetaGraphDef* metaGraphDef,
107  const std::string& exportDir,
108  SessionOptions& sessionOptions) {
109  // check for valid pointer
110  if (metaGraphDef == nullptr) {
111  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
112  }
113 
114  // check that the graph has nodes
115  if (metaGraphDef->graph_def().node_size() <= 0) {
116  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
117  }
118 
119  Session* session = createSession(sessionOptions);
120 
121  // add the graph def from the meta graph
122  Status status;
123  status = session->Create(metaGraphDef->graph_def());
124  if (!status.ok()) {
125  throw cms::Exception("InvalidMetaGraphDef")
126  << "error while attaching metaGraphDef to session: " << status.ToString();
127  }
128 
129  // restore variables using the variable and index files in the export directory
130  // first, find names and paths
131  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
132  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
133  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
134  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
135  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
136 
137  // when the index file is missing, there's nothing to do
138  if (!Env::Default()->FileExists(indexFile).ok()) {
139  return session;
140  }
141 
142  // create a tensor to store the variable file
143  Tensor varFileTensor(DT_STRING, TensorShape({}));
144  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
145 
146  // run the restore op
147  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
148  if (!status.ok()) {
149  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
150  }
151 
152  return session;
153  }
154 
155  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
156  // create session options and set thread options
157  SessionOptions sessionOptions;
158  setThreading(sessionOptions, nThreads);
159 
160  return createSession(metaGraphDef, exportDir, sessionOptions);
161  }
162 
163  Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
164  // check for valid pointer
165  if (graphDef == nullptr) {
166  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
167  }
168 
169  // check that the graph has nodes
170  if (graphDef->node_size() <= 0) {
171  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
172  }
173 
174  // create a new, empty session
175  Session* session = createSession(sessionOptions);
176 
177  // add the graph def
178  Status status;
179  status = session->Create(*graphDef);
180 
181  // check for success
182  if (!status.ok()) {
183  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
184  }
185 
186  return session;
187  }
188 
189  Session* createSession(const GraphDef* graphDef, int nThreads) {
190  // create session options and set thread options
191  SessionOptions sessionOptions;
192  setThreading(sessionOptions, nThreads);
193 
194  return createSession(graphDef, sessionOptions);
195  }
196 
197  bool closeSession(Session*& session) {
198  if (session == nullptr) {
199  return true;
200  }
201 
202  // close and delete the session
203  Status status = session->Close();
204  delete session;
205 
206  // reset the pointer
207  session = nullptr;
208 
209  return status.ok();
210  }
211 
212  bool closeSession(const Session*& session) {
213  auto s = const_cast<Session*>(session);
214  bool state = closeSession(s);
215 
216  // reset the pointer
217  session = nullptr;
218 
219  return state;
220  }
221 
222  void run(Session* session,
223  const NamedTensorList& inputs,
224  const std::vector<std::string>& outputNames,
225  std::vector<Tensor>* outputs,
226  const thread::ThreadPoolOptions& threadPoolOptions) {
227  if (session == nullptr) {
228  throw cms::Exception("InvalidSession") << "cannot run empty session";
229  }
230 
231  // create empty run options
232  RunOptions runOptions;
233 
234  // run and check the status
235  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
236  if (!status.ok()) {
237  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
238  }
239  }
240 
241  void run(Session* session,
242  const NamedTensorList& inputs,
243  const std::vector<std::string>& outputNames,
244  std::vector<Tensor>* outputs,
245  thread::ThreadPoolInterface* threadPool) {
246  // create thread pool options
247  thread::ThreadPoolOptions threadPoolOptions;
248  threadPoolOptions.inter_op_threadpool = threadPool;
249  threadPoolOptions.intra_op_threadpool = threadPool;
250 
251  // run
252  run(session, inputs, outputNames, outputs, threadPoolOptions);
253  }
254 
255  void run(Session* session,
256  const NamedTensorList& inputs,
257  const std::vector<std::string>& outputNames,
258  std::vector<Tensor>* outputs,
259  const std::string& threadPoolName) {
260  // lookup the thread pool and forward the call accordingly
261  if (threadPoolName == "no_threads") {
263  } else if (threadPoolName == "tbb") {
264  // the TBBTreadPool singleton should be already initialized before with a number of threads
266  } else if (threadPoolName == "tensorflow") {
267  run(session, inputs, outputNames, outputs, nullptr);
268  } else {
269  throw cms::Exception("UnknownThreadPool")
270  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
271  }
272  }
273 
274  void run(Session* session,
275  const std::vector<std::string>& outputNames,
276  std::vector<Tensor>* outputs,
277  const std::string& threadPoolName) {
278  run(session, {}, outputNames, outputs, threadPoolName);
279  }
280 
282  // delete the session if set
283  Session* s = session.load();
284  if (s != nullptr) {
286  session.store(nullptr);
287  }
288 
289  // delete the graph if set
290  if (graph.load() != nullptr) {
291  delete graph.load();
292  graph.store(nullptr);
293  }
294  }
295 
296 } // namespace tensorflow
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:84
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:29
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:67
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:45
std::atomic< Session * > session
Definition: TensorFlow.h:188
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:16
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:222
bool closeSession(Session *&session)
Definition: TensorFlow.cc:197
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
Log< level::Info, false > LogInfo
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:14
std::atomic< GraphDef * > graph
Definition: TensorFlow.h:187
constexpr bool Default
Definition: SoACommon.h:73
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:28
static NoThreadPool & instance()
Definition: NoThreadPool.h:22