CMS 3D CMS Logo

TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * Based on TensorFlow C++ API 2.1.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
9 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
11 
12 #include "tensorflow/core/framework/tensor.h"
13 #include "tensorflow/core/lib/core/threadpool.h"
14 #include "tensorflow/core/lib/io/path.h"
15 #include "tensorflow/core/public/session.h"
16 #include "tensorflow/core/util/tensor_bundle/naming.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/saved_model/loader.h"
19 #include "tensorflow/cc/saved_model/constants.h"
20 #include "tensorflow/cc/saved_model/tag_constants.h"
21 
24 
26 
27 namespace tensorflow {
28 
29  typedef std::pair<std::string, Tensor> NamedTensor;
30  typedef std::vector<NamedTensor> NamedTensorList;
31 
32  // set the tensorflow log level
33  void setLogging(const std::string& level = "3");
34 
35  // updates the config of sessionOptions so that it uses nThreads
36  void setThreading(SessionOptions& sessionOptions, int nThreads = 1);
37 
38  // deprecated
39  // updates the config of sessionOptions so that it uses nThreads, prints a deprecation warning
40  // since the threading configuration is done per run() call as of 2.1
41  void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool);
42 
43  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
44  // predefined sessionOptions
45  // transfers ownership
46  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
47 
48  // deprecated in favor of loadMetaGraphDef
49  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
50 
51  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
52  // nThreads
53  // transfers ownership
54  MetaGraphDef* loadMetaGraphDef(const std::string& exportDir,
55  const std::string& tag = kSavedModelTagServe,
56  int nThreads = 1);
57 
58  // deprecated in favor of loadMetaGraphDef
59  MetaGraphDef* loadMetaGraph(const std::string& exportDir,
60  const std::string& tag = kSavedModelTagServe,
61  int nThreads = 1);
62 
63  // loads a graph definition saved as a protobuf file at pbFile
64  // transfers ownership
65  GraphDef* loadGraphDef(const std::string& pbFile);
66 
67  // return a new, empty session using predefined sessionOptions
68  // transfers ownership
69  Session* createSession(SessionOptions& sessionOptions);
70 
71  // return a new, empty session with nThreads
72  // transfers ownership
73  Session* createSession(int nThreads = 1);
74 
75  // return a new session that will contain an already loaded meta graph whose exportDir must be
76  // given in order to load and initialize the variables, sessionOptions are predefined
77  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
78  // transfers ownership
79  Session* createSession(MetaGraphDef* metaGraphDef, const std::string& exportDir, SessionOptions& sessionOptions);
80 
81  // return a new session that will contain an already loaded meta graph whose exportDir must be given
82  // in order to load and initialize the variables, threading options are inferred from nThreads
83  // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
84  // transfers ownership
85  Session* createSession(MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads = 1);
86 
87  // return a new session that will contain an already loaded graph def, sessionOptions are predefined
88  // an error is thrown when graphDef is a nullptr or when the grah has no nodes
89  // transfers ownership
90  Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions);
91 
92  // return a new session that will contain an already loaded graph def, threading options are
93  // inferred from nThreads
94  // an error is thrown when graphDef is a nullptr or when the grah has no nodes
95  // transfers ownership
96  Session* createSession(GraphDef* graphDef, int nThreads = 1);
97 
98  // closes a session, calls its destructor, resets the pointer, and returns true on success
99  bool closeSession(Session*& session);
100 
101  // run the session with inputs and outputNames, store output tensors, and control the underlying
102  // thread pool using threadPoolOptions
103  // used for thread scheduling with custom thread pool options
104  // throws a cms exception when not successful
105  void run(Session* session,
106  const NamedTensorList& inputs,
107  const std::vector<std::string>& outputNames,
108  std::vector<Tensor>* outputs,
109  const thread::ThreadPoolOptions& threadPoolOptions);
110 
111  // run the session with inputs and outputNames, store output tensors, and control the underlying
112  // thread pool
113  // throws a cms exception when not successful
114  void run(Session* session,
115  const NamedTensorList& inputs,
116  const std::vector<std::string>& outputNames,
117  std::vector<Tensor>* outputs,
118  thread::ThreadPoolInterface* threadPool);
119 
120  // run the session with inputs and outputNames, store output tensors, and control the underlying
121  // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
122  // throws a cms exception when not successful
123  void run(Session* session,
124  const NamedTensorList& inputs,
125  const std::vector<std::string>& outputNames,
126  std::vector<Tensor>* outputs,
127  const std::string& threadPoolName = "no_threads");
128 
129  // run the session without inputs but only outputNames, store output tensors, and control the
130  // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
131  // throws a cms exception when not successful
132  void run(Session* session,
133  const std::vector<std::string>& outputNames,
134  std::vector<Tensor>* outputs,
135  const std::string& threadPoolName = "no_threads");
136 
137 } // namespace tensorflow
138 
139 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
personalPlayback.level
level
Definition: personalPlayback.py:22
tensorflow::createSession
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
jets_cff.singleThreadPool
singleThreadPool
Definition: jets_cff.py:325
tensorflow::NamedTensor
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:29
PatBasicFWLiteJetAnalyzer_Selector_cfg.outputs
outputs
Definition: PatBasicFWLiteJetAnalyzer_Selector_cfg.py:48
tensorflow::setThreading
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
Definition: TensorFlow.cc:17
tensorflow::closeSession
bool closeSession(Session *&session)
Definition: TensorFlow.cc:196
Session
GlobalPosition_Frontier_DevDB_cff.tag
tag
Definition: GlobalPosition_Frontier_DevDB_cff.py:11
runTheMatrix.nThreads
nThreads
Definition: runTheMatrix.py:361
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
NoThreadPool.h
tensorflow::NamedTensorList
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:30
PixelMapPlotter.inputs
inputs
Definition: PixelMapPlotter.py:490
tensorflow::setLogging
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:15
TBBThreadPool.h
tensorflow::loadGraphDef
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
tensorflow::loadMetaGraph
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:46
tensorflow::run
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:211
tensorflow
Definition: NoThreadPool.h:18
Exception.h
tensorflow::loadMetaGraphDef
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:29
jets_cff.outputNames
outputNames
Definition: jets_cff.py:322