CMS 3D CMS Logo

TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
4  *
5  * Author: Marcel Rieger
6  */
7 
8 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
9 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 
11 #include "tensorflow/core/framework/tensor.h"
12 #include "tensorflow/core/lib/core/threadpool.h"
13 #include "tensorflow/core/lib/io/path.h"
14 #include "tensorflow/core/public/session.h"
15 #include "tensorflow/core/util/tensor_bundle/naming.h"
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/saved_model/loader.h"
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/cc/saved_model/tag_constants.h"
20 
23 
25 
26 namespace tensorflow {
27 
28  enum class Backend { cpu, cuda, rocm, intel, best };
29 
30  typedef std::pair<std::string, Tensor> NamedTensor;
31  typedef std::vector<NamedTensor> NamedTensorList;
32 
33  struct Options {
34  int _nThreads;
36  SessionOptions _options;
37 
41  };
42 
46  };
47 
48  // updates the config of sessionOptions so that it uses nThreads
49  void setThreading(int nThreads = 1);
50 
51  // Set the backend option cpu/cuda
52  // The gpu memory is set to "allow_growth" to avoid TF getting all the CUDA memory at once.
54 
55  SessionOptions& getSessionOptions() { return _options; };
56  int getNThreads() const { return _nThreads; };
57  Backend getBackend() const { return _backend; };
58  };
59 
60  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
61  // predefined options
62  // transfers ownership
63  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag = kSavedModelTagServe);
64 
65  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
66  // user provided options
67  // transfers ownership
68  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options);
69 
70  // deprecated in favor of loadMetaGraphDef
71  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& Options);
72 
73  // loads a graph definition saved as a protobuf file at pbFile
74  // transfers ownership
75  GraphDef* loadGraphDef(const std::string& pbFile);
76 
77  // return a new, empty session using the predefined options
79 
80  // return a new, empty session using user provided options
81  // transfers ownership
83 
84  // return a new session that will contain an already loaded meta graph whose exportDir must be
85  // given in order to load and initialize the variables, sessionOptions are predefined
86  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
87  // transfers ownership
88  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options);
89 
90  // return a new session that will contain an already loaded graph def, sessionOptions are predefined
91  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
92  // transfers ownership
93  Session* createSession(const GraphDef* graphDef);
94 
95  // return a new session that will contain an already loaded graph def, sessionOptions are user defined
96  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
97  // transfers ownership
98  Session* createSession(const GraphDef* graphDef, Options& options);
99 
100  // closes a session, calls its destructor, resets the pointer, and returns true on success
101  bool closeSession(Session*& session);
102 
103  // version of the function above that accepts a const session
104  bool closeSession(const Session*& session);
105 
107 
108  // run the session with inputs and outputNames, store output tensors, and control the underlying
109  // thread pool using threadPoolOptions
110  // used for thread scheduling with custom thread pool options
111  // throws a cms exception when not successful
112  void run(Session* session,
113  const NamedTensorList& inputs,
114  const std::vector<std::string>& outputNames,
115  std::vector<Tensor>* outputs,
116  const thread::ThreadPoolOptions& threadPoolOptions);
117 
118  // version of the function above that accepts a const session
119  inline void run(const Session* session,
120  const NamedTensorList& inputs,
121  const std::vector<std::string>& outputNames,
122  std::vector<Tensor>* outputs,
123  const thread::ThreadPoolOptions& threadPoolOptions) {
124  // TF takes a non-const session in the run call which is, however, thread-safe and logically
125  // const, thus const_cast is consistent
126  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
127  }
128 
129  // run the session with inputs and outputNames, store output tensors, and control the underlying
130  // thread pool
131  // throws a cms exception when not successful
132  void run(Session* session,
133  const NamedTensorList& inputs,
134  const std::vector<std::string>& outputNames,
135  std::vector<Tensor>* outputs,
136  thread::ThreadPoolInterface* threadPool);
137 
138  // version of the function above that accepts a const session
139  inline void run(const Session* session,
140  const NamedTensorList& inputs,
141  const std::vector<std::string>& outputNames,
142  std::vector<Tensor>* outputs,
143  thread::ThreadPoolInterface* threadPool) {
144  // TF takes a non-const session in the run call which is, however, thread-safe and logically
145  // const, thus const_cast is consistent
146  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
147  }
148 
149  // run the session with inputs and outputNames, store output tensors, and control the underlying
150  // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
151  // throws a cms exception when not successful
152  void run(Session* session,
153  const NamedTensorList& inputs,
154  const std::vector<std::string>& outputNames,
155  std::vector<Tensor>* outputs,
156  const std::string& threadPoolName = "no_threads");
157 
158  // version of the function above that accepts a const session
159  inline void run(const Session* session,
160  const NamedTensorList& inputs,
161  const std::vector<std::string>& outputNames,
162  std::vector<Tensor>* outputs,
163  const std::string& threadPoolName = "no_threads") {
164  // TF takes a non-const session in the run call which is, however, thread-safe and logically
165  // const, thus const_cast is consistent
166  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
167  }
168 
169  // run the session without inputs but only outputNames, store output tensors, and control the
170  // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
171  // throws a cms exception when not successful
172  void run(Session* session,
173  const std::vector<std::string>& outputNames,
174  std::vector<Tensor>* outputs,
175  const std::string& threadPoolName = "no_threads");
176 
177  // version of the function above that accepts a const session
178  inline void run(const Session* session,
179  const std::vector<std::string>& outputNames,
180  std::vector<Tensor>* outputs,
181  const std::string& threadPoolName = "no_threads") {
182  // TF takes a non-const session in the run call which is, however, thread-safe and logically
183  // const, thus const_cast is consistent
184  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
185  }
186 
187  // struct that can be used in edm::stream modules for caching a graph and a session instance,
188  // both made atomic for cases where access is required from multiple threads
189  struct SessionCache {
190  std::atomic<GraphDef*> graph;
191  std::atomic<Session*> session;
192 
193  // constructor
195 
196  // initializing constructor, forwarding all arguments to createSession
197  template <typename... Args>
198  SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
199  createSession(graphPath, std::forward<Args>(sessionArgs)...);
200  }
201 
202  // destructor
204 
205  // create the internal graph representation from graphPath and the session object, forwarding
206  // all additional arguments to the central tensorflow::createSession
207  template <typename... Args>
208  void createSession(const std::string& graphPath, Args&&... sessionArgs) {
209  graph.store(loadGraphDef(graphPath));
210  session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
211  }
212 
213  // return a pointer to the const session
214  inline const Session* getSession() const { return session.load(); }
215 
216  // closes and removes the session as well as the graph, and sets the atomic members to nullptr's
217  void closeSession();
218  };
219 
220 } // namespace tensorflow
221 
222 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:31
void setBackend(Backend backend=Backend::cpu)
Definition: TensorFlow.cc:22
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:119
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:90
Backend getBackend() const
Definition: TensorFlow.h:57
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:30
std::atomic< Session * > session
Definition: TensorFlow.h:191
void createSession(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:208
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:271
bool closeSession(Session *&session)
Definition: TensorFlow.cc:233
bool checkEmptyInputs(const NamedTensorList &inputs)
Definition: TensorFlow.cc:258
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, Options &Options)
Definition: TensorFlow.cc:112
Options(Backend backend)
Definition: TensorFlow.h:38
Session * createSession()
Definition: TensorFlow.cc:136
const Session * getSession() const
Definition: TensorFlow.h:214
std::atomic< GraphDef * > graph
Definition: TensorFlow.h:190
SessionCache(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:198
void setThreading(int nThreads=1)
Definition: TensorFlow.cc:15
SessionOptions & getSessionOptions()
Definition: TensorFlow.h:55
SessionOptions _options
Definition: TensorFlow.h:36
int getNThreads() const
Definition: TensorFlow.h:56