CMS 3D CMS Logo

TBBSession.cc
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 //NOTE: The memory layout of the Node class changes depending on if NDEBUG was
16 // set when tensorflow was compiled. The reason is Node class holds two edgeset
17 // class instances and edgeset adds a member data if NDEBUG is set
18 
19 /*
20 This file is an adaptation of the original direct_session.cc file located at
21 https://github.com/tensorflow/tensorflow/blob/v1.3.0/tensorflow/core/common_runtime/direct_session.cc
22 to meet the demands of the software environment developed and used by the CMS collaboration.
23 
24 Changes with respect to the original code are documented in the TBBSession.h header file.
25 */
26 
27 #if !defined(NDEBUG)
28 #define NDEBUG 1
29 #endif
30 
31 #include "TBBSession.h"
32 
33 #include <atomic>
34 #include <string>
35 #include <vector>
36 
39 
40 #include "tbb/task_group.h"
41 
42 #include "tensorflow/core/common_runtime/constant_folding.h"
43 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
44 #include "tensorflow/core/common_runtime/device_factory.h"
45 #include "tensorflow/core/common_runtime/executor.h"
46 #include "tensorflow/core/common_runtime/function.h"
47 #include "tensorflow/core/common_runtime/graph_optimizer.h"
48 #include "tensorflow/core/common_runtime/memory_types.h"
49 #include "tensorflow/core/common_runtime/optimization_registry.h"
50 #include "tensorflow/core/common_runtime/simple_placer.h"
51 #include "tensorflow/core/common_runtime/step_stats_collector.h"
52 #include "tensorflow/core/framework/function.h"
53 #include "tensorflow/core/framework/graph.pb_text.h"
54 #include "tensorflow/core/framework/graph.pb.h"
55 #include "tensorflow/core/framework/graph_def_util.h"
56 #include "tensorflow/core/framework/log_memory.h"
57 #include "tensorflow/core/framework/node_def.pb.h"
58 #include "tensorflow/core/framework/tensor.h"
59 #include "tensorflow/core/framework/versions.pb.h"
60 #include "tensorflow/core/graph/algorithm.h"
61 #include "tensorflow/core/graph/graph.h"
62 #include "tensorflow/core/graph/graph_constructor.h"
63 #include "tensorflow/core/graph/graph_partition.h"
64 #include "tensorflow/core/graph/subgraph.h"
65 #include "tensorflow/core/graph/tensor_id.h"
66 #include "tensorflow/core/lib/core/errors.h"
67 #include "tensorflow/core/lib/core/notification.h"
68 #include "tensorflow/core/lib/core/refcount.h"
69 #include "tensorflow/core/lib/core/status.h"
70 #include "tensorflow/core/lib/gtl/array_slice.h"
71 #include "tensorflow/core/lib/gtl/stl_util.h"
72 #include "tensorflow/core/lib/monitoring/counter.h"
73 #include "tensorflow/core/lib/strings/numbers.h"
74 #include "tensorflow/core/lib/strings/str_util.h"
75 #include "tensorflow/core/lib/strings/strcat.h"
76 #include "tensorflow/core/platform/cpu_info.h"
77 #include "tensorflow/core/platform/logging.h"
78 #include "tensorflow/core/platform/mutex.h"
79 #include "tensorflow/core/platform/types.h"
80 #include "tensorflow/core/util/device_name_utils.h"
81 #include "tensorflow/core/util/env_var.h"
82 
83 #if GOOGLE_CUDA
84 #include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
85 #endif // GOOGLE_CUDA
86 
87 namespace tensorflow {
88 
89 namespace {
90 
91 CMS_THREAD_SAFE auto* tbb_session_runs = monitoring::Counter<0>::New(
92  "/tensorflow/core/tbb_session_runs",
93  "The number of times TBBSession::Run() has been called.");
94 
95 
96 // TODO(vrv): Figure out how to unify the many different functions
97 // that generate RendezvousKey, since many of them have to be
98 // consistent with each other.
99 string GetRendezvousKey(const string& tensor_name,
100  const DeviceAttributes& device_info,
101  const FrameAndIter& frame_iter) {
102  return strings::StrCat(device_info.name(), ";",
103  strings::FpToString(device_info.incarnation()), ";",
104  device_info.name(), ";", tensor_name, ";",
105  frame_iter.frame_id, ":", frame_iter.iter_id);
106 }
107 
108 } // namespace
109 
110 class TBBSessionFactory : public SessionFactory {
111  public:
113 
114  bool AcceptsOptions(const SessionOptions& options) override {
115  return options.target == "tbb";
116  }
117 
118  Session* NewSession(const SessionOptions& options) override {
119  // Must do this before the CPU allocator is created.
120  if (options.config.graph_options().build_cost_model() > 0) {
121  EnableCPUAllocatorFullStats(true);
122  }
123  std::vector<Device*> devices;
124  Status s = DeviceFactory::AddDevices(
125  options, "/job:localhost/replica:0/task:0", &devices);
126  if (!s.ok()) {
127  LOG(ERROR) << s;
128  return nullptr;
129  }
130 
132  new TBBSession(options, new DeviceMgr(devices), this);
133  {
134  mutex_lock l(sessions_lock_);
135  sessions_.push_back(session);
136  }
137  return session;
138  }
139 
140  Status Reset(const SessionOptions& options,
141  const std::vector<string>& containers) override {
142  std::vector<TBBSession*> sessions_to_reset;
143  {
144  mutex_lock l(sessions_lock_);
145  // We create a copy to ensure that we don't have a deadlock when
146  // session->Close calls the TBBSessionFactory.Deregister, which
147  // acquires sessions_lock_.
148  std::swap(sessions_to_reset, sessions_);
149  }
150  Status s;
151  for (auto session : sessions_to_reset) {
152  s.Update(session->Reset(containers));
153  }
154  // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
155  // it doesn't close the sessions?
156  for (auto session : sessions_to_reset) {
157  s.Update(session->Close());
158  }
159  return s;
160  }
161 
163  mutex_lock l(sessions_lock_);
164  sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
165  sessions_.end());
166  }
167 
168  private:
170  std::vector<TBBSession*> sessions_ GUARDED_BY(sessions_lock_);
171 };
172 
174  public:
176  SessionFactory::Register("TBB_SESSION", new TBBSessionFactory());
177  }
178 };
180 
181 std::atomic_int_fast64_t TBBSession::step_id_counter_(1);
182 
183 // NOTE: On Android with a single device, there is never
184 // a risk of an OpKernel blocking indefinitely:
185 //
186 // 1) No operations do I/O that depends on other simultaneous kernels,
187 //
188 // 2) Recv nodes always complete immediately: The inputs are sent into
189 // the local rendezvous before we start the executor, so the
190 // corresponding recvs will not block.
191 //
192 // Based on these assumptions, we can use the same thread pool for
193 // both "non-blocking" and "blocking" OpKernels on Android.
194 //
195 // This may change down the road when we add support for multiple
196 // devices that run concurrently, in which case we will need to
197 // revisit this decision.
198 void TBBSession::SchedClosure(tbb::task_arena& arena, tbb::task_group& g, std::function<void()> c) {
199  arena.execute( [&g,&c] () {g.run( c ); } );
200 }
201 
202 TBBSession::TBBSession(const SessionOptions& options,
203  const DeviceMgr* device_mgr,
204  TBBSessionFactory* const factory)
205  : options_(options),
206  device_mgr_(device_mgr),
207  factory_(factory),
208  cancellation_manager_(new CancellationManager()),
209  operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
210  // The default value of sync_on_finish will be flipped soon and this
211  // environment variable will be removed as well.
212  Status status =
213  ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
214  if (!status.ok()) {
215  LOG(ERROR) << status.error_message();
216  }
217  // NOTE(mrry): We do not need to use a unique string for the session
218  // handle, because TBBSession owns its devices. This may change
219  // in future versions.
220  session_handle_ = "tbb";
221  int devices_added = 0;
222  if (options.config.log_device_placement()) {
223  const string mapping_str = device_mgr_->DeviceMappingString();
224  if (mapping_str.empty()) {
225  printf("Device mapping: no known devices.\n");
226  } else {
227  printf("Device mapping:\n%s", mapping_str.c_str());
228  }
229  LOG(INFO) << "Device mapping:\n" << mapping_str;
230  }
231  for (auto d : device_mgr_->ListDevices()) {
232  devices_.push_back(d);
233  device_set_.AddDevice(d);
234  d->op_segment()->AddHold(session_handle_);
235 
236  // The first device added is special: it is the 'client device' (a
237  // CPU device) from which we feed and fetch Tensors.
238  if (devices_added == 0) {
239  device_set_.set_client_device(d);
240  }
241  ++devices_added;
242  }
243 }
244 
246  if (!closed_) Close().IgnoreError();
247  for (auto& it : partial_runs_) {
248  it.second.reset(nullptr);
249  }
250  for (auto& it : executors_) {
251  it.second.reset();
252  }
253  for (auto d : device_mgr_->ListDevices()) {
254  d->op_segment()->RemoveHold(session_handle_);
255  }
256  delete cancellation_manager_;
257 
258  execution_state_.reset(nullptr);
259  flib_def_.reset(nullptr);
260 }
261 
263  const GraphDef& graph, bool* out_already_initialized) {
264  // If already initialized, do nothing.
265  if (flib_def_ && execution_state_) {
266  *out_already_initialized = true;
267  return Status::OK();
268  }
269  // Set up the per-session execution state.
270  // NOTE(mrry): The function library created here will be used for
271  // all subsequent extensions of the graph.
272  flib_def_.reset(
273  new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
274  SimpleGraphExecutionStateOptions options;
275  options.device_set = &device_set_;
276  options.session_options = &options_;
277  // TODO(mrry,suharshs): We explicitly copy `graph` so that
278  // `MakeForBaseGraph()` can take ownership of its
279  // contents. Previously this happened implicitly in calls to the
280  // `SimpleGraphExecutionState`. Other sessions call
281  // `MakeForBaseGraph` in such a way that we can destructively read
282  // the passed-in `GraphDef`. In principle we could do the same here,
283  // with a wider refactoring; we might revise the direct session so
284  // that it copies the graph fewer times.
285  GraphDef temp(graph);
286  TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
287  &temp, options, &execution_state_));
288  graph_created_ = true;
289  *out_already_initialized = false;
290  return Status::OK();
291 }
292 
293 Status TBBSession::Create(const GraphDef& graph) {
294  TF_RETURN_IF_ERROR(init_error_);
295  if (graph.node_size() > 0) {
296  mutex_lock l(graph_def_lock_);
297  if (graph_created_) {
298  return errors::AlreadyExists(
299  "A Graph has already been created for this session.");
300  }
301  return ExtendLocked(graph);
302  }
303  return Status::OK();
304 }
305 
306 Status TBBSession::Extend(const GraphDef& graph) {
307  TF_RETURN_IF_ERROR(CheckNotClosed());
308  mutex_lock l(graph_def_lock_);
309  return ExtendLocked(graph);
310 }
311 
312 Status TBBSession::ExtendLocked(const GraphDef& graph) {
313  bool already_initialized;
314  // If this is the first call, we can initialize the execution state
315  // with `graph` and do not need to call `Extend()`.
316  TF_RETURN_IF_ERROR(
317  MaybeInitializeExecutionState(graph, &already_initialized));
318  if (already_initialized) {
319  TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
320  std::unique_ptr<SimpleGraphExecutionState> state;
321  TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
322  execution_state_.swap(state);
323  }
324  return Status::OK();
325 }
326 
328  const std::vector<string>& output_names,
329  const std::vector<string>& target_nodes,
330  std::vector<Tensor>* outputs) {
331  RunMetadata run_metadata;
332  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
333  &run_metadata);
334 }
335 
337  const DebugOptions& debug_options, int64 session_run_index,
338  int64 executor_step_index, const std::vector<string>& input_names,
339  const std::vector<string>& output_names,
340  const std::vector<string>& target_names,
341  std::unique_ptr<DebuggerStateInterface>* debugger_state) {
342  TF_RETURN_IF_ERROR(
343  DebuggerStateRegistry::CreateState(debug_options, debugger_state));
344  TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
345  debug_options.global_step(), session_run_index, executor_step_index,
346  input_names, output_names, target_names));
347  return Status::OK();
348 }
349 
351  const DebugOptions& debug_options, Graph* graph, Device* device) {
352  std::unique_ptr<DebugGraphDecoratorInterface> decorator;
353  TF_RETURN_IF_ERROR(
354  DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
355 
356  TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
357  TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
358  return Status::OK();
359 }
360 
361 Status TBBSession::Run(const RunOptions& run_options,
362  const NamedTensorList& inputs,
363  const std::vector<string>& output_names,
364  const std::vector<string>& target_nodes,
365  std::vector<Tensor>* outputs,
366  RunMetadata* run_metadata) {
367  TF_RETURN_IF_ERROR(CheckNotClosed());
368  tbb_session_runs->GetCell()->IncrementBy(1);
369  {
370  mutex_lock l(graph_def_lock_);
371  if (!graph_created_) {
372  return errors::InvalidArgument(
373  "Session was not created with a graph before Run()!");
374  }
375  }
376 
377  // Extract the inputs names for this run of the session.
378  std::vector<string> input_tensor_names;
379  input_tensor_names.reserve(inputs.size());
380  for (const auto& it : inputs) {
381  input_tensor_names.push_back(it.first);
382  }
383 
384 
385  // Check if we already have an executor for these arguments.
386  ExecutorsAndKeys* executors_and_keys;
387  RunStateArgs run_state_args(run_options.debug_options());
388 
389  Executor::Args args;
390  args.step_id = step_id_counter_.fetch_add(1);
391 
392  TF_RETURN_IF_ERROR(
393  GetOrCreateExecutors(input_tensor_names, output_names, target_nodes,
394  &executors_and_keys, &run_state_args));
395  const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
396 
397  std::unique_ptr<DebuggerStateInterface> debugger_state;
398  if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
399  TF_RETURN_IF_ERROR(CreateDebuggerState(
400  run_options.debug_options(), args.step_id, executor_step_count,
401  input_tensor_names, output_names, target_nodes, &debugger_state));
402  }
403 
404  // Configure a call frame for the step, which we use to feed and
405  // fetch values to and from the executors.
406  FunctionCallFrame call_frame(executors_and_keys->input_types,
407  executors_and_keys->output_types);
408  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
409  for (const auto& it : inputs) {
410  if (it.second.dtype() == DT_RESOURCE) {
411  Tensor tensor_from_handle;
412  TF_RETURN_IF_ERROR(
413  ResourceHandleToInputTensor(it.second, &tensor_from_handle));
414  feed_args[executors_and_keys->input_name_to_index[it.first]] =
415  tensor_from_handle;
416  } else {
417  feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
418  }
419  }
420  Status s = call_frame.SetArgs(feed_args);
421  if (errors::IsInternal(s)) {
422  return errors::InvalidArgument(s.error_message());
423  } else if (!s.ok()) {
424  return s;
425  }
426 
427  // Create a run state and start execution.
428  RunState run_state(args.step_id, &devices_);
429  run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
430  CancellationManager step_cancellation_manager;
431  args.call_frame = &call_frame;
432 
433  // Use a task_arena to avoid having unrelated tasks start
434  // running on this thread (which could start deadlocks)
435  tbb::task_arena taskArena;
436  tbb::task_group taskGroup;
437  //we are required to always call wait before destructor
438  auto doneWithTaskGroup = [&taskArena, &taskGroup](void *) { taskArena.execute([&taskGroup]() { taskGroup.wait();}); };
439  std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup) > guard(&taskGroup, doneWithTaskGroup);
440 
441  // Start parallel Executors.
442  const size_t num_executors = executors_and_keys->items.size();
443  ExecutorBarrier* barrier = new ExecutorBarrier(
444  num_executors, run_state.rendez, [&run_state](const Status& ret) {
445  {
446  mutex_lock l(run_state.mu_);
447  run_state.status.Update(ret);
448  }
449  run_state.executors_done.Notify();
450  });
451 
452  args.rendezvous = run_state.rendez;
453  args.cancellation_manager = &step_cancellation_manager;
454  args.runner = [this, &taskArena, &taskGroup](Executor::Args::Closure c) {
455  SchedClosure(taskArena, taskGroup, std::move(c));
456  };
457  args.session_state = &session_state_;
458  args.tensor_store = &run_state.tensor_store;
459  args.step_container = &run_state.step_container;
460  if (LogMemory::IsEnabled()) {
461  LogMemory::RecordStep(args.step_id, run_state_args.handle);
462  }
463  args.sync_on_finish = sync_on_finish_;
464 
465  const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
466 
467  bool update_cost_model = false;
468  if (options_.config.graph_options().build_cost_model() > 0) {
469  const int64 build_cost_model_every =
470  options_.config.graph_options().build_cost_model();
471  const int64 build_cost_model_after =
472  options_.config.graph_options().build_cost_model_after();
473  int64 measure_step_count = executor_step_count - build_cost_model_after;
474  if (measure_step_count >= 0) {
475  update_cost_model =
476  ((measure_step_count + 1) % build_cost_model_every == 0);
477  }
478  }
479  if (do_trace || update_cost_model) {
480  run_state.collector.reset(
481  new StepStatsCollector(run_metadata->mutable_step_stats()));
482  args.stats_collector = run_state.collector.get();
483  }
484 
485 #if GOOGLE_CUDA
486  std::unique_ptr<GPUTracer> tracer;
487  if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
488  tracer.reset(CreateGPUTracer());
489  // tracer will be NULL on non-GPU platforms.
490  // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
491  if (tracer) tracer->Start().IgnoreError();
492  }
493 #endif // GOOGLE_CUDA
494 
495  // Register this step with session's cancellation manager, so that
496  // `Session::Close()` will cancel the step.
497  CancellationToken cancellation_token =
498  cancellation_manager_->get_cancellation_token();
499  bool already_cancelled = !cancellation_manager_->RegisterCallback(
500  cancellation_token, [&step_cancellation_manager]() {
501  step_cancellation_manager.StartCancel();
502  });
503  if (already_cancelled) {
504  // NOTE(mrry): If we don't explicitly notify
505  // `run_state.executors_done`, the RunState destructor would
506  // block on this notification.
507  run_state.executors_done.Notify();
508  delete barrier;
509  return errors::Cancelled("Run call was cancelled");
510  }
511 
512  for (const auto& item : executors_and_keys->items) {
513  item.executor->RunAsync(args, barrier->Get());
514  }
515 
516  //WaitForNotification will handle calling wait on taskGroup
517  guard.release();
518  WaitForNotification(taskArena, taskGroup, &run_state, &step_cancellation_manager,
519  run_options.timeout_in_ms() > 0
520  ? run_options.timeout_in_ms()
522 
523  if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
524  // The step has been cancelled: make sure we don't attempt to receive the
525  // outputs as this would make it block forever.
526  mutex_lock l(run_state.mu_);
527  run_state.status.Update(errors::Cancelled("Run call was cancelled"));
528  }
529 
530 #if GOOGLE_CUDA
531  if (tracer) {
532  // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
533  tracer->Stop().IgnoreError();
534  tracer->Collect(args.stats_collector).IgnoreError();
535  }
536 #endif // GOOGLE_CUDA
537 
538  {
539  mutex_lock l(run_state.mu_);
540  TF_RETURN_IF_ERROR(run_state.status);
541  }
542 
543  // Receive outputs.
544  if (outputs) {
545  std::vector<Tensor> sorted_outputs;
546  Status s = call_frame.ConsumeRetvals(&sorted_outputs);
547  if (errors::IsInternal(s)) {
548  return errors::InvalidArgument(s.error_message());
549  } else if (!s.ok()) {
550  return s;
551  }
552  const bool unique_outputs =
553  output_names.size() == executors_and_keys->output_name_to_index.size();
554  // first_indices[i] = j implies that j is the smallest value for which
555  // output_names[i] == output_names[j].
556  std::vector<int> first_indices;
557  if (!unique_outputs) {
558  first_indices.resize(output_names.size());
559  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
560  for (int j = 0; j <= i; ++j) {
561  if (output_names[i] == output_names[j]) {
562  first_indices[i] = j;
563  break;
564  }
565  }
566  }
567  }
568  outputs->clear();
569  outputs->reserve(sorted_outputs.size());
570  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
571  const string& output_name = output_names[i];
572  if (first_indices.empty() || first_indices[i] == i) {
573  outputs->emplace_back(
574  std::move(sorted_outputs[executors_and_keys
575  ->output_name_to_index[output_name]]));
576  } else {
577  outputs->push_back((*outputs)[first_indices[i]]);
578  }
579  }
580  }
581 
582  // Save the output tensors of this run we choose to keep.
583  TF_RETURN_IF_ERROR(
584  run_state.tensor_store.SaveTensors(output_names, &session_state_));
585 
586  // Build and return the cost model as instructed.
587  mutex_lock l(executor_lock_);
588  if (update_cost_model) {
589  // Build the cost model
590  std::unordered_map<string, const Graph*> device_to_graph;
591  for (const PerPartitionExecutorsAndLib& partition :
592  executors_and_keys->items) {
593  const Graph* graph = partition.graph;
594  const string device = partition.flib->device()->name();
595  device_to_graph[device] = graph;
596  }
597  args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
598 
599  // annotate stats onto cost graph.
600  CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
601  for (const auto& item : executors_and_keys->items) {
602  TF_RETURN_IF_ERROR(
603  cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
604  }
605  }
606 
607  // If requested via RunOptions, output the partition graphs.
608  if (run_options.output_partition_graphs()) {
609  protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
610  run_metadata->mutable_partition_graphs();
611  for (const PerPartitionExecutorsAndLib& exec_and_lib :
612  executors_and_keys->items) {
613  GraphDef* partition_graph_def = partition_graph_defs->Add();
614  exec_and_lib.graph->ToGraphDef(partition_graph_def);
615  }
616  }
617 
618  return Status::OK();
619 }
620 
621 
622 Status TBBSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
623  Tensor* retrieved_tensor) {
624  if (resource_tensor.dtype() != DT_RESOURCE) {
625  return errors::InvalidArgument(strings::StrCat(
626  "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
627  resource_tensor.dtype()));
628  }
629 
630  ResourceHandle resource_handle = resource_tensor.scalar<ResourceHandle>()();
631 
632  if (resource_handle.container() ==
633  SessionState::kTensorHandleResourceTypeName) {
634  return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
635  } else {
636  return errors::InvalidArgument(strings::StrCat(
637  "Invalid resource type hash code: ", resource_handle.hash_code(),
638  "(name: ", resource_handle.name(),
639  " type: ", resource_handle.maybe_type_name(), ")"));
640  }
641 }
642 
644  gtl::ArraySlice<string> inputs,
645  gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes,
646  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) {
647  int64 handle_name_counter_value = -1;
648  if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
649  handle_name_counter_value = handle_name_counter_.fetch_add(1);
650  }
651 
652  string debug_tensor_watches_summary;
653  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
654  debug_tensor_watches_summary = SummarizeDebugTensorWatches(
655  run_state_args->debug_options.debug_tensor_watch_opts());
656  }
657 
658  // Fast lookup path, no sorting.
659  const string key = strings::StrCat(
660  str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
661  str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
662  "/", debug_tensor_watches_summary);
663  // Set the handle, if it's needed to log memory or for partial run.
664  if (handle_name_counter_value >= 0) {
665  run_state_args->handle =
666  strings::StrCat(key, ";", handle_name_counter_value);
667  }
668 
669  // See if we already have the executors for this run.
670  {
671  mutex_lock l(executor_lock_); // could use reader lock
672  auto it = executors_.find(key);
673  if (it != executors_.end()) {
674  *executors_and_keys = it->second.get();
675  return Status::OK();
676  }
677  }
678 
679  // Slow lookup path, the unsorted key missed the cache.
680  // Sort the inputs and outputs, and look up with the sorted key in case an
681  // earlier call used a different order of inputs and outputs.
682  //
683  // We could consider some other signature instead of sorting that
684  // preserves the same property to avoid the sort in the future.
685  std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
686  std::sort(inputs_sorted.begin(), inputs_sorted.end());
687  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
688  std::sort(outputs_sorted.begin(), outputs_sorted.end());
689  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
690  std::sort(tn_sorted.begin(), tn_sorted.end());
691 
692  const string sorted_key = strings::StrCat(
693  str_util::Join(inputs_sorted, ","), "->",
694  str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
695  "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
696  // Set the handle, if its needed to log memory or for partial run.
697  if (handle_name_counter_value >= 0) {
698  run_state_args->handle =
699  strings::StrCat(sorted_key, ";", handle_name_counter_value);
700  }
701 
702  // See if we already have the executors for this run.
703  {
704  mutex_lock l(executor_lock_);
705  auto it = executors_.find(sorted_key);
706  if (it != executors_.end()) {
707  *executors_and_keys = it->second.get();
708  // Insert this under the original key.
709  executors_.emplace(key, it->second);
710  return Status::OK();
711  }
712  }
713 
714  // Nothing found, so create the executors and store in the cache.
715  BuildGraphOptions options;
716  options.feed_endpoints = inputs_sorted;
717  options.fetch_endpoints = outputs_sorted;
718  options.target_nodes = tn_sorted;
719  options.use_function_convention = !run_state_args->is_partial_run;
720  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
721  options.debug_options = run_state_args->debug_options;
722  }
723 
724  std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
725 
726  // The executor_lock_ is intentionally released while executor is
727  // being created.
728  std::unordered_map<string, std::unique_ptr<Graph>> graphs;
729  TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
730  run_state_args, &ek->input_types,
731  &ek->output_types));
732 
733  if (run_state_args->is_partial_run) {
734  ek->graph = std::move(run_state_args->graph);
735  std::unordered_set<StringPiece, StringPiece::Hasher> names;
736  for (const string& input : inputs) {
737  TensorId id(ParseTensorName(input));
738  names.emplace(id.first);
739  }
740  for (const string& output : outputs) {
741  TensorId id(ParseTensorName(output));
742  names.emplace(id.first);
743  }
744  for (Node* n : ek->graph->nodes()) {
745  if (names.count(n->name()) > 0) {
746  ek->name_to_node.insert({n->name(), n});
747  }
748  }
749  }
750  ek->items.reserve(graphs.size());
751  const auto& optimizer_opts =
752  options_.config.graph_options().optimizer_options();
753  GraphOptimizer optimizer(optimizer_opts);
754  for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
755  const string& partition_name = iter->first;
756  std::unique_ptr<Graph>& partition_graph = iter->second;
757  const int graph_def_version = partition_graph->versions().producer();
758 
759  Device* device;
760  TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
761 
762  ek->items.resize(ek->items.size() + 1);
763  auto* item = &(ek->items.back());
764  item->flib.reset(NewFunctionLibraryRuntime(
765  device_mgr_.get(), options_.env, device, graph_def_version,
766  ek->flib_def.get(), optimizer_opts));
767 
768  LocalExecutorParams params;
769  params.device = device;
770  params.function_library = item->flib.get();
771  auto lib = item->flib.get();
772  auto opseg = device->op_segment();
773  params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
774  OpKernel** kernel) {
775  // Caches the kernel only if the node is stateful.
776  if (!lib->IsStateful(ndef.op())) {
777  return lib->CreateKernel(ndef, kernel);
778  }
779  auto create_fn = [lib, &ndef](OpKernel** kernel) {
780  return lib->CreateKernel(ndef, kernel);
781  };
782  // Kernels created for subgraph nodes need to be cached. On
783  // cache miss, create_fn() is invoked to create a kernel based
784  // on the function library here + global op registry.
785  return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
786  create_fn);
787  };
788  params.delete_kernel = [lib](OpKernel* kernel) {
789  // If the node is stateful, opseg owns it. Otherwise, delete it.
790  if (kernel && !lib->IsStateful(kernel->type_string())) {
791  delete kernel;
792  }
793  };
794  params.node_outputs_cb = node_outputs_callback_;
795 
796  optimizer.Optimize(lib, options_.env, device, &iter->second);
797 
798  // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
799  if (!options.debug_options.debug_tensor_watch_opts().empty()) {
800  TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
801  options.debug_options, partition_graph.get(), params.device));
802  }
803 
804  TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
805  device->name(),
806  partition_graph.get()));
807  // NewLocalExecutor takes ownership of partition_graph.
808  item->graph = partition_graph.get();
809  item->executor = nullptr;
810  Executor* executor;
811  TF_RETURN_IF_ERROR(
812  NewLocalExecutor(params, partition_graph.release(), &executor));
813  item->executor.reset(executor);
814  }
815 
816  // Cache the mapping from input/output names to graph elements to
817  // avoid recomputing it every time.
818  if (!run_state_args->is_partial_run) {
819  // For regular `Run()`, we use the function calling convention, and so
820  // maintain a mapping from input/output names to
821  // argument/return-value ordinal index.
822  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
823  const string& input = inputs_sorted[i];
824  ek->input_name_to_index[input] = i;
825  }
826  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
827  const string& output = outputs_sorted[i];
828  ek->output_name_to_index[output] = i;
829  }
830  } else {
831  // For `PRun()`, we use the rendezvous calling convention, and so
832  // maintain a mapping from input/output names to rendezvous keys.
833  //
834  // We always use the first device as the device name portion of the
835  // key, even if we're feeding another graph.
836  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
837  const string& input = inputs_sorted[i];
838  ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
839  input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
840  }
841  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
842  const string& output = outputs_sorted[i];
843  ek->output_name_to_rendezvous_key[output] =
844  GetRendezvousKey(output, device_set_.client_device()->attributes(),
845  FrameAndIter(0, 0));
846  }
847  }
848 
849  // Reacquire the lock, try to insert into the map.
850  mutex_lock l(executor_lock_);
851 
852  // Another thread may have created the entry before us, in which case we will
853  // reuse the already created one.
854  auto insert_result = executors_.emplace(sorted_key, ek);
855  // Insert the value under the original key, so the fast path lookup will work
856  // if the user uses the same order of inputs, outputs, and targets again.
857  executors_.emplace(key, insert_result.first->second);
858  *executors_and_keys = insert_result.first->second.get();
859 
860  return Status::OK();
861 }
862 
864  const BuildGraphOptions& subgraph_options,
865  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
866  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
867  RunStateArgs* run_state_args, DataTypeVector* input_types,
868  DataTypeVector* output_types) {
869  mutex_lock l(graph_def_lock_);
870  std::unique_ptr<SimpleClientGraph> client_graph;
871 
872  std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
873  SimpleGraphExecutionState* execution_state = nullptr;
874  if (options_.config.graph_options().place_pruned_graph()) {
875  // Because we are placing pruned graphs, we need to create a
876  // new SimpleGraphExecutionState for every new unseen graph,
877  // and then place it.
878  SimpleGraphExecutionStateOptions prune_options;
879  prune_options.device_set = &device_set_;
880  prune_options.session_options = &options_;
881  prune_options.stateful_placements = stateful_placements_;
882  TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph(
883  execution_state_->original_graph_def().library(), prune_options,
884  execution_state_->original_graph_def(), subgraph_options,
885  &temp_exec_state_holder, &client_graph));
886  execution_state = temp_exec_state_holder.get();
887  } else {
888  execution_state = execution_state_.get();
889  TF_RETURN_IF_ERROR(
890  execution_state->BuildGraph(subgraph_options, &client_graph));
891  }
892 
893  if (subgraph_options.feed_endpoints.size() !=
894  client_graph->feed_types.size()) {
895  return errors::Internal(
896  "Graph pruning failed: requested number of feed endpoints = ",
897  subgraph_options.feed_endpoints.size(),
898  " versus number of pruned feed endpoints = ",
899  client_graph->feed_types.size());
900  }
901  if (subgraph_options.fetch_endpoints.size() !=
902  client_graph->fetch_types.size()) {
903  return errors::Internal(
904  "Graph pruning failed: requested number of fetch endpoints = ",
905  subgraph_options.fetch_endpoints.size(),
906  " versus number of pruned fetch endpoints = ",
907  client_graph->fetch_types.size());
908  }
909 
910  auto current_stateful_placements = execution_state->GetStatefulPlacements();
911  // Update our current state based on the execution_state's
912  // placements. If there are any mismatches for a node,
913  // we should fail, as this should never happen.
914  for (auto placement_pair : current_stateful_placements) {
915  const string& node_name = placement_pair.first;
916  const string& placement = placement_pair.second;
917  auto iter = stateful_placements_.find(node_name);
918  if (iter == stateful_placements_.end()) {
919  stateful_placements_.insert(std::make_pair(node_name, placement));
920  } else if (iter->second != placement) {
921  return errors::Internal(
922  "Stateful placement mismatch. "
923  "Current assignment of ",
924  node_name, " to ", iter->second, " does not match ", placement);
925  }
926  }
927 
928  stateful_placements_ = execution_state->GetStatefulPlacements();
929 
930  // Remember the graph in run state if this is a partial run.
931  if (run_state_args->is_partial_run) {
932  run_state_args->graph.reset(new Graph(flib_def_.get()));
933  CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
934  }
935 
936  // Partition the graph across devices.
937  PartitionOptions popts;
938  popts.node_to_loc = [](const Node* node) {
939  assert(node != nullptr);
940  return node->assigned_device_name();
941  };
942  popts.new_name = [this](const string& prefix) {
943  return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
944  };
945  popts.get_incarnation = [](const string& name) {
946  // The direct session does not have changing incarnation numbers.
947  // Just return '1'.
948  return 1;
949  };
950  popts.control_flow_added = false;
951 
952  std::unordered_map<string, GraphDef> partitions;
953  TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
954 
955  std::vector<string> device_names;
956  for (auto device : devices_) {
957  // Extract the LocalName from the device.
958  device_names.push_back(DeviceNameUtils::LocalName(device->name()));
959  }
960 
961  // Check for valid partitions.
962  for (const auto& partition : partitions) {
963  const string local_partition_name =
964  DeviceNameUtils::LocalName(partition.first);
965  if (std::count(device_names.begin(), device_names.end(),
966  local_partition_name) == 0) {
967  return errors::InvalidArgument(
968  "Creating a partition for ", local_partition_name,
969  " which doesn't exist in the list of available devices. Available "
970  "devices: ",
971  str_util::Join(device_names, ","));
972  }
973  }
974 
975  for (const auto& partition : partitions) {
976  std::unique_ptr<Graph> device_graph(
977  new Graph(client_graph->flib_def.get()));
978  GraphConstructorOptions device_opts;
979  // There are internal operations (e.g., send/recv) that we now allow.
980  device_opts.allow_internal_ops = true;
981  device_opts.expect_device_spec = true;
982  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
983  device_graph.get()));
984  outputs->emplace(partition.first, std::move(device_graph));
985  }
986 
987  GraphOptimizationPassOptions optimization_options;
988  optimization_options.session_options = &options_;
989  optimization_options.flib_def = client_graph->flib_def.get();
990  optimization_options.partition_graphs = outputs;
991  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
992  OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
993 
994  Status s;
995  for (auto& partition : *outputs) {
996  const string& partition_name = partition.first;
997  std::unique_ptr<Graph>* graph = &partition.second;
998 
999  VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1000  << partition_name;
1001 
1002  // Give the device an opportunity to rewrite its subgraph.
1003  Device* d;
1004  s = device_mgr_->LookupDevice(partition_name, &d);
1005  if (!s.ok()) break;
1006  // TODO(pbar) The library is currently shared and immutable. There
1007  // may be possible use cases where a device may want to modify
1008  // function definitions - in which case the library would need to be
1009  // replicated per device.
1010  s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph);
1011  if (!s.ok()) {
1012  break;
1013  }
1014  }
1015  *flib_def = std::move(client_graph->flib_def);
1016  std::swap(*input_types, client_graph->feed_types);
1017  std::swap(*output_types, client_graph->fetch_types);
1018  return s;
1019 }
1020 
1022  std::vector<DeviceAttributes>* response) {
1023  response->clear();
1024  response->reserve(devices_.size());
1025  for (Device* d : devices_) {
1026  const DeviceAttributes& attrs = d->attributes();
1027  response->emplace_back(attrs);
1028  }
1030 }
1031 
1033  const std::vector<string>& containers) {
1034  device_mgr_->ClearContainers(containers);
1036 }
1037 
1039  cancellation_manager_->StartCancel();
1040  {
1041  mutex_lock l(closed_lock_);
1042  if (closed_) return ::tensorflow::Status::OK();
1043  closed_ = true;
1044  }
1045  if (factory_ != nullptr) factory_->Deregister(this);
1047 }
1048 
1050  const std::vector<string>& pending_input_names,
1051  const std::vector<string>& pending_output_names, int64 step_id,
1052  const std::vector<Device*>* devices)
1053  : step_container(step_id, [devices](const string& name) {
1054  for (auto d : *devices) {
1055  if (!d->resource_manager()->Cleanup(name).ok()) {
1056  // Do nothing...
1057  }
1058  }
1059  }) {
1060  // Initially all the feeds and fetches are pending.
1061  for (auto& name : pending_input_names) {
1062  pending_inputs[name] = false;
1063  }
1064  for (auto& name : pending_output_names) {
1065  pending_outputs[name] = false;
1066  }
1067 }
1068 
1070  const std::vector<Device*>* devices)
1071  : RunState({}, {}, step_id, devices) {}
1072 
1074  if (rendez != nullptr) {
1075  if (!executors_done.HasBeenNotified()) {
1076  rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1077  executors_done.WaitForNotification();
1078  }
1079  rendez->Unref();
1080  }
1081 }
1082 
1084  for (const auto& it : pending_inputs) {
1085  if (!it.second) return false;
1086  }
1087  for (const auto& it : pending_outputs) {
1088  if (!it.second) return false;
1089  }
1090  return true;
1091 }
1092 
1093 void TBBSession::WaitForNotification(tbb::task_arena& arena,
1094  tbb::task_group& taskGroup,
1095  RunState* run_state,
1096  CancellationManager* cm,
1097  int64 timeout_in_ms) {
1098  // Doing the wait in the arena adds this thread to the arena
1099  // and therefore tasks associated to the group can run on this thread
1100  arena.execute([&taskGroup]() { taskGroup.wait();} );
1101 
1102  Status status =
1103  WaitForNotification(&run_state->executors_done, timeout_in_ms);
1104  if (!status.ok()) {
1105  {
1106  mutex_lock l(run_state->mu_);
1107  run_state->status.Update(status);
1108  }
1109  cm->StartCancel();
1110  // We must wait for the executors to complete, because they have borrowed
1111  // references to `cm` and other per-step state. After this notification, it
1112  // is safe to clean up the step.
1113  run_state->executors_done.WaitForNotification();
1114  }
1115 }
1116 
1118  Notification* notification, int64 timeout_in_ms) {
1119  if (timeout_in_ms > 0) {
1120  int64 timeout_in_us = timeout_in_ms * 1000;
1121  bool notified = WaitForNotificationWithTimeout(notification, timeout_in_us);
1122  if (!notified) {
1123  return Status(error::DEADLINE_EXCEEDED,
1124  "Timed out waiting for notification");
1125  }
1126  } else {
1127  notification->WaitForNotification();
1128  }
1129  return Status::OK();
1130 }
1131 
1132 } // namespace tensorflow
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:305
std::vector< PerPartitionExecutorsAndLib > items
Definition: TBBSession.h:152
RunState(int64 step_id, const std::vector< Device * > *devices)
Definition: TBBSession.cc:1069
std::unordered_map< string, size_t > input_name_to_index
Definition: TBBSession.h:153
static const HistoName names[]
::tensorflow::Status Close() override
Definition: TBBSession.cc:1038
static boost::mutex mutex
Definition: LHEProxy.cc:11
std::vector< Device * > devices_
Definition: TBBSession.h:257
std::atomic< int64 > handle_name_counter_
Definition: TBBSession.h:313
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: TBBSession.h:90
std::atomic< int64 > edge_name_counter_
Definition: TBBSession.h:312
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: TBBSession.cc:1117
std::unique_ptr< FunctionLibraryRuntime > flib
Definition: TBBSession.h:129
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: TBBSession.cc:350
CostModelManager cost_model_manager_
Definition: TBBSession.h:322
#define LOG(A)
Definition: config.py:1
std::unordered_map< string, bool > pending_inputs
Definition: TBBSession.h:172
::tensorflow::Status Create(const GraphDef &graph) override
Definition: TBBSession.cc:293
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Definition: Activities.doc:4
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Definition: TBBSession.cc:140
Session * NewSession(const SessionOptions &options) override
Definition: TBBSession.cc:118
static std::string const input
Definition: EdmProvDump.cc:44
Partition
Definition: HLTHPDFilter.cc:32
void Deregister(const TBBSession *session)
Definition: TBBSession.cc:162
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: TBBSession.cc:262
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: TBBSession.cc:643
static NTSessionRegistrar registrar
Definition: NTSession.cc:176
IntraProcessRendezvous * rendez
Definition: TBBSession.h:169
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::vector< TBBSession * > sessions_ GUARDED_BY(sessions_lock_)
void SchedClosure(tbb::task_arena &arena, tbb::task_group &g, std::function< void()> c)
Definition: TBBSession.cc:198
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: TBBSession.cc:622
CancellationManager * cancellation_manager_
Definition: TBBSession.h:289
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: TBBSession.cc:863
#define CMS_THREAD_SAFE
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: TBBSession.cc:1032
const DebugOptions & debug_options
Definition: TBBSession.h:195
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: TBBSession.cc:306
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unordered_map< string, bool > pending_outputs
Definition: TBBSession.h:173
::tensorflow::Status CheckNotClosed()
Definition: TBBSession.h:237
std::atomic_int_fast64_t step_count
Definition: TBBSession.h:148
const SessionOptions options_
Definition: TBBSession.h:253
TBBSessionFactory *const factory_
Definition: TBBSession.h:288
static std::atomic_int_fast64_t step_id_counter_
Definition: TBBSession.h:316
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:209
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:256
const int64 operation_timeout_in_ms_
Definition: TBBSession.h:319
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: TBBSession.h:324
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: TBBSession.cc:327
TBBSession(const SessionOptions &options, const DeviceMgr *device_mgr, TBBSessionFactory *factory)
Definition: TBBSession.cc:202
std::unique_ptr< Graph > graph
Definition: TBBSession.h:194
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: TBBSession.cc:336
SessionState session_state_
Definition: TBBSession.h:286
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: TBBSession.cc:312
graphs
Definition: cuy.py:960
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: TBBSession.cc:1021
def move(src, dest)
Definition: eostools.py:510
static const int ERROR
bool AcceptsOptions(const SessionOptions &options) override
Definition: TBBSession.cc:114