13 namespace tensorflow {
19 sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
20 sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
24 edm::LogInfo(
"PhysicsTools/TensorFlow") <<
"setting the thread pool via tensorflow::setThreading() is deprecated";
32 RunOptions runOptions;
33 SavedModelBundle bundle;
36 status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
39 <<
"error while loading metaGraphDef from '" << exportDir <<
"': " << status.ToString();
43 return new MetaGraphDef(bundle.meta_graph_def);
48 <<
"tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
55 SessionOptions sessionOptions;
63 <<
"tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
73 GraphDef* graphDef =
new GraphDef();
74 status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
79 <<
"error while loading graphDef from '" << pbFile <<
"': " << status.ToString();
91 status = NewSession(sessionOptions, &session);
93 throw cms::Exception(
"InvalidSession") <<
"error while creating session: " << status.ToString();
101 SessionOptions sessionOptions;
109 SessionOptions& sessionOptions) {
111 if (metaGraphDef ==
nullptr) {
112 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: metaGraphDef is nullptr";
116 if (metaGraphDef->graph_def().node_size() <= 0) {
117 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: graphDef has no nodes";
124 status = session->Create(metaGraphDef->graph_def());
127 <<
"error while attaching metaGraphDef to session: " << status.ToString();
132 std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
133 std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
134 std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
135 std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
136 std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
139 if (!Env::Default()->FileExists(indexFile).
ok()) {
144 Tensor varFileTensor(DT_STRING, TensorShape({}));
145 varFileTensor.scalar<tensorflow::tstring>()() = varFile;
148 status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName},
nullptr);
150 throw cms::Exception(
"InvalidSession") <<
"error while restoring variables in session: " << status.ToString();
158 SessionOptions sessionOptions;
161 return createSession(metaGraphDef, exportDir, sessionOptions);
166 if (graphDef ==
nullptr) {
167 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef is nullptr";
171 if (graphDef->node_size() <= 0) {
172 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef has no nodes";
180 status = session->Create(*graphDef);
184 throw cms::Exception(
"InvalidSession") <<
"error while attaching graphDef to session: " << status.ToString();
192 SessionOptions sessionOptions;
199 if (session ==
nullptr) {
215 const std::vector<std::string>& outputNames,
216 std::vector<Tensor>* outputs,
217 const thread::ThreadPoolOptions& threadPoolOptions) {
218 if (session ==
nullptr) {
219 throw cms::Exception(
"InvalidSession") <<
"cannot run empty session";
223 RunOptions runOptions;
226 Status status = session->Run(runOptions, inputs, outputNames, {}, outputs,
nullptr, threadPoolOptions);
228 throw cms::Exception(
"InvalidRun") <<
"error while running session: " << status.ToString();
234 const std::vector<std::string>& outputNames,
235 std::vector<Tensor>* outputs,
236 thread::ThreadPoolInterface* threadPool) {
238 thread::ThreadPoolOptions threadPoolOptions;
239 threadPoolOptions.inter_op_threadpool = threadPool;
240 threadPoolOptions.intra_op_threadpool = threadPool;
243 run(session, inputs, outputNames, outputs, threadPoolOptions);
248 const std::vector<std::string>& outputNames,
249 std::vector<Tensor>* outputs,
252 if (threadPoolName ==
"no_threads") {
254 }
else if (threadPoolName ==
"tbb") {
257 }
else if (threadPoolName ==
"tensorflow") {
258 run(session, inputs, outputNames, outputs,
nullptr);
261 <<
"thread pool implementation'" << threadPoolName <<
"' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
266 const std::vector<std::string>& outputNames,
267 std::vector<Tensor>* outputs,
269 run(session, {}, outputNames, outputs, threadPoolName);
Session * createSession(SessionOptions &sessionOptions)
std::vector< NamedTensor > NamedTensorList
GraphDef * loadGraphDef(const std::string &pbFile)
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
bool closeSession(Session *&session)
static TBBThreadPool & instance(int nThreads=-1)
Log< level::Info, false > LogInfo
void setLogging(const std::string &level="3")
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
static NoThreadPool & instance()