18 sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
19 sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
23 edm::LogInfo(
"PhysicsTools/TensorFlow") <<
"setting the thread pool via tensorflow::setThreading() is deprecated";
31 RunOptions runOptions;
32 SavedModelBundle bundle;
35 status = LoadSavedModel(sessionOptions, runOptions, exportDir, {
tag}, &bundle);
38 <<
"error while loading metaGraphDef from '" << exportDir <<
"': " <<
status.ToString();
42 return new MetaGraphDef(bundle.meta_graph_def);
47 <<
"tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
54 SessionOptions sessionOptions;
62 <<
"tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
72 GraphDef* graphDef =
new GraphDef();
78 <<
"error while loading graphDef from '" << pbFile <<
"': " <<
status.ToString();
90 status = NewSession(sessionOptions, &session);
100 SessionOptions sessionOptions;
108 SessionOptions& sessionOptions) {
110 if (metaGraphDef ==
nullptr) {
111 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: metaGraphDef is nullptr";
115 if (metaGraphDef->graph_def().node_size() <= 0) {
116 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: graphDef has no nodes";
123 status = session->Create(metaGraphDef->graph_def());
126 <<
"error while attaching metaGraphDef to session: " <<
status.ToString();
131 std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
132 std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
133 std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
134 std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
135 std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
143 Tensor varFileTensor(DT_STRING, TensorShape({}));
144 varFileTensor.scalar<tensorflow::tstring>()() = varFile;
147 status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName},
nullptr);
149 throw cms::Exception(
"InvalidSession") <<
"error while restoring variables in session: " <<
status.ToString();
157 SessionOptions sessionOptions;
160 return createSession(metaGraphDef, exportDir, sessionOptions);
165 if (graphDef ==
nullptr) {
166 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef is nullptr";
170 if (graphDef->node_size() <= 0) {
171 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef has no nodes";
179 status = session->Create(*graphDef);
183 throw cms::Exception(
"InvalidSession") <<
"error while attaching graphDef to session: " <<
status.ToString();
191 SessionOptions sessionOptions;
198 if (session ==
nullptr) {
213 auto s =
const_cast<Session*
>(session);
226 const thread::ThreadPoolOptions& threadPoolOptions) {
227 if (session ==
nullptr) {
228 throw cms::Exception(
"InvalidSession") <<
"cannot run empty session";
232 RunOptions runOptions;
245 thread::ThreadPoolInterface* threadPool) {
247 thread::ThreadPoolOptions threadPoolOptions;
248 threadPoolOptions.inter_op_threadpool = threadPool;
249 threadPoolOptions.intra_op_threadpool = threadPool;
261 if (threadPoolName ==
"no_threads") {
263 }
else if (threadPoolName ==
"tbb") {
266 }
else if (threadPoolName ==
"tensorflow") {
270 <<
"thread pool implementation'" << threadPoolName <<
"' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
290 if (
graph.load() !=
nullptr) {
292 graph.store(
nullptr);
Session * createSession(SessionOptions &sessionOptions)
std::vector< NamedTensor > NamedTensorList
GraphDef * loadGraphDef(const std::string &pbFile)
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
std::atomic< Session * > session
void setThreading(SessionOptions &sessionOptions, int nThreads=1)
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
bool closeSession(Session *&session)
static TBBThreadPool & instance(int nThreads=-1)
Log< level::Info, false > LogInfo
void setLogging(const std::string &level="3")
std::atomic< GraphDef * > graph
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
static NoThreadPool & instance()