CMS 3D CMS Logo

TensorFlow.cc
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * Based on TensorFlow C++ API 2.1.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
10 
12 
13 namespace tensorflow {
14 
15  void setLogging(const std::string& level) { setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0); }
16 
17  void setThreading(SessionOptions& sessionOptions, int nThreads) {
18  // set number of threads used for intra and inter operation communication
19  sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
20  sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
21  }
22 
23  void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool) {
24  edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
25 
26  setThreading(sessionOptions, nThreads);
27  }
28 
29  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
30  // objects to load the graph
31  Status status;
32  RunOptions runOptions;
33  SavedModelBundle bundle;
34 
35  // load the model
36  status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
37  if (!status.ok()) {
38  throw cms::Exception("InvalidMetaGraphDef")
39  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
40  }
41 
42  // return a copy of the graph
43  return new MetaGraphDef(bundle.meta_graph_def);
44  }
45 
46  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
47  edm::LogInfo("PhysicsTools/TensorFlow")
48  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
49 
50  return loadMetaGraphDef(exportDir, tag, sessionOptions);
51  }
52 
53  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
54  // create session options and set thread options
55  SessionOptions sessionOptions;
56  setThreading(sessionOptions, nThreads);
57 
58  return loadMetaGraphDef(exportDir, tag, sessionOptions);
59  }
60 
61  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
62  edm::LogInfo("PhysicsTools/TensorFlow")
63  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
64 
65  return loadMetaGraphDef(exportDir, tag, nThreads);
66  }
67 
68  GraphDef* loadGraphDef(const std::string& pbFile) {
69  // objects to load the graph
70  Status status;
71 
72  // load it
73  GraphDef* graphDef = new GraphDef();
74  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
75 
76  // check for success
77  if (!status.ok()) {
78  throw cms::Exception("InvalidGraphDef")
79  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
80  }
81 
82  return graphDef;
83  }
84 
85  Session* createSession(SessionOptions& sessionOptions) {
86  // objects to create the session
87  Status status;
88 
89  // create a new, empty session
90  Session* session = nullptr;
91  status = NewSession(sessionOptions, &session);
92  if (!status.ok()) {
93  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
94  }
95 
96  return session;
97  }
98 
99  Session* createSession(int nThreads) {
100  // create session options and set thread options
101  SessionOptions sessionOptions;
102  setThreading(sessionOptions, nThreads);
103 
104  return createSession(sessionOptions);
105  }
106 
107  Session* createSession(const MetaGraphDef* metaGraphDef,
108  const std::string& exportDir,
109  SessionOptions& sessionOptions) {
110  // check for valid pointer
111  if (metaGraphDef == nullptr) {
112  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
113  }
114 
115  // check that the graph has nodes
116  if (metaGraphDef->graph_def().node_size() <= 0) {
117  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
118  }
119 
120  Session* session = createSession(sessionOptions);
121 
122  // add the graph def from the meta graph
123  Status status;
124  status = session->Create(metaGraphDef->graph_def());
125  if (!status.ok()) {
126  throw cms::Exception("InvalidMetaGraphDef")
127  << "error while attaching metaGraphDef to session: " << status.ToString();
128  }
129 
130  // restore variables using the variable and index files in the export directory
131  // first, find names and paths
132  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
133  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
134  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
135  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
136  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
137 
138  // when the index file is missing, there's nothing to do
139  if (!Env::Default()->FileExists(indexFile).ok()) {
140  return session;
141  }
142 
143  // create a tensor to store the variable file
144  Tensor varFileTensor(DT_STRING, TensorShape({}));
145  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
146 
147  // run the restore op
148  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
149  if (!status.ok()) {
150  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
151  }
152 
153  return session;
154  }
155 
156  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
157  // create session options and set thread options
158  SessionOptions sessionOptions;
159  setThreading(sessionOptions, nThreads);
160 
161  return createSession(metaGraphDef, exportDir, sessionOptions);
162  }
163 
164  Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
165  // check for valid pointer
166  if (graphDef == nullptr) {
167  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
168  }
169 
170  // check that the graph has nodes
171  if (graphDef->node_size() <= 0) {
172  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
173  }
174 
175  // create a new, empty session
176  Session* session = createSession(sessionOptions);
177 
178  // add the graph def
179  Status status;
180  status = session->Create(*graphDef);
181 
182  // check for success
183  if (!status.ok()) {
184  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
185  }
186 
187  return session;
188  }
189 
190  Session* createSession(const GraphDef* graphDef, int nThreads) {
191  // create session options and set thread options
192  SessionOptions sessionOptions;
193  setThreading(sessionOptions, nThreads);
194 
195  return createSession(graphDef, sessionOptions);
196  }
197 
198  bool closeSession(Session*& session) {
199  if (session == nullptr) {
200  return true;
201  }
202 
203  // close and delete the session
204  Status status = session->Close();
205  delete session;
206 
207  // reset the pointer
208  session = nullptr;
209 
210  return status.ok();
211  }
212 
213  void run(Session* session,
214  const NamedTensorList& inputs,
215  const std::vector<std::string>& outputNames,
216  std::vector<Tensor>* outputs,
217  const thread::ThreadPoolOptions& threadPoolOptions) {
218  if (session == nullptr) {
219  throw cms::Exception("InvalidSession") << "cannot run empty session";
220  }
221 
222  // create empty run options
223  RunOptions runOptions;
224 
225  // run and check the status
226  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
227  if (!status.ok()) {
228  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
229  }
230  }
231 
232  void run(Session* session,
233  const NamedTensorList& inputs,
234  const std::vector<std::string>& outputNames,
235  std::vector<Tensor>* outputs,
236  thread::ThreadPoolInterface* threadPool) {
237  // create thread pool options
238  thread::ThreadPoolOptions threadPoolOptions;
239  threadPoolOptions.inter_op_threadpool = threadPool;
240  threadPoolOptions.intra_op_threadpool = threadPool;
241 
242  // run
243  run(session, inputs, outputNames, outputs, threadPoolOptions);
244  }
245 
246  void run(Session* session,
247  const NamedTensorList& inputs,
248  const std::vector<std::string>& outputNames,
249  std::vector<Tensor>* outputs,
250  const std::string& threadPoolName) {
251  // lookup the thread pool and forward the call accordingly
252  if (threadPoolName == "no_threads") {
254  } else if (threadPoolName == "tbb") {
255  // the TBBTreadPool singleton should be already initialized before with a number of threads
257  } else if (threadPoolName == "tensorflow") {
258  run(session, inputs, outputNames, outputs, nullptr);
259  } else {
260  throw cms::Exception("UnknownThreadPool")
261  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
262  }
263  }
264 
265  void run(Session* session,
266  const std::vector<std::string>& outputNames,
267  std::vector<Tensor>* outputs,
268  const std::string& threadPoolName) {
269  run(session, {}, outputNames, outputs, threadPoolName);
270  }
271 
272 } // namespace tensorflow
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:30
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:46
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:213
bool closeSession(Session *&session)
Definition: TensorFlow.cc:198
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
Log< level::Info, false > LogInfo
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:15
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
static NoThreadPool & instance()
Definition: NoThreadPool.h:22