CMS 3D CMS Logo

TensorFlow.cc
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
4  *
5  * Author: Marcel Rieger
6  */
7 
12 
13 namespace tensorflow {
14 
15  void Options::setThreading(int nThreads) {
16  _nThreads = nThreads;
17  // set number of threads used for intra and inter operation communication
18  _options.config.set_intra_op_parallelism_threads(nThreads);
19  _options.config.set_inter_op_parallelism_threads(nThreads);
20  }
21 
23  /*
24  * The TensorFlow backend configures the available devices using options provided in the sessionOptions proto.
25  * // Options from https://github.com/tensorflow/tensorflow/blob/c53dab9fbc9de4ea8b1df59041a5ffd3987328c3/tensorflow/core/protobuf/config.proto
26  *
27  * If the device_count["GPU"] = 0 GPUs are not used.
28  * The visible_device_list configuration is used to map the `visible` devices (from CUDA_VISIBLE_DEVICES) to `virtual` devices.
29  * If Backend::cpu is request, the GPU device is disallowed by device_count configuration.
30  * If Backend::cuda is request:
31  * - if ResourceInformation shows an available Nvidia GPU device:
32  * the device is used with memory_growth configuration (not allocating all cuda memory at once).
33  * - if no device is present: an exception is raised.
34  */
35 
37  if (backend == Backend::cpu) {
38  // disable GPU usage
39  (*_options.config.mutable_device_count())["GPU"] = 0;
40  _options.config.mutable_gpu_options()->set_visible_device_list("");
41  }
42  // NVidia GPU
43  else if (backend == Backend::cuda) {
44  if (not ri->nvidiaDriverVersion().empty()) {
45  // Check if one GPU device is visible to TF
46  // If not, an exception is raised --> this can happen in case of driver version mismatch
47  // or missing CUDA support in TF compilation
48  if ((*_options.config.mutable_device_count())["GPU"] == 0) {
50  ex << "Cuda backend requested, NVIDIA GPU visible to cmssw, but not visible to TensorFlow in the job";
51  ex.addContext("Calling tensorflow::setBackend()");
52  throw ex;
53  }
54  // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
55  (*_options.config.mutable_device_count())["GPU"] = 1;
56  _options.config.mutable_gpu_options()->set_visible_device_list("0");
57  // Do not allocate all the memory on the GPU at the beginning.
58  _options.config.mutable_gpu_options()->set_allow_growth(true);
59  } else {
61  ex << "Cuda backend requested, but no NVIDIA GPU available in the job";
62  ex.addContext("Calling tensorflow::setBackend()");
63  throw ex;
64  }
65  }
66  // ROCm and Intel GPU are still not supported
67  else if ((backend == Backend::rocm) || (backend == Backend::intel)) {
69  ex << "ROCm/Intel GPU backend requested, but TF is not compiled yet for this platform";
70  ex.addContext("Calling tensorflow::setBackend()");
71  throw ex;
72  }
73  // Get NVidia GPU if possible or fallback to CPU
74  else if (backend == Backend::best) {
75  // Check if a Nvidia GPU is availabl
76  if (not ri->nvidiaDriverVersion().empty()) {
77  // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
78  (*_options.config.mutable_device_count())["GPU"] = 1;
79  _options.config.mutable_gpu_options()->set_visible_device_list("0");
80  // Do not allocate all the memory on the GPU at the beginning.
81  _options.config.mutable_gpu_options()->set_allow_growth(true);
82  } else {
83  // Just CPU support
84  (*_options.config.mutable_device_count())["GPU"] = 0;
85  _options.config.mutable_gpu_options()->set_visible_device_list("");
86  }
87  }
88  }
89 
90  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag) {
91  Options default_options{};
92  return loadMetaGraphDef(exportDir, tag, default_options);
93  }
94 
95  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options) {
96  // objects to load the graph
97  Status status;
98  RunOptions runOptions;
99  SavedModelBundle bundle;
100 
101  // load the model
102  status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
103  if (!status.ok()) {
104  throw cms::Exception("InvalidMetaGraphDef")
105  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
106  }
107 
108  // return a copy of the graph
109  return new MetaGraphDef(bundle.meta_graph_def);
110  }
111 
112  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& options) {
113  edm::LogInfo("PhysicsTools/TensorFlow")
114  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
115 
116  return loadMetaGraphDef(exportDir, tag, options);
117  }
118 
119  GraphDef* loadGraphDef(const std::string& pbFile) {
120  // objects to load the graph
121  Status status;
122 
123  // load it
124  GraphDef* graphDef = new GraphDef();
125  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
126 
127  // check for success
128  if (!status.ok()) {
129  throw cms::Exception("InvalidGraphDef")
130  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
131  }
132 
133  return graphDef;
134  }
135 
137  Options default_options{};
138  return createSession(default_options);
139  }
140 
142  // objects to create the session
143  Status status;
144 
145  // create a new, empty session
146  Session* session = nullptr;
147  status = NewSession(options.getSessionOptions(), &session);
148  if (!status.ok()) {
149  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
150  }
151 
152  return session;
153  }
154 
155  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options) {
156  // check for valid pointer
157  if (metaGraphDef == nullptr) {
158  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
159  }
160 
161  // check that the graph has nodes
162  if (metaGraphDef->graph_def().node_size() <= 0) {
163  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
164  }
165 
166  Session* session = createSession(options);
167 
168  // add the graph def from the meta graph
169  Status status;
170  status = session->Create(metaGraphDef->graph_def());
171  if (!status.ok()) {
172  throw cms::Exception("InvalidMetaGraphDef")
173  << "error while attaching metaGraphDef to session: " << status.ToString();
174  }
175 
176  // restore variables using the variable and index files in the export directory
177  // first, find names and paths
178  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
179  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
180  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
181  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
182  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
183 
184  // when the index file is missing, there's nothing to do
185  if (!Env::Default()->FileExists(indexFile).ok()) {
186  return session;
187  }
188 
189  // create a tensor to store the variable file
190  Tensor varFileTensor(DT_STRING, TensorShape({}));
191  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
192 
193  // run the restore op
194  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
195  if (!status.ok()) {
196  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
197  }
198 
199  return session;
200  }
201 
202  Session* createSession(const GraphDef* graphDef) {
203  Options default_options{};
204  return createSession(graphDef, default_options);
205  }
206 
207  Session* createSession(const GraphDef* graphDef, Options& options) {
208  // check for valid pointer
209  if (graphDef == nullptr) {
210  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
211  }
212 
213  // check that the graph has nodes
214  if (graphDef->node_size() <= 0) {
215  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
216  }
217 
218  // create a new, empty session
219  Session* session = createSession(options);
220 
221  // add the graph def
222  Status status;
223  status = session->Create(*graphDef);
224 
225  // check for success
226  if (!status.ok()) {
227  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
228  }
229 
230  return session;
231  }
232 
233  bool closeSession(Session*& session) {
234  if (session == nullptr) {
235  return true;
236  }
237 
238  // close and delete the session
239  Status status = session->Close();
240  delete session;
241 
242  // reset the pointer
243  session = nullptr;
244 
245  return status.ok();
246  }
247 
248  bool closeSession(const Session*& session) {
249  auto s = const_cast<Session*>(session);
250  bool state = closeSession(s);
251 
252  // reset the pointer
253  session = nullptr;
254 
255  return state;
256  }
257 
259  // check for empty tensors in the inputs
260  bool isEmpty = false;
261  for (const auto& input : inputs) {
262  // Checking using the shape
263  if (input.second.shape().num_elements() == 0) {
264  isEmpty = true;
265  break;
266  }
267  }
268  return isEmpty;
269  }
270 
271  void run(Session* session,
272  const NamedTensorList& inputs,
273  const std::vector<std::string>& outputNames,
274  std::vector<Tensor>* outputs,
275  const thread::ThreadPoolOptions& threadPoolOptions) {
276  if (session == nullptr) {
277  throw cms::Exception("InvalidSession") << "cannot run empty session";
278  }
279 
280  // create empty run options
281  RunOptions runOptions;
282 
283  // Check if the inputs are empty
285  return;
286 
287  // run and check the status
288  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
289  if (!status.ok()) {
290  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
291  }
292  }
293 
294  void run(Session* session,
295  const NamedTensorList& inputs,
296  const std::vector<std::string>& outputNames,
297  std::vector<Tensor>* outputs,
298  thread::ThreadPoolInterface* threadPool) {
299  // create thread pool options
300  thread::ThreadPoolOptions threadPoolOptions;
301  threadPoolOptions.inter_op_threadpool = threadPool;
302  threadPoolOptions.intra_op_threadpool = threadPool;
303 
304  // run
305  run(session, inputs, outputNames, outputs, threadPoolOptions);
306  }
307 
308  void run(Session* session,
309  const NamedTensorList& inputs,
310  const std::vector<std::string>& outputNames,
311  std::vector<Tensor>* outputs,
312  const std::string& threadPoolName) {
313  // lookup the thread pool and forward the call accordingly
314  if (threadPoolName == "no_threads") {
316  } else if (threadPoolName == "tbb") {
317  // the TBBTreadPool singleton should be already initialized before with a number of threads
319  } else if (threadPoolName == "tensorflow") {
320  run(session, inputs, outputNames, outputs, nullptr);
321  } else {
322  throw cms::Exception("UnknownThreadPool")
323  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
324  }
325  }
326 
327  void run(Session* session,
328  const std::vector<std::string>& outputNames,
329  std::vector<Tensor>* outputs,
330  const std::string& threadPoolName) {
331  run(session, {}, outputNames, outputs, threadPoolName);
332  }
333 
335  // delete the session if set
336  Session* s = session.load();
337  if (s != nullptr) {
339  session.store(nullptr);
340  }
341 
342  // delete the graph if set
343  if (graph.load() != nullptr) {
344  delete graph.load();
345  graph.store(nullptr);
346  }
347  }
348 
349 } // namespace tensorflow
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:31
void setBackend(Backend backend=Backend::cpu)
Definition: TensorFlow.cc:22
virtual std::string const & nvidiaDriverVersion() const =0
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:119
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:90
static std::string const input
Definition: EdmProvDump.cc:50
std::atomic< Session * > session
Definition: TensorFlow.h:191
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:271
bool closeSession(Session *&session)
Definition: TensorFlow.cc:233
bool checkEmptyInputs(const NamedTensorList &inputs)
Definition: TensorFlow.cc:258
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, Options &Options)
Definition: TensorFlow.cc:112
Session * createSession()
Definition: TensorFlow.cc:136
Log< level::Info, false > LogInfo
std::atomic< GraphDef * > graph
Definition: TensorFlow.h:190
constexpr bool Default
Definition: SoACommon.h:75
void addContext(std::string const &context)
Definition: Exception.cc:169
void setThreading(int nThreads=1)
Definition: TensorFlow.cc:15
static NoThreadPool & instance()
Definition: NoThreadPool.h:22
SessionOptions _options
Definition: TensorFlow.h:36