18 _options.config.set_intra_op_parallelism_threads(nThreads);
19 _options.config.set_inter_op_parallelism_threads(nThreads);
39 (*
_options.config.mutable_device_count())[
"GPU"] = 0;
40 _options.config.mutable_gpu_options()->set_visible_device_list(
"");
46 (*
_options.config.mutable_device_count())[
"GPU"] = 1;
47 _options.config.mutable_gpu_options()->set_visible_device_list(
"0");
49 _options.config.mutable_gpu_options()->set_allow_growth(
true);
52 ex <<
"Cuda backend requested, but no NVIDIA GPU available in the job";
53 ex.
addContext(
"Calling tensorflow::setBackend()");
60 ex <<
"ROCm/Intel GPU backend requested, but TF is not compiled yet for this platform";
61 ex.
addContext(
"Calling tensorflow::setBackend()");
69 (*
_options.config.mutable_device_count())[
"GPU"] = 1;
70 _options.config.mutable_gpu_options()->set_visible_device_list(
"0");
72 _options.config.mutable_gpu_options()->set_allow_growth(
true);
75 (*
_options.config.mutable_device_count())[
"GPU"] = 0;
76 _options.config.mutable_gpu_options()->set_visible_device_list(
"");
88 setenv(
"TF_CPP_MIN_LOG_LEVEL",
level.c_str(), 0);
99 RunOptions runOptions;
100 SavedModelBundle bundle;
103 status = LoadSavedModel(
options.getSessionOptions(), runOptions, exportDir, {
tag}, &bundle);
106 <<
"error while loading metaGraphDef from '" << exportDir <<
"': " <<
status.ToString();
110 return new MetaGraphDef(bundle.meta_graph_def);
115 <<
"tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
125 GraphDef* graphDef =
new GraphDef();
131 <<
"error while loading graphDef from '" << pbFile <<
"': " <<
status.ToString();
150 throw cms::Exception(
"InvalidSession") <<
"error while creating session: " <<
status.ToString();
158 if (metaGraphDef ==
nullptr) {
159 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: metaGraphDef is nullptr";
163 if (metaGraphDef->graph_def().node_size() <= 0) {
164 throw cms::Exception(
"InvalidMetaGraphDef") <<
"error while creating session: graphDef has no nodes";
171 status = session->Create(metaGraphDef->graph_def());
174 <<
"error while attaching metaGraphDef to session: " <<
status.ToString();
179 std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
180 std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
181 std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
182 std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
183 std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
191 Tensor varFileTensor(DT_STRING, TensorShape({}));
192 varFileTensor.scalar<tensorflow::tstring>()() = varFile;
195 status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName},
nullptr);
197 throw cms::Exception(
"InvalidSession") <<
"error while restoring variables in session: " <<
status.ToString();
210 if (graphDef ==
nullptr) {
211 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef is nullptr";
215 if (graphDef->node_size() <= 0) {
216 throw cms::Exception(
"InvalidGraphDef") <<
"error while creating session: graphDef has no nodes";
224 status = session->Create(*graphDef);
228 throw cms::Exception(
"InvalidSession") <<
"error while attaching graphDef to session: " <<
status.ToString();
235 if (session ==
nullptr) {
250 auto s =
const_cast<Session*
>(session);
263 const thread::ThreadPoolOptions& threadPoolOptions) {
264 if (session ==
nullptr) {
265 throw cms::Exception(
"InvalidSession") <<
"cannot run empty session";
269 RunOptions runOptions;
282 thread::ThreadPoolInterface* threadPool) {
284 thread::ThreadPoolOptions threadPoolOptions;
285 threadPoolOptions.inter_op_threadpool = threadPool;
286 threadPoolOptions.intra_op_threadpool = threadPool;
298 if (threadPoolName ==
"no_threads") {
300 }
else if (threadPoolName ==
"tbb") {
303 }
else if (threadPoolName ==
"tensorflow") {
307 <<
"thread pool implementation'" << threadPoolName <<
"' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
327 if (
graph.load() !=
nullptr) {
329 graph.store(
nullptr);
std::vector< NamedTensor > NamedTensorList
void setBackend(Backend backend=Backend::cpu)
GraphDef * loadGraphDef(const std::string &pbFile)
MetaGraphDef * loadMetaGraphDef(const std::string &exportDir, const std::string &tag=kSavedModelTagServe)
std::atomic< Session * > session
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
bool closeSession(Session *&session)
static TBBThreadPool & instance(int nThreads=-1)
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, Options &Options)
Session * createSession()
Log< level::Info, false > LogInfo
void setLogging(const std::string &level="3")
std::atomic< GraphDef * > graph
void addContext(std::string const &context)
void setThreading(int nThreads=1)
static NoThreadPool & instance()