16 setenv(
"TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0);
23 sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
24 sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
27 if (nThreads == 1 && !singleThreadPool.empty())
30 if (singleThreadPool !=
"no_threads" && singleThreadPool !=
"tbb")
33 <<
"thread pool '" << singleThreadPool <<
"' unknown, use 'no_threads' or 'tbb'";
40 SessionOptions& sessionOptions)
44 RunOptions runOptions;
45 SavedModelBundle bundle;
48 status = LoadSavedModel(sessionOptions, runOptions, exportDir, { tag }, &bundle);
52 <<
"error while loading meta graph: " << status.ToString();
56 return new MetaGraphDef(bundle.meta_graph_def);
62 SessionOptions sessionOptions;
74 GraphDef* graphDef =
new GraphDef();
75 status = ReadBinaryProto(
Env::Default(), pbFile, graphDef);
81 <<
"error while loading graph def: " << status.ToString();
94 status = NewSession(sessionOptions, &session);
98 <<
"error while creating session: " << status.ToString();
107 SessionOptions sessionOptions;
114 SessionOptions& sessionOptions)
120 status = session->Create(metaGraph->graph_def());
124 <<
"error while attaching meta graph to session: " << status.ToString();
129 std::string varFileTensorName = metaGraph->saver_def().filename_tensor_name();
130 std::string restoreOpName = metaGraph->saver_def().restore_op_name();
131 std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
132 std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
133 std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
142 Tensor varFileTensor(DT_STRING, TensorShape({}));
146 status = session->Run({ { varFileTensorName, varFileTensor } }, {}, { restoreOpName },
nullptr);
150 <<
"error while restoring variables in session: " << status.ToString();
159 SessionOptions sessionOptions;
172 status = session->Create(*graphDef);
178 <<
"error while attaching graph def to session: " << status.ToString();
187 SessionOptions sessionOptions;
195 if (session ==
nullptr)
211 const std::vector<std::string>&
outputNames,
const std::vector<std::string>& targetNodes,
214 if (session ==
nullptr)
216 throw cms::Exception(
"InvalidSession") <<
"cannot run empty session";
220 Status status = session->Run(inputs, outputNames, targetNodes, outputs);
224 <<
"error while running session: " << status.ToString();
229 const std::vector<Tensor>& inputTensors,
const std::vector<std::string>&
outputNames,
230 const std::vector<std::string>& targetNodes, std::vector<Tensor>*
outputs)
232 if (inputNames.size() != inputTensors.size())
234 throw cms::Exception(
"InvalidInput") <<
"numbers of input names and tensors not equal";
238 for (
size_t i = 0;
i < inputNames.size();
i++)
240 inputs.push_back(
NamedTensor(inputNames[
i], inputTensors[i]));
243 run(session, inputs, outputNames, targetNodes, outputs);
249 run(session, inputs, outputNames, {},
outputs);
253 const std::vector<Tensor>& inputTensors,
const std::vector<std::string>&
outputNames,
256 run(session, inputNames, inputTensors, outputNames, {},
outputs);
Session * createSession(SessionOptions &sessionOptions)
std::vector< NamedTensor > NamedTensorList
GraphDef * loadGraphDef(const std::string &pbFile)
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
std::pair< std::string, Tensor > NamedTensor
bool closeSession(Session *&session)
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
void setLogging(const std::string &level="3")
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, const std::vector< std::string > &targetNodes, std::vector< Tensor > *outputs)