CMS 3D CMS Logo

TensorFlow.h
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * Based on TensorFlow C++ API 1.3.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
9 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
10 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
11 
12 #include "tensorflow/core/public/session.h"
13 #include "tensorflow/core/framework/tensor.h"
14 #include "tensorflow/cc/saved_model/loader.h"
15 #include "tensorflow/cc/saved_model/tag_constants.h"
16 #include "tensorflow/cc/saved_model/constants.h"
17 #include "tensorflow/core/lib/io/path.h"
18 #include "tensorflow/core/util/tensor_bundle/naming.h"
19 
21 
22 namespace tensorflow
23 {
24 
25 typedef std::pair<std::string, Tensor> NamedTensor;
26 typedef std::vector<NamedTensor> NamedTensorList;
27 
28 // set the tensorflow log level
29 void setLogging(const std::string& level = "3");
30 
31 // updates the config of sessionOptions so that it uses nThreads and if 1, sets the thread pool to
32 // singleThreadPool
33 void setThreading(SessionOptions& sessionOptions, int nThreads,
34  const std::string& singleThreadPool = "no_threads");
35 
36 // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
37 // predefined sessionOptions
38 // transfers ownership
39 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag,
40  SessionOptions& sessionOptions);
41 
42 // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
43 // nThreads
44 // transfers ownership
45 MetaGraphDef* loadMetaGraph(const std::string& exportDir,
46  const std::string& tag = kSavedModelTagServe, int nThreads = 1);
47 
48 // loads a graph definition saved as a protobuf file at pbFile
49 // transfers ownership
50 GraphDef* loadGraphDef(const std::string& pbFile);
51 
52 // return a new, empty session using predefined sessionOptions
53 // transfers ownership
54 Session* createSession(SessionOptions& sessionOptions);
55 
56 // return a new, empty session with nThreads
57 // transfers ownership
59 
60 // return a new session that will contain an already loaded meta graph whose exportDir must be given
61 // in order to load and initialize the variables, sessionOptions are predefined
62 // transfers ownership
63 Session* createSession(MetaGraphDef* metaGraph, const std::string& exportDir,
64  SessionOptions& sessionOptions);
65 
66 // return a new session that will contain an already loaded meta graph whose exportDir must be given
67 // in order to load and initialize the variables, threading options are inferred from nThreads
68 // transfers ownership
69 Session* createSession(MetaGraphDef* metaGraph, const std::string& exportDir, int nThreads = 1);
70 
71 // return a new session that will contain an already loaded graph def, sessionOptions are predefined
72 // transfers ownership
73 Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions);
74 
75 // return a new session that will contain an already loaded graph def, threading options are
76 // inferred from nThreads
77 // transfers ownership
78 Session* createSession(GraphDef* graphDef, int nThreads = 1);
79 
80 // closes a session, calls its destructor, resets the pointer, and returns true on success
82 
83 // run the session with inputs, outputNames and targetNodes, and store output tensors
84 // throws a cms exception when not successful
85 void run(Session* session, const NamedTensorList& inputs,
86  const std::vector<std::string>& outputNames, const std::vector<std::string>& targetNodes,
87  std::vector<Tensor>* outputs);
88 
89 // run the session with inputNames, inputTensors, outputNames and targetNodes, and store output
90 // tensors
91 // throws a cms exception when not successful
92 void run(Session* session, const std::vector<std::string>& inputNames,
93  const std::vector<Tensor>& inputTensors, const std::vector<std::string>& outputNames,
94  const std::vector<std::string>& targetNodes, std::vector<Tensor>* outputs);
95 
96 // run the session with inputs and outputNames, and store output tensors
97 // throws a cms exception when not successful
98 void run(Session* session, const NamedTensorList& inputs,
99  const std::vector<std::string>& outputNames, std::vector<Tensor>* outputs);
100 
101 // run the session with inputNames, inputTensors and outputNames, and store output tensors
102 // throws a cms exception when not successful
103 void run(Session* session, const std::vector<std::string>& inputNames,
104  const std::vector<Tensor>& inputTensors, const std::vector<std::string>& outputNames,
105  std::vector<Tensor>* outputs);
106 
107 } // namespace tensorflow
108 
109 #endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:26
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:39
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:25
bool closeSession(Session *&session)
Definition: TensorFlow.cc:193
singleThreadPool
Definition: jets_cff.py:327
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:14
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, const std::vector< std::string > &targetNodes, std::vector< Tensor > *outputs)
Definition: TensorFlow.cc:210