CMS 3D CMS Logo

TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
4  *
5  * Author: Marcel Rieger
6  */
7 
8 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
9 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 
11 #include "tensorflow/core/framework/tensor.h"
12 #include "tensorflow/core/lib/core/threadpool.h"
13 #include "tensorflow/core/lib/io/path.h"
14 #include "tensorflow/core/public/session.h"
15 #include "tensorflow/core/util/tensor_bundle/naming.h"
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/saved_model/loader.h"
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/cc/saved_model/tag_constants.h"
20 
23 
25 
26 namespace tensorflow {
27 
28  typedef std::pair<std::string, Tensor> NamedTensor;
29  typedef std::vector<NamedTensor> NamedTensorList;
30 
31  // set the tensorflow log level
32  void setLogging(const std::string& level = "3");
33 
34  // updates the config of sessionOptions so that it uses nThreads
35  void setThreading(SessionOptions& sessionOptions, int nThreads = 1);
36 
37  // deprecated
38  // updates the config of sessionOptions so that it uses nThreads, prints a deprecation warning
39  // since the threading configuration is done per run() call as of 2.1
40  void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool);
41 
42  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
43  // predefined sessionOptions
44  // transfers ownership
45  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
46 
47  // deprecated in favor of loadMetaGraphDef
48  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
49 
50  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
51  // nThreads
52  // transfers ownership
53  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir,
54  const std::string& tag = kSavedModelTagServe,
55  int nThreads = 1);
56 
57  // deprecated in favor of loadMetaGraphDef
58  MetaGraphDef* loadMetaGraph(const std::string& exportDir,
59  const std::string& tag = kSavedModelTagServe,
60  int nThreads = 1);
61 
62  // loads a graph definition saved as a protobuf file at pbFile
63  // transfers ownership
64  GraphDef* loadGraphDef(const std::string& pbFile);
65 
66  // return a new, empty session using predefined sessionOptions
67  // transfers ownership
68  Session* createSession(SessionOptions& sessionOptions);
69 
70  // return a new, empty session with nThreads
71  // transfers ownership
72  Session* createSession(int nThreads = 1);
73 
74  // return a new session that will contain an already loaded meta graph whose exportDir must be
75  // given in order to load and initialize the variables, sessionOptions are predefined
76  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
77  // transfers ownership
78  Session* createSession(const MetaGraphDef* metaGraphDef,
79  const std::string& exportDir,
80  SessionOptions& sessionOptions);
81 
82  // return a new session that will contain an already loaded meta graph whose exportDir must be given
83  // in order to load and initialize the variables, threading options are inferred from nThreads
84  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
85  // transfers ownership
86  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads = 1);
87 
88  // return a new session that will contain an already loaded graph def, sessionOptions are predefined
89  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
90  // transfers ownership
91  Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions);
92 
93  // return a new session that will contain an already loaded graph def, threading options are
94  // inferred from nThreads
95  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
96  // transfers ownership
97  Session* createSession(const GraphDef* graphDef, int nThreads = 1);
98 
99  // closes a session, calls its destructor, resets the pointer, and returns true on success
100  bool closeSession(Session*& session);
101 
102  // version of the function above that accepts a const session
103  bool closeSession(const Session*& session);
104 
105  // run the session with inputs and outputNames, store output tensors, and control the underlying
106  // thread pool using threadPoolOptions
107  // used for thread scheduling with custom thread pool options
108  // throws a cms exception when not successful
109  void run(Session* session,
110  const NamedTensorList& inputs,
111  const std::vector<std::string>& outputNames,
112  std::vector<Tensor>* outputs,
113  const thread::ThreadPoolOptions& threadPoolOptions);
114 
115  // version of the function above that accepts a const session
116  inline void run(const Session* session,
117  const NamedTensorList& inputs,
118  const std::vector<std::string>& outputNames,
119  std::vector<Tensor>* outputs,
120  const thread::ThreadPoolOptions& threadPoolOptions) {
121  // TF takes a non-const session in the run call which is, however, thread-safe and logically
122  // const, thus const_cast is consistent
123  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
124  }
125 
126  // run the session with inputs and outputNames, store output tensors, and control the underlying
127  // thread pool
128  // throws a cms exception when not successful
129  void run(Session* session,
130  const NamedTensorList& inputs,
131  const std::vector<std::string>& outputNames,
132  std::vector<Tensor>* outputs,
133  thread::ThreadPoolInterface* threadPool);
134 
135  // version of the function above that accepts a const session
136  inline void run(const Session* session,
137  const NamedTensorList& inputs,
138  const std::vector<std::string>& outputNames,
139  std::vector<Tensor>* outputs,
140  thread::ThreadPoolInterface* threadPool) {
141  // TF takes a non-const session in the run call which is, however, thread-safe and logically
142  // const, thus const_cast is consistent
143  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
144  }
145 
146  // run the session with inputs and outputNames, store output tensors, and control the underlying
147  // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
148  // throws a cms exception when not successful
149  void run(Session* session,
150  const NamedTensorList& inputs,
151  const std::vector<std::string>& outputNames,
152  std::vector<Tensor>* outputs,
153  const std::string& threadPoolName = "no_threads");
154 
155  // version of the function above that accepts a const session
156  inline void run(const Session* session,
157  const NamedTensorList& inputs,
158  const std::vector<std::string>& outputNames,
159  std::vector<Tensor>* outputs,
160  const std::string& threadPoolName = "no_threads") {
161  // TF takes a non-const session in the run call which is, however, thread-safe and logically
162  // const, thus const_cast is consistent
163  run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
164  }
165 
166  // run the session without inputs but only outputNames, store output tensors, and control the
167  // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
168  // throws a cms exception when not successful
169  void run(Session* session,
170  const std::vector<std::string>& outputNames,
171  std::vector<Tensor>* outputs,
172  const std::string& threadPoolName = "no_threads");
173 
174  // version of the function above that accepts a const session
175  inline void run(const Session* session,
176  const std::vector<std::string>& outputNames,
177  std::vector<Tensor>* outputs,
178  const std::string& threadPoolName = "no_threads") {
179  // TF takes a non-const session in the run call which is, however, thread-safe and logically
180  // const, thus const_cast is consistent
181  run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
182  }
183 
184  // struct that can be used in edm::stream modules for caching a graph and a session instance,
185  // both made atomic for cases where access is required from multiple threads
186  struct SessionCache {
187  std::atomic<GraphDef*> graph;
188  std::atomic<Session*> session;
189 
190  // constructor
192 
193  // initializing constructor, forwarding all arguments to createSession
194  template <typename... Args>
195  SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
196  createSession(graphPath, std::forward<Args>(sessionArgs)...);
197  }
198 
199  // destructor
201 
202  // create the internal graph representation from graphPath and the session object, forwarding
203  // all additional arguments to the central tensorflow::createSession
204  template <typename... Args>
205  void createSession(const std::string& graphPath, Args&&... sessionArgs) {
206  graph.store(loadGraphDef(graphPath));
207  session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
208  }
209 
210  // return a pointer to the const session
211  inline const Session* getSession() const { return session.load(); }
212 
213  // closes and removes the session as well as the graph, and sets the atomic members to nullptr's
214  void closeSession();
215  };
216 
217 } // namespace tensorflow
218 
219 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:84
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:29
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:67
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:45
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:28
std::atomic< Session * > session
Definition: TensorFlow.h:188
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:16
void createSession(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:205
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:222
bool closeSession(Session *&session)
Definition: TensorFlow.cc:197
const Session * getSession() const
Definition: TensorFlow.h:211
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:14
std::atomic< GraphDef * > graph
Definition: TensorFlow.h:187
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:28
SessionCache(const std::string &graphPath, Args &&... sessionArgs)
Definition: TensorFlow.h:195