CMS 3D CMS Logo

TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
4  *
5  * Author: Marcel Rieger
6  */
7 
8 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
9 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 
11 #include "tensorflow/core/framework/tensor.h"
12 #include "tensorflow/core/lib/core/threadpool.h"
13 #include "tensorflow/core/lib/io/path.h"
14 #include "tensorflow/core/public/session.h"
15 #include "tensorflow/core/util/tensor_bundle/naming.h"
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/saved_model/loader.h"
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/cc/saved_model/tag_constants.h"
20 
23 
25 
26 namespace tensorflow {
27 
28  enum class Backend { cpu, cuda, rocm, intel, best };
29 
30  typedef std::pair<std::string, Tensor> NamedTensor;
31  typedef std::vector<NamedTensor> NamedTensorList;
32 
33  struct Options {
34  int _nThreads;
36  SessionOptions _options;
37 
41  };
42 
46  };
47 
48  // updates the config of sessionOptions so that it uses nThreads
49  void setThreading(int nThreads = 1);
50 
51  // Set the backend option cpu/cuda
52  // The gpu memory is set to "allow_growth" to avoid TF getting all the CUDA memory at once.
54 
55  SessionOptions& getSessionOptions() { return _options; };
56  int getNThreads() const { return _nThreads; };
57  Backend getBackend() const { return _backend; };
58  };
59 
60  // set the tensorflow log level
61  void setLogging(const std::string& level = "3");
62 
63  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
64  // predefined options
65  // transfers ownership
66  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag = kSavedModelTagServe);
67 
68  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
69  // user provided options
70  // transfers ownership
71  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options);
72 
73  // deprecated in favor of loadMetaGraphDef
74  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& Options);
75 
76  // loads a graph definition saved as a protobuf file at pbFile
77  // transfers ownership
78  GraphDef* loadGraphDef(const std::string& pbFile);
79 
80  // return a new, empty session using the predefined options
82 
83  // return a new, empty session using user provided options
84  // transfers ownership
86 
87  // return a new session that will contain an already loaded meta graph whose exportDir must be
88  // given in order to load and initialize the variables, sessionOptions are predefined
89  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
90  // transfers ownership
91  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options);
92 
93  // return a new session that will contain an already loaded graph def, sessionOptions are predefined
94  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
95  // transfers ownership
96  Session* createSession(const GraphDef* graphDef);
97 
98  // return a new session that will contain an already loaded graph def, sessionOptions are user defined
99  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
100  // transfers ownership
101  Session* createSession(const GraphDef* graphDef, Options& options);
102 
103  // closes a session, calls its destructor, resets the pointer, and returns true on success
104  bool closeSession(Session*& session);
105 
106  // version of the function above that accepts a const session
107  bool closeSession(const Session*& session);
108 
110 
111  // run the session with inputs and outputNames, store output tensors, and control the underlying
112  // thread pool using threadPoolOptions
113  // used for thread scheduling with custom thread pool options
114  // throws a cms exception when not successful
115  void run(Session* session,
116  const NamedTensorList& inputs,
117  const std::vector<std::string>& outputNames,
118  std::vector<Tensor>* outputs,
119  const thread::ThreadPoolOptions& threadPoolOptions);
120 
121  // version of the function above that accepts a const session
122  inline void run(const Session* session,
123  const NamedTensorList& inputs,
124  const std::vector<std::string>& outputNames,
125  std::vector<Tensor>* outputs,
126  const thread::ThreadPoolOptions& threadPoolOptions) {
127  // TF takes a non-const session in the run call which is, however, thread-safe and logically
128  // const, thus const_cast is consistent
129  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
130  }
131 
132  // run the session with inputs and outputNames, store output tensors, and control the underlying
133  // thread pool
134  // throws a cms exception when not successful
135  void run(Session* session,
136  const NamedTensorList& inputs,
137  const std::vector<std::string>& outputNames,
138  std::vector<Tensor>* outputs,
139  thread::ThreadPoolInterface* threadPool);
140 
141  // version of the function above that accepts a const session
142  inline void run(const Session* session,
143  const NamedTensorList& inputs,
144  const std::vector<std::string>& outputNames,
145  std::vector<Tensor>* outputs,
146  thread::ThreadPoolInterface* threadPool) {
147  // TF takes a non-const session in the run call which is, however, thread-safe and logically
148  // const, thus const_cast is consistent
149  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
150  }
151 
152  // run the session with inputs and outputNames, store output tensors, and control the underlying
153  // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
154  // throws a cms exception when not successful
155  void run(Session* session,
156  const NamedTensorList& inputs,
157  const std::vector<std::string>& outputNames,
158  std::vector<Tensor>* outputs,
159  const std::string& threadPoolName = "no_threads");
160 
161  // version of the function above that accepts a const session
162  inline void run(const Session* session,
163  const NamedTensorList& inputs,
164  const std::vector<std::string>& outputNames,
165  std::vector<Tensor>* outputs,
166  const std::string& threadPoolName = "no_threads") {
167  // TF takes a non-const session in the run call which is, however, thread-safe and logically
168  // const, thus const_cast is consistent
169  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
170  }
171 
172  // run the session without inputs but only outputNames, store output tensors, and control the
173  // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
174  // throws a cms exception when not successful
175  void run(Session* session,
176  const std::vector<std::string>& outputNames,
177  std::vector<Tensor>* outputs,
178  const std::string& threadPoolName = "no_threads");
179 
180  // version of the function above that accepts a const session
181  inline void run(const Session* session,
182  const std::vector<std::string>& outputNames,
183  std::vector<Tensor>* outputs,
184  const std::string& threadPoolName = "no_threads") {
185  // TF takes a non-const session in the run call which is, however, thread-safe and logically
186  // const, thus const_cast is consistent
187  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
188  }
189 
190  // struct that can be used in edm::stream modules for caching a graph and a session instance,
191  // both made atomic for cases where access is required from multiple threads
192  struct SessionCache {
193  std::atomic<GraphDef*> graph;
194  std::atomic<Session*> session;
195 
196  // constructor
198 
199  // initializing constructor, forwarding all arguments to createSession
200  template <typename... Args>
201  SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
202  createSession(graphPath, std::forward<Args>(sessionArgs)...);
203  }
204 
205  // destructor
207 
208  // create the internal graph representation from graphPath and the session object, forwarding
209  // all additional arguments to the central tensorflow::createSession
210  template <typename... Args>
211  void createSession(const std::string& graphPath, Args&&... sessionArgs) {
212  graph.store(loadGraphDef(graphPath));
213  session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
214  }
215 
216  // return a pointer to the const session
217  inline const Session* getSession() const { return session.load(); }
218 
219  // closes and removes the session as well as the graph, and sets the atomic members to nullptr's
220  void closeSession();
221  };
222 
223 } // namespace tensorflow
224 
225 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:31
void setBackend(Backend backend=Backend::cpu)
Definition: TensorFlow.cc:22
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:129
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
Definition: TensorFlow.cc:100
Backend getBackend() const
Definition: TensorFlow.h:57
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:30
std::atomic< Session * > session
Definition: TensorFlow.h:194
void createSession(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:211
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:281
bool closeSession(Session *&session)
Definition: TensorFlow.cc:243
bool checkEmptyInputs(const NamedTensorList &inputs)
Definition: TensorFlow.cc:268
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, Options &Options)
Definition: TensorFlow.cc:122
Options(Backend backend)
Definition: TensorFlow.h:38
Session * createSession()
Definition: TensorFlow.cc:146
const Session * getSession() const
Definition: TensorFlow.h:217
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:90
std::atomic< GraphDef * > graph
Definition: TensorFlow.h:193
SessionCache(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:201
void setThreading(int nThreads=1)
Definition: TensorFlow.cc:15
SessionOptions & getSessionOptions()
Definition: TensorFlow.h:55
SessionOptions _options
Definition: TensorFlow.h:36
int getNThreads() const
Definition: TensorFlow.h:56