CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * Based on TensorFlow C++ API 2.1.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
9 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
11 
12 #include "tensorflow/core/framework/tensor.h"
13 #include "tensorflow/core/lib/core/threadpool.h"
14 #include "tensorflow/core/lib/io/path.h"
15 #include "tensorflow/core/public/session.h"
16 #include "tensorflow/core/util/tensor_bundle/naming.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/saved_model/loader.h"
19 #include "tensorflow/cc/saved_model/constants.h"
20 #include "tensorflow/cc/saved_model/tag_constants.h"
21 
24 
26 
27 namespace tensorflow {
28 
29  typedef std::pair<std::string, Tensor> NamedTensor;
30  typedef std::vector<NamedTensor> NamedTensorList;
31 
32  // set the tensorflow log level
33  void setLogging(const std::string& level = "3");
34 
35  // updates the config of sessionOptions so that it uses nThreads
36  void setThreading(SessionOptions& sessionOptions, int nThreads = 1);
37 
38  // deprecated
39  // updates the config of sessionOptions so that it uses nThreads, prints a deprecation warning
40  // since the threading configuration is done per run() call as of 2.1
41  void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool);
42 
43  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
44  // predefined sessionOptions
45  // transfers ownership
46  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
47 
48  // deprecated in favor of loadMetaGraphDef
49  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
50 
51  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
52  // nThreads
53  // transfers ownership
54  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir,
55  const std::string& tag = kSavedModelTagServe,
56  int nThreads = 1);
57 
58  // deprecated in favor of loadMetaGraphDef
59  MetaGraphDef* loadMetaGraph(const std::string& exportDir,
60  const std::string& tag = kSavedModelTagServe,
61  int nThreads = 1);
62 
63  // loads a graph definition saved as a protobuf file at pbFile
64  // transfers ownership
65  GraphDef* loadGraphDef(const std::string& pbFile);
66 
67  // return a new, empty session using predefined sessionOptions
68  // transfers ownership
69  Session* createSession(SessionOptions& sessionOptions);
70 
71  // return a new, empty session with nThreads
72  // transfers ownership
73  Session* createSession(int nThreads = 1);
74 
75  // return a new session that will contain an already loaded meta graph whose exportDir must be
76  // given in order to load and initialize the variables, sessionOptions are predefined
77  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
78  // transfers ownership
79  Session* createSession(const MetaGraphDef* metaGraphDef,
80  const std::string& exportDir,
81  SessionOptions& sessionOptions);
82 
83  // return a new session that will contain an already loaded meta graph whose exportDir must be given
84  // in order to load and initialize the variables, threading options are inferred from nThreads
85  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
86  // transfers ownership
87  Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads = 1);
88 
89  // return a new session that will contain an already loaded graph def, sessionOptions are predefined
90  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
91  // transfers ownership
92  Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions);
93 
94  // return a new session that will contain an already loaded graph def, threading options are
95  // inferred from nThreads
96  // an error is thrown when graphDef is a nullptr or when the graph has no nodes
97  // transfers ownership
98  Session* createSession(const GraphDef* graphDef, int nThreads = 1);
99 
100  // closes a session, calls its destructor, resets the pointer, and returns true on success
101  bool closeSession(Session*& session);
102 
103  // run the session with inputs and outputNames, store output tensors, and control the underlying
104  // thread pool using threadPoolOptions
105  // used for thread scheduling with custom thread pool options
106  // throws a cms exception when not successful
107  void run(Session* session,
108  const NamedTensorList& inputs,
109  const std::vector<std::string>& outputNames,
110  std::vector<Tensor>* outputs,
111  const thread::ThreadPoolOptions& threadPoolOptions);
112 
113  // run the session with inputs and outputNames, store output tensors, and control the underlying
114  // thread pool
115  // throws a cms exception when not successful
116  void run(Session* session,
117  const NamedTensorList& inputs,
118  const std::vector<std::string>& outputNames,
119  std::vector<Tensor>* outputs,
120  thread::ThreadPoolInterface* threadPool);
121 
122  // run the session with inputs and outputNames, store output tensors, and control the underlying
123  // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
124  // throws a cms exception when not successful
125  void run(Session* session,
126  const NamedTensorList& inputs,
127  const std::vector<std::string>& outputNames,
128  std::vector<Tensor>* outputs,
129  const std::string& threadPoolName = "no_threads");
130 
131  // run the session without inputs but only outputNames, store output tensors, and control the
132  // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
133  // throws a cms exception when not successful
134  void run(Session* session,
135  const std::vector<std::string>& outputNames,
136  std::vector<Tensor>* outputs,
137  const std::string& threadPoolName = "no_threads");
138 
139 } // namespace tensorflow
140 
141 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:30
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:46
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:29
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:213
bool closeSession(Session *&session)
Definition: TensorFlow.cc:198
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:15
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
tuple level
Definition: testEve_cfg.py:47