CMS 3D CMS Logo

TensorFlow.cc
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
4  *
5  * Author: Marcel Rieger
6  */
7 
12 
13 namespace tensorflow {
14 
15  void Options::setThreading(int nThreads) {
16  _nThreads = nThreads;
17  // set number of threads used for intra and inter operation communication
18  _options.config.set_intra_op_parallelism_threads(nThreads);
19  _options.config.set_inter_op_parallelism_threads(nThreads);
20  }
21 
23  /*
24  * The TensorFlow backend configures the available devices using options provided in the sessionOptions proto.
25  * // Options from https://github.com/tensorflow/tensorflow/blob/c53dab9fbc9de4ea8b1df59041a5ffd3987328c3/tensorflow/core/protobuf/config.proto
26  *
27  * If the device_count["GPU"] = 0 GPUs are not used.
28  * The visible_device_list configuration is used to map the `visible` devices (from CUDA_VISIBLE_DEVICES) to `virtual` devices.
29  * If Backend::cpu is request, the GPU device is disallowed by device_count configuration.
30  * If Backend::cuda is request:
31  * - if ResourceInformation shows an available Nvidia GPU device:
32  * the device is used with memory_growth configuration (not allocating all cuda memory at once).
33  * - if no device is present: an exception is raised.
34  */
35 
37  if (backend == Backend::cpu) {
38  // disable GPU usage
39  (*_options.config.mutable_device_count())["GPU"] = 0;
40  _options.config.mutable_gpu_options()->set_visible_device_list("");
41  }
42  // NVidia GPU
43  else if (backend == Backend::cuda) {
44  if (not ri->nvidiaDriverVersion().empty()) {
45  // Check if one GPU device is visible to TF
46  // If not, an exception is raised --> this can happen in case of driver version mismatch
47  // or missing CUDA support in TF compilation
48  if ((*_options.config.mutable_device_count())["GPU"] == 0) {
50  ex << "Cuda backend requested, NVIDIA GPU visible to cmssw, but not visible to TensorFlow in the job";
51  ex.addContext("Calling tensorflow::setBackend()");
52  throw ex;
53  }
54  // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
55  (*_options.config.mutable_device_count())["GPU"] = 1;
56  _options.config.mutable_gpu_options()->set_visible_device_list("0");
57  // Do not allocate all the memory on the GPU at the beginning.
58  _options.config.mutable_gpu_options()->set_allow_growth(true);
59  } else {
61  ex << "Cuda backend requested, but no NVIDIA GPU available in the job";
62  ex.addContext("Calling tensorflow::setBackend()");
63  throw ex;
64  }
65  }
66  // ROCm and Intel GPU are still not supported
67  else if ((backend == Backend::rocm) || (backend == Backend::intel)) {
69  ex << "ROCm/Intel GPU backend requested, but TF is not compiled yet for this platform";
70  ex.addContext("Calling tensorflow::setBackend()");
71  throw ex;
72  }
73  // Get NVidia GPU if possible or fallback to CPU
74  else if (backend == Backend::best) {
75  // Check if a Nvidia GPU is availabl
76  if (not ri->nvidiaDriverVersion().empty()) {
77  // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
78  (*_options.config.mutable_device_count())["GPU"] = 1;
79  _options.config.mutable_gpu_options()->set_visible_device_list("0");
80  // Do not allocate all the memory on the GPU at the beginning.
81  _options.config.mutable_gpu_options()->set_allow_growth(true);
82  } else {
83  // Just CPU support
84  (*_options.config.mutable_device_count())["GPU"] = 0;
85  _options.config.mutable_gpu_options()->set_visible_device_list("");
86  }
87  }
88  }
89 
90  void setLogging(const std::string& level) {
91  /*
92  * 0 = all messages are logged (default behavior)
93  * 1 = INFO messages are not printed
94  * 2 = INFO and WARNING messages are not printed
95  * 3 = INFO, WARNING, and ERROR messages are not printed
96  */
97  setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0);
98  }
99 
100  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag) {
101  Options default_options{};
102  return loadMetaGraphDef(exportDir, tag, default_options);
103  }
104 
105  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options) {
106  // objects to load the graph
107  Status status;
108  RunOptions runOptions;
109  SavedModelBundle bundle;
110 
111  // load the model
112  status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
113  if (!status.ok()) {
114  throw cms::Exception("InvalidMetaGraphDef")
115  << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
116  }
117 
118  // return a copy of the graph
119  return new MetaGraphDef(bundle.meta_graph_def);
120  }
121 
122  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& options) {
123  edm::LogInfo("PhysicsTools/TensorFlow")
124  << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
125 
126  return loadMetaGraphDef(exportDir, tag, options);
127  }
128 
129  GraphDef* loadGraphDef(const std::string& pbFile) {
130  // objects to load the graph
131  Status status;
132 
133  // load it
134  GraphDef* graphDef = new GraphDef();
135  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
136 
137  // check for success
138  if (!status.ok()) {
139  throw cms::Exception("InvalidGraphDef")
140  << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
141  }
142 
143  return graphDef;
144  }
145 
147  Options default_options{};
148  return createSession(default_options);
149  }
150 
152  // objects to create the session
153  Status status;
154 
155  // create a new, empty session
156  Session* session = nullptr;
157  status = NewSession(options.getSessionOptions(), &session);
158  if (!status.ok()) {
159  throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
160  }
161 
162  return session;
163  }
164 
165  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options) {
166  // check for valid pointer
167  if (metaGraphDef == nullptr) {
168  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
169  }
170 
171  // check that the graph has nodes
172  if (metaGraphDef->graph_def().node_size() <= 0) {
173  throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
174  }
175 
176  Session* session = createSession(options);
177 
178  // add the graph def from the meta graph
179  Status status;
180  status = session->Create(metaGraphDef->graph_def());
181  if (!status.ok()) {
182  throw cms::Exception("InvalidMetaGraphDef")
183  << "error while attaching metaGraphDef to session: " << status.ToString();
184  }
185 
186  // restore variables using the variable and index files in the export directory
187  // first, find names and paths
188  std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
189  std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
190  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
191  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
192  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
193 
194  // when the index file is missing, there's nothing to do
195  if (!Env::Default()->FileExists(indexFile).ok()) {
196  return session;
197  }
198 
199  // create a tensor to store the variable file
200  Tensor varFileTensor(DT_STRING, TensorShape({}));
201  varFileTensor.scalar<tensorflow::tstring>()() = varFile;
202 
203  // run the restore op
204  status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
205  if (!status.ok()) {
206  throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
207  }
208 
209  return session;
210  }
211 
212  Session* createSession(const GraphDef* graphDef) {
213  Options default_options{};
214  return createSession(graphDef, default_options);
215  }
216 
217  Session* createSession(const GraphDef* graphDef, Options& options) {
218  // check for valid pointer
219  if (graphDef == nullptr) {
220  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
221  }
222 
223  // check that the graph has nodes
224  if (graphDef->node_size() <= 0) {
225  throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
226  }
227 
228  // create a new, empty session
229  Session* session = createSession(options);
230 
231  // add the graph def
232  Status status;
233  status = session->Create(*graphDef);
234 
235  // check for success
236  if (!status.ok()) {
237  throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
238  }
239 
240  return session;
241  }
242 
243  bool closeSession(Session*& session) {
244  if (session == nullptr) {
245  return true;
246  }
247 
248  // close and delete the session
249  Status status = session->Close();
250  delete session;
251 
252  // reset the pointer
253  session = nullptr;
254 
255  return status.ok();
256  }
257 
258  bool closeSession(const Session*& session) {
259  auto s = const_cast<Session*>(session);
260  bool state = closeSession(s);
261 
262  // reset the pointer
263  session = nullptr;
264 
265  return state;
266  }
267 
269  // check for empty tensors in the inputs
270  bool isEmpty = false;
271  for (const auto& input : inputs) {
272  // Checking using the shape
273  if (input.second.shape().num_elements() == 0) {
274  isEmpty = true;
275  break;
276  }
277  }
278  return isEmpty;
279  }
280 
281  void run(Session* session,
282  const NamedTensorList& inputs,
283  const std::vector<std::string>& outputNames,
284  std::vector<Tensor>* outputs,
285  const thread::ThreadPoolOptions& threadPoolOptions) {
286  if (session == nullptr) {
287  throw cms::Exception("InvalidSession") << "cannot run empty session";
288  }
289 
290  // create empty run options
291  RunOptions runOptions;
292 
293  // Check if the inputs are empty
295  return;
296 
297  // run and check the status
298  Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
299  if (!status.ok()) {
300  throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
301  }
302  }
303 
304  void run(Session* session,
305  const NamedTensorList& inputs,
306  const std::vector<std::string>& outputNames,
307  std::vector<Tensor>* outputs,
308  thread::ThreadPoolInterface* threadPool) {
309  // create thread pool options
310  thread::ThreadPoolOptions threadPoolOptions;
311  threadPoolOptions.inter_op_threadpool = threadPool;
312  threadPoolOptions.intra_op_threadpool = threadPool;
313 
314  // run
315  run(session, inputs, outputNames, outputs, threadPoolOptions);
316  }
317 
318  void run(Session* session,
319  const NamedTensorList& inputs,
320  const std::vector<std::string>& outputNames,
321  std::vector<Tensor>* outputs,
322  const std::string& threadPoolName) {
323  // lookup the thread pool and forward the call accordingly
324  if (threadPoolName == "no_threads") {
326  } else if (threadPoolName == "tbb") {
327  // the TBBTreadPool singleton should be already initialized before with a number of threads
329  } else if (threadPoolName == "tensorflow") {
330  run(session, inputs, outputNames, outputs, nullptr);
331  } else {
332  throw cms::Exception("UnknownThreadPool")
333  << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
334  }
335  }
336 
337  void run(Session* session,
338  const std::vector<std::string>& outputNames,
339  std::vector<Tensor>* outputs,
340  const std::string& threadPoolName) {
341  run(session, {}, outputNames, outputs, threadPoolName);
342  }
343 
345  // delete the session if set
346  Session* s = session.load();
347  if (s != nullptr) {
349  session.store(nullptr);
350  }
351 
352  // delete the graph if set
353  if (graph.load() != nullptr) {
354  delete graph.load();
355  graph.store(nullptr);
356  }
357  }
358 
359 } // namespace tensorflow
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:31
void setBackend(Backend backend=Backend::cpu)
Definition: TensorFlow.cc:22
virtual std::string const & nvidiaDriverVersion() const =0
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:129
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:100
static std::string const input
Definition: EdmProvDump.cc:50
std::atomic< Session * > session
Definition: TensorFlow.h:194
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:281
bool closeSession(Session *&session)
Definition: TensorFlow.cc:243
bool checkEmptyInputs(const NamedTensorList &inputs)
Definition: TensorFlow.cc:268
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, Options &Options)
Definition: TensorFlow.cc:122
Session * createSession()
Definition: TensorFlow.cc:146
Log< level::Info, false > LogInfo
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:90
std::atomic< GraphDef * > graph
Definition: TensorFlow.h:193
constexpr bool Default
Definition: SoACommon.h:75
void addContext(std::string const &context)
Definition: Exception.cc:169
void setThreading(int nThreads=1)
Definition: TensorFlow.cc:15
static NoThreadPool & instance()
Definition: NoThreadPool.h:22
SessionOptions _options
Definition: TensorFlow.h:36