CMS 3D CMS Logo

TensorFlow.cc
Go to the documentation of this file.
1 /*
2  * TensorFlow interface helpers.
3  * Based on TensorFlow C++ API 1.3.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
10 
11 namespace tensorflow
12 {
13 
15 {
16  setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0);
17 }
18 
19 void setThreading(SessionOptions& sessionOptions, int nThreads,
21 {
22  // set number of threads used for intra and inter operation communication
23  sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
24  sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
25 
26  // when exactly one thread is requested use a custom thread pool
27  if (nThreads == 1 && !singleThreadPool.empty())
28  {
29  // check for known thread pools
30  if (singleThreadPool != "no_threads" && singleThreadPool != "tbb")
31  {
32  throw cms::Exception("UnknownThreadPool")
33  << "thread pool '" << singleThreadPool << "' unknown, use 'no_threads' or 'tbb'";
34  }
35  sessionOptions.target = singleThreadPool;
36  }
37 }
38 
39 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag,
40  SessionOptions& sessionOptions)
41 {
42  // objects to load the graph
43  Status status;
44  RunOptions runOptions;
45  SavedModelBundle bundle;
46 
47  // load the model
48  status = LoadSavedModel(sessionOptions, runOptions, exportDir, { tag }, &bundle);
49  if (!status.ok())
50  {
51  throw cms::Exception("InvalidMetaGraph")
52  << "error while loading meta graph: " << status.ToString();
53  }
54 
55  // return a copy of the graph
56  return new MetaGraphDef(bundle.meta_graph_def);
57 }
58 
59 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads)
60 {
61  // create session options and set thread options
62  SessionOptions sessionOptions;
63  setThreading(sessionOptions, nThreads);
64 
65  return loadMetaGraph(exportDir, tag, sessionOptions);
66 }
67 
68 GraphDef* loadGraphDef(const std::string& pbFile)
69 {
70  // objects to load the graph
71  Status status;
72 
73  // load it
74  GraphDef* graphDef = new GraphDef();
75  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
76 
77  // check for success
78  if (!status.ok())
79  {
80  throw cms::Exception("InvalidGraphDef")
81  << "error while loading graph def: " << status.ToString();
82  }
83 
84  return graphDef;
85 }
86 
87 Session* createSession(SessionOptions& sessionOptions)
88 {
89  // objects to create the session
90  Status status;
91 
92  // create a new, empty session
93  Session* session = nullptr;
94  status = NewSession(sessionOptions, &session);
95  if (!status.ok())
96  {
97  throw cms::Exception("InvalidSession")
98  << "error while creating session: " << status.ToString();
99  }
100 
101  return session;
102 }
103 
105 {
106  // create session options and set thread options
107  SessionOptions sessionOptions;
108  setThreading(sessionOptions, nThreads);
109 
110  return createSession(sessionOptions);
111 }
112 
113 Session* createSession(MetaGraphDef* metaGraph, const std::string& exportDir,
114  SessionOptions& sessionOptions)
115 {
116  Session* session = createSession(sessionOptions);
117 
118  // add the graph def from the meta graph
119  Status status;
120  status = session->Create(metaGraph->graph_def());
121  if (!status.ok())
122  {
123  throw cms::Exception("InvalidSession")
124  << "error while attaching meta graph to session: " << status.ToString();
125  }
126 
127  // restore variables using the variable and index files in the export directory
128  // first, find names and paths
129  std::string varFileTensorName = metaGraph->saver_def().filename_tensor_name();
130  std::string restoreOpName = metaGraph->saver_def().restore_op_name();
131  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
132  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
133  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
134 
135  // when the index file is missing, there's nothing to do
136  if (!Env::Default()->FileExists(indexFile).ok())
137  {
138  return session;
139  }
140 
141  // create a tensor to store the variable file
142  Tensor varFileTensor(DT_STRING, TensorShape({}));
143  varFileTensor.scalar<std::string>()() = varFile;
144 
145  // run the restore op
146  status = session->Run({ { varFileTensorName, varFileTensor } }, {}, { restoreOpName }, nullptr);
147  if (!status.ok())
148  {
149  throw cms::Exception("InvalidSession")
150  << "error while restoring variables in session: " << status.ToString();
151  }
152 
153  return session;
154 }
155 
156 Session* createSession(MetaGraphDef* metaGraph, const std::string& exportDir, int nThreads)
157 {
158  // create session options and set thread options
159  SessionOptions sessionOptions;
160  setThreading(sessionOptions, nThreads);
161 
162  return createSession(metaGraph, exportDir, sessionOptions);
163 }
164 
165 Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions)
166 {
167  // create a new, empty session
168  Session* session = createSession(sessionOptions);
169 
170  // add the graph def
171  Status status;
172  status = session->Create(*graphDef);
173 
174  // check for success
175  if (!status.ok())
176  {
177  throw cms::Exception("InvalidSession")
178  << "error while attaching graph def to session: " << status.ToString();
179  }
180 
181  return session;
182 }
183 
184 Session* createSession(GraphDef* graphDef, int nThreads)
185 {
186  // create session options and set thread options
187  SessionOptions sessionOptions;
188  setThreading(sessionOptions, nThreads);
189 
190  return createSession(graphDef, sessionOptions);
191 }
192 
194 {
195  if (session == nullptr)
196  {
197  return true;
198  }
199 
200  // close and delete the session
201  Status status = session->Close();
202  delete session;
203 
204  // reset the pointer
205  session = nullptr;
206 
207  return status.ok();
208 }
209 
211  const std::vector<std::string>& outputNames, const std::vector<std::string>& targetNodes,
212  std::vector<Tensor>* outputs)
213 {
214  if (session == nullptr)
215  {
216  throw cms::Exception("InvalidSession") << "cannot run empty session";
217  }
218 
219  // run and check the status
220  Status status = session->Run(inputs, outputNames, targetNodes, outputs);
221  if (!status.ok())
222  {
223  throw cms::Exception("InvalidRun")
224  << "error while running session: " << status.ToString();
225  }
226 }
227 
228 void run(Session* session, const std::vector<std::string>& inputNames,
229  const std::vector<Tensor>& inputTensors, const std::vector<std::string>& outputNames,
230  const std::vector<std::string>& targetNodes, std::vector<Tensor>* outputs)
231 {
232  if (inputNames.size() != inputTensors.size())
233  {
234  throw cms::Exception("InvalidInput") << "numbers of input names and tensors not equal";
235  }
236 
238  for (size_t i = 0; i < inputNames.size(); i++)
239  {
240  inputs.push_back(NamedTensor(inputNames[i], inputTensors[i]));
241  }
242 
243  run(session, inputs, outputNames, targetNodes, outputs);
244 }
245 
247  const std::vector<std::string>& outputNames, std::vector<Tensor>* outputs)
248 {
249  run(session, inputs, outputNames, {}, outputs);
250 }
251 
252 void run(Session* session, const std::vector<std::string>& inputNames,
253  const std::vector<Tensor>& inputTensors, const std::vector<std::string>& outputNames,
254  std::vector<Tensor>* outputs)
255 {
256  run(session, inputNames, inputTensors, outputNames, {}, outputs);
257 }
258 
259 } // namespace tensorflow
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:26
#define Default
Definition: vmac.h:110
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:39
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:25
bool closeSession(Session *&session)
Definition: TensorFlow.cc:193
singleThreadPool
Definition: jets_cff.py:317
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
void setLogging(const std::string &level="3")
Definition: TensorFlow.cc:14
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, const std::vector< std::string > &targetNodes, std::vector< Tensor > *outputs)
Definition: TensorFlow.cc:210