CMS 3D CMS Logo

TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * Based on TensorFlow C++ API 1.3.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
9 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
11 
12 #include "tensorflow/core/public/session.h"
13 #include "tensorflow/core/framework/tensor.h"
14 #include "tensorflow/cc/saved_model/loader.h"
15 #include "tensorflow/cc/saved_model/tag_constants.h"
16 #include "tensorflow/cc/saved_model/constants.h"
17 #include "tensorflow/core/lib/io/path.h"
18 #include "tensorflow/core/util/tensor_bundle/naming.h"
19 
21 
22 namespace tensorflow {
23 
24  typedef std::pair<std::string, Tensor> NamedTensor;
25  typedef std::vector<NamedTensor> NamedTensorList;
26 
27  // set the tensorflow log level
28  void setLogging(const std::string& level = "3");
29 
30  // updates the config of sessionOptions so that it uses nThreads and if 1, sets the thread pool to
31  // singleThreadPool
32  void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool = "no_threads");
33 
34  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
35  // predefined sessionOptions
36  // transfers ownership
37  MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
38 
39  // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
40  // nThreads
41  // transfers ownership
42  MetaGraphDef* loadMetaGraph(const std::string& exportDir,
43  const std::string& tag = kSavedModelTagServe,
44  int nThreads = 1);
45 
46  // loads a graph definition saved as a protobuf file at pbFile
47  // transfers ownership
48  GraphDef* loadGraphDef(const std::string& pbFile);
49 
50  // return a new, empty session using predefined sessionOptions
51  // transfers ownership
52  Session* createSession(SessionOptions& sessionOptions);
53 
54  // return a new, empty session with nThreads
55  // transfers ownership
56  Session* createSession(int nThreads = 1);
57 
58  // return a new session that will contain an already loaded meta graph whose exportDir must be given
59  // in order to load and initialize the variables, sessionOptions are predefined
60  // transfers ownership
61  Session* createSession(MetaGraphDef* metaGraph, const std::string& exportDir, SessionOptions& sessionOptions);
62 
63  // return a new session that will contain an already loaded meta graph whose exportDir must be given
64  // in order to load and initialize the variables, threading options are inferred from nThreads
65  // transfers ownership
66  Session* createSession(MetaGraphDef* metaGraph, const std::string& exportDir, int nThreads = 1);
67 
68  // return a new session that will contain an already loaded graph def, sessionOptions are predefined
69  // transfers ownership
70  Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions);
71 
72  // return a new session that will contain an already loaded graph def, threading options are
73  // inferred from nThreads
74  // transfers ownership
75  Session* createSession(GraphDef* graphDef, int nThreads = 1);
76 
77  // closes a session, calls its destructor, resets the pointer, and returns true on success
78  bool closeSession(Session*& session);
79 
80  // run the session with inputs, outputNames and targetNodes, and store output tensors
81  // throws a cms exception when not successful
82  void run(Session* session,
83  const NamedTensorList& inputs,
84  const std::vector<std::string>& outputNames,
85  const std::vector<std::string>& targetNodes,
86  std::vector<Tensor>* outputs);
87 
88  // run the session with inputNames, inputTensors, outputNames and targetNodes, and store output
89  // tensors
90  // throws a cms exception when not successful
91  void run(Session* session,
92  const std::vector<std::string>& inputNames,
93  const std::vector<Tensor>& inputTensors,
94  const std::vector<std::string>& outputNames,
95  const std::vector<std::string>& targetNodes,
96  std::vector<Tensor>* outputs);
97 
98  // run the session with inputs and outputNames, and store output tensors
99  // throws a cms exception when not successful
100  void run(Session* session,
101  const NamedTensorList& inputs,
102  const std::vector<std::string>& outputNames,
103  std::vector<Tensor>* outputs);
104 
105  // run the session with inputNames, inputTensors and outputNames, and store output tensors
106  // throws a cms exception when not successful
107  void run(Session* session,
108  const std::vector<std::string>& inputNames,
109  const std::vector<Tensor>& inputTensors,
110  const std::vector<std::string>& outputNames,
111  std::vector<Tensor>* outputs);
112 
113 } // namespace tensorflow
114 
115 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:71
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:25
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:55
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:31
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:24
bool closeSession(Session *&session)
Definition: TensorFlow.cc:161
singleThreadPool
Definition: jets_cff.py:309
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:15
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:13
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, const std::vector< std::string > &targetNodes, std::vector< Tensor > *outputs)
Definition: TensorFlow.cc:176