CMS 3D CMS Logo

TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
4  *
5  * Author: Marcel Rieger
6  */
7 
8 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
9 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 
11 #include "tensorflow/core/framework/tensor.h"
12 #include "tensorflow/core/lib/core/threadpool.h"
13 #include "tensorflow/core/lib/io/path.h"
14 #include "tensorflow/core/public/session.h"
15 #include "tensorflow/core/util/tensor_bundle/naming.h"
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/saved_model/loader.h"
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/cc/saved_model/tag_constants.h"
20 
23 
25 
26 namespace tensorflow {
27 
28  enum class Backend { cpu, cuda, rocm, intel, best };
29 
30  typedef std::pair<std::string, Tensor> NamedTensor;
31  typedef std::vector<NamedTensor> NamedTensorList;
32 
33  struct Options {
34  int _nThreads;
36  SessionOptions _options;
37 
41  };
42 
46  };
47 
48  // updates the config of sessionOptions so that it uses nThreads
49  void setThreading(int nThreads = 1);
50 
51  // Set the backend option cpu/cuda
52  // The gpu memory is set to "allow_growth" to avoid TF getting all the CUDA memory at once.
54 
55  SessionOptions& getSessionOptions() { return _options; };
56  int getNThreads() const { return _nThreads; };
57  Backend getBackend() const { return _backend; };
58  };
59 
60  // set the tensorflow log level
61  void setLogging(const std::string& level = "3");
62 
63  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
64  // predefined options
65  // transfers ownership
66  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag = kSavedModelTagServe);
67 
68  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
69  // user provided options
70  // transfers ownership
71  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options);
72 
73  // deprecated in favor of loadMetaGraphDef
74  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& Options);
75 
76  // loads a graph definition saved as a protobuf file at pbFile
77  // transfers ownership
78  GraphDef* loadGraphDef(const std::string& pbFile);
79 
80  // return a new, empty session using the predefined options
82 
83  // return a new, empty session using user provided options
84  // transfers ownership
86 
87  // return a new session that will contain an already loaded meta graph whose exportDir must be
88  // given in order to load and initialize the variables, sessionOptions are predefined
89  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
90  // transfers ownership
91  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options);
92 
93  // return a new session that will contain an already loaded graph def, sessionOptions are predefined
94  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
95  // transfers ownership
96  Session* createSession(const GraphDef* graphDef);
97 
98  // return a new session that will contain an already loaded graph def, sessionOptions are user defined
99  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
100  // transfers ownership
101  Session* createSession(const GraphDef* graphDef, Options& options);
102 
103  // closes a session, calls its destructor, resets the pointer, and returns true on success
104  bool closeSession(Session*& session);
105 
106  // version of the function above that accepts a const session
107  bool closeSession(const Session*& session);
108 
109  // run the session with inputs and outputNames, store output tensors, and control the underlying
110  // thread pool using threadPoolOptions
111  // used for thread scheduling with custom thread pool options
112  // throws a cms exception when not successful
113  void run(Session* session,
114  const NamedTensorList& inputs,
115  const std::vector<std::string>& outputNames,
116  std::vector<Tensor>* outputs,
117  const thread::ThreadPoolOptions& threadPoolOptions);
118 
119  // version of the function above that accepts a const session
120  inline void run(const Session* session,
121  const NamedTensorList& inputs,
122  const std::vector<std::string>& outputNames,
123  std::vector<Tensor>* outputs,
124  const thread::ThreadPoolOptions& threadPoolOptions) {
125  // TF takes a non-const session in the run call which is, however, thread-safe and logically
126  // const, thus const_cast is consistent
127  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
128  }
129 
130  // run the session with inputs and outputNames, store output tensors, and control the underlying
131  // thread pool
132  // throws a cms exception when not successful
133  void run(Session* session,
134  const NamedTensorList& inputs,
135  const std::vector<std::string>& outputNames,
136  std::vector<Tensor>* outputs,
137  thread::ThreadPoolInterface* threadPool);
138 
139  // version of the function above that accepts a const session
140  inline void run(const Session* session,
141  const NamedTensorList& inputs,
142  const std::vector<std::string>& outputNames,
143  std::vector<Tensor>* outputs,
144  thread::ThreadPoolInterface* threadPool) {
145  // TF takes a non-const session in the run call which is, however, thread-safe and logically
146  // const, thus const_cast is consistent
147  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
148  }
149 
150  // run the session with inputs and outputNames, store output tensors, and control the underlying
151  // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
152  // throws a cms exception when not successful
153  void run(Session* session,
154  const NamedTensorList& inputs,
155  const std::vector<std::string>& outputNames,
156  std::vector<Tensor>* outputs,
157  const std::string& threadPoolName = "no_threads");
158 
159  // version of the function above that accepts a const session
160  inline void run(const Session* session,
161  const NamedTensorList& inputs,
162  const std::vector<std::string>& outputNames,
163  std::vector<Tensor>* outputs,
164  const std::string& threadPoolName = "no_threads") {
165  // TF takes a non-const session in the run call which is, however, thread-safe and logically
166  // const, thus const_cast is consistent
167  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
168  }
169 
170  // run the session without inputs but only outputNames, store output tensors, and control the
171  // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
172  // throws a cms exception when not successful
173  void run(Session* session,
174  const std::vector<std::string>& outputNames,
175  std::vector<Tensor>* outputs,
176  const std::string& threadPoolName = "no_threads");
177 
178  // version of the function above that accepts a const session
179  inline void run(const Session* session,
180  const std::vector<std::string>& outputNames,
181  std::vector<Tensor>* outputs,
182  const std::string& threadPoolName = "no_threads") {
183  // TF takes a non-const session in the run call which is, however, thread-safe and logically
184  // const, thus const_cast is consistent
185  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
186  }
187 
188  // struct that can be used in edm::stream modules for caching a graph and a session instance,
189  // both made atomic for cases where access is required from multiple threads
190  struct SessionCache {
191  std::atomic<GraphDef*> graph;
192  std::atomic<Session*> session;
193 
194  // constructor
196 
197  // initializing constructor, forwarding all arguments to createSession
198  template <typename... Args>
199  SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
200  createSession(graphPath, std::forward<Args>(sessionArgs)...);
201  }
202 
203  // destructor
205 
206  // create the internal graph representation from graphPath and the session object, forwarding
207  // all additional arguments to the central tensorflow::createSession
208  template <typename... Args>
209  void createSession(const std::string& graphPath, Args&&... sessionArgs) {
210  graph.store(loadGraphDef(graphPath));
211  session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
212  }
213 
214  // return a pointer to the const session
215  inline const Session* getSession() const { return session.load(); }
216 
217  // closes and removes the session as well as the graph, and sets the atomic members to nullptr's
218  void closeSession();
219  };
220 
221 } // namespace tensorflow
222 
223 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:31
void setBackend(Backend backend=Backend::cpu)
Definition: TensorFlow.cc:22
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:120
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:91
Backend getBackend() const
Definition: TensorFlow.h:57
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:30
std::atomic< Session * > session
Definition: TensorFlow.h:192
void createSession(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:209
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:259
bool closeSession(Session *&session)
Definition: TensorFlow.cc:234
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, Options &Options)
Definition: TensorFlow.cc:113
Options(Backend backend)
Definition: TensorFlow.h:38
Session * createSession()
Definition: TensorFlow.cc:137
const Session * getSession() const
Definition: TensorFlow.h:215
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:81
std::atomic< GraphDef * > graph
Definition: TensorFlow.h:191
SessionCache(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:199
void setThreading(int nThreads=1)
Definition: TensorFlow.cc:15
SessionOptions & getSessionOptions()
Definition: TensorFlow.h:55
SessionOptions _options
Definition: TensorFlow.h:36
int getNThreads() const
Definition: TensorFlow.h:56