19 sessionOptions.config.set_intra_op_parallelism_threads(
nThreads);
20 sessionOptions.config.set_inter_op_parallelism_threads(
nThreads);
24 edm::LogInfo(
"PhysicsTools/TensorFlow") <<
"setting the thread pool via tensorflow::setThreading() is deprecated";
32 RunOptions runOptions;
33 SavedModelBundle bundle;
36 status = LoadSavedModel(sessionOptions, runOptions, exportDir, {
tag}, &bundle);
39 <<
"error while loading metaGraphDef from '" << exportDir <<
"': " <<
status.ToString();
43 return new MetaGraphDef(bundle.meta_graph_def);
48 <<
"tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
55 SessionOptions sessionOptions;
63 <<
"tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
73 GraphDef* graphDef =
new GraphDef();
74 status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
79 <<
"error while loading graphDef from '" << pbFile <<
"': " <<
status.ToString();
91 status = NewSession(sessionOptions, &session);
101 SessionOptions sessionOptions;
109 if (metaGraphDef ==
nullptr) {
110 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: metaGraphDef is nullptr";
114 if (metaGraphDef->graph_def().node_size() <= 0) {
115 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: graphDef has no nodes";
122 status = session->Create(metaGraphDef->graph_def());
125 <<
"error while attaching metaGraphDef to session: " <<
status.ToString();
130 std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
131 std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
132 std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
133 std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
134 std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
137 if (!Env::Default()->FileExists(indexFile).
ok()) {
142 Tensor varFileTensor(DT_STRING, TensorShape({}));
143 varFileTensor.scalar<tensorflow::tstring>()() = varFile;
146 status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName},
nullptr);
148 throw cms::Exception(
"InvalidSession") <<
"error while restoring variables in session: " <<
status.ToString();
156 SessionOptions sessionOptions;
159 return createSession(metaGraphDef, exportDir, sessionOptions);
164 if (graphDef ==
nullptr) {
165 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef is nullptr";
169 if (graphDef->node_size() <= 0) {
170 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef has no nodes";
178 status = session->Create(*graphDef);
182 throw cms::Exception(
"InvalidSession") <<
"error while attaching graphDef to session: " <<
status.ToString();
190 SessionOptions sessionOptions;
197 if (session ==
nullptr) {
215 const thread::ThreadPoolOptions& threadPoolOptions) {
216 if (session ==
nullptr) {
217 throw cms::Exception(
"InvalidSession") <<
"cannot run empty session";
221 RunOptions runOptions;
234 thread::ThreadPoolInterface* threadPool) {
236 thread::ThreadPoolOptions threadPoolOptions;
237 threadPoolOptions.inter_op_threadpool = threadPool;
238 threadPoolOptions.intra_op_threadpool = threadPool;
250 if (threadPoolName ==
"no_threads") {
252 }
else if (threadPoolName ==
"tbb") {
255 }
else if (threadPoolName ==
"tensorflow") {
259 <<
"thread pool implementation'" << threadPoolName <<
"' unknown, use 'no_threads', 'tbb', or 'tensorflow'";