CMS 3D CMS Logo

TBBSession.cc
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 //NOTE: The memory layout of the Node class changes depending on if NDEBUG was
16 // set when tensorflow was compiled. The reason is Node class holds two edgeset
17 // class instances and edgeset adds a member data if NDEBUG is set
18 
19 /*
20 This file is an adaptation of the original direct_session.cc file located at
21 https://github.com/tensorflow/tensorflow/blob/v1.6.0/tensorflow/core/common_runtime/direct_session.cc
22 to meet the demands of the software environment developed and used by the CMS collaboration.
23 
24 Changes with respect to the original code are documented in the TBBSession.h header file.
25 */
26 
27 #if !defined(NDEBUG)
28 #define NDEBUG 1
29 #endif
30 
31 #include "TBBSession.h"
32 
33 #include <atomic>
34 #include <string>
35 #include <vector>
36 
37 #include "tbb/task_group.h"
40 
41 #include "tensorflow/core/common_runtime/constant_folding.h"
42 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
43 #include "tensorflow/core/common_runtime/device_factory.h"
44 #include "tensorflow/core/common_runtime/executor.h"
45 #include "tensorflow/core/common_runtime/function.h"
46 #include "tensorflow/core/common_runtime/graph_optimizer.h"
47 #include "tensorflow/core/common_runtime/memory_types.h"
48 #include "tensorflow/core/common_runtime/optimization_registry.h"
49 #include "tensorflow/core/common_runtime/step_stats_collector.h"
50 #include "tensorflow/core/framework/function.h"
51 #include "tensorflow/core/framework/graph.pb_text.h"
52 #include "tensorflow/core/framework/graph.pb.h"
53 #include "tensorflow/core/framework/graph_def_util.h"
54 #include "tensorflow/core/framework/log_memory.h"
55 #include "tensorflow/core/framework/node_def.pb.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/framework/versions.pb.h"
58 #include "tensorflow/core/graph/algorithm.h"
59 #include "tensorflow/core/graph/graph.h"
60 #include "tensorflow/core/graph/graph_constructor.h"
61 #include "tensorflow/core/graph/graph_partition.h"
62 #include "tensorflow/core/graph/subgraph.h"
63 #include "tensorflow/core/graph/tensor_id.h"
64 #include "tensorflow/core/lib/core/errors.h"
65 #include "tensorflow/core/lib/core/notification.h"
66 #include "tensorflow/core/lib/core/refcount.h"
67 #include "tensorflow/core/lib/core/status.h"
68 #include "tensorflow/core/lib/gtl/array_slice.h"
69 #include "tensorflow/core/lib/gtl/stl_util.h"
70 #include "tensorflow/core/lib/monitoring/counter.h"
71 #include "tensorflow/core/lib/strings/numbers.h"
72 #include "tensorflow/core/lib/strings/str_util.h"
73 #include "tensorflow/core/lib/strings/strcat.h"
74 #include "tensorflow/core/platform/cpu_info.h"
75 #include "tensorflow/core/platform/device_tracer.h"
76 #include "tensorflow/core/platform/logging.h"
77 #include "tensorflow/core/platform/mutex.h"
78 #include "tensorflow/core/platform/types.h"
79 #include "tensorflow/core/util/device_name_utils.h"
80 #include "tensorflow/core/util/env_var.h"
81 
82 namespace tensorflow {
83 
84 namespace {
85 
86 CMS_THREAD_SAFE auto* tbb_session_runs = monitoring::Counter<0>::New(
87  "/tensorflow/core/tbb_session_runs",
88  "The number of times TBBSession::Run() has been called.");
89 
90 // TODO(vrv): Figure out how to unify the many different functions
91 // that generate RendezvousKey, since many of them have to be
92 // consistent with each other.
93 string GetRendezvousKey(const string& tensor_name,
94  const DeviceAttributes& device_info,
95  const FrameAndIter& frame_iter) {
96  return strings::StrCat(device_info.name(), ";",
97  strings::FpToString(device_info.incarnation()), ";",
98  device_info.name(), ";", tensor_name, ";",
99  frame_iter.frame_id, ":", frame_iter.iter_id);
100 }
101 
102 } // namespace
103 
104 class TBBSessionFactory : public SessionFactory {
105  public:
107 
108  bool AcceptsOptions(const SessionOptions& options) override {
109  return options.target == "tbb";
110  }
111 
112  Session* NewSession(const SessionOptions& options) override {
113  // Must do this before the CPU allocator is created.
114  if (options.config.graph_options().build_cost_model() > 0) {
115  EnableCPUAllocatorFullStats(true);
116  }
117  std::vector<Device*> devices;
118  const Status s = DeviceFactory::AddDevices(
119  options, "/job:localhost/replica:0/task:0", &devices);
120  if (!s.ok()) {
121  LOG(ERROR) << s;
122  return nullptr;
123  }
124 
126  new TBBSession(options, new DeviceMgr(devices), this);
127  {
128  mutex_lock l(sessions_lock_);
129  sessions_.push_back(session);
130  }
131  return session;
132  }
133 
134  Status Reset(const SessionOptions& options,
135  const std::vector<string>& containers) override {
136  std::vector<TBBSession*> sessions_to_reset;
137  {
138  mutex_lock l(sessions_lock_);
139  // We create a copy to ensure that we don't have a deadlock when
140  // session->Close calls the TBBSessionFactory.Deregister, which
141  // acquires sessions_lock_.
142  std::swap(sessions_to_reset, sessions_);
143  }
144  Status s;
145  for (auto session : sessions_to_reset) {
146  s.Update(session->Reset(containers));
147  }
148  // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
149  // it doesn't close the sessions?
150  for (auto session : sessions_to_reset) {
151  s.Update(session->Close());
152  }
153  return s;
154  }
155 
157  mutex_lock l(sessions_lock_);
158  sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
159  sessions_.end());
160  }
161 
162  private:
164  std::vector<TBBSession*> sessions_ GUARDED_BY(sessions_lock_);
165 };
166 
168  public:
170  SessionFactory::Register("TBB_SESSION", new TBBSessionFactory());
171  }
172 };
174 
175 std::atomic_int_fast64_t TBBSession::step_id_counter_(1);
176 
177 // NOTE: On Android with a single device, there is never
178 // a risk of an OpKernel blocking indefinitely:
179 //
180 // 1) No operations do I/O that depends on other simultaneous kernels,
181 //
182 // 2) Recv nodes always complete immediately: The inputs are sent into
183 // the local rendezvous before we start the executor, so the
184 // corresponding recvs will not block.
185 //
186 // Based on these assumptions, we can use the same thread pool for
187 // both "non-blocking" and "blocking" OpKernels on Android.
188 //
189 // This may change down the road when we add support for multiple
190 // devices that run concurrently, in which case we will need to
191 // revisit this decision.
192 // Override to allow CMSSW FWK to schedule
193 void TBBSession::SchedClosure(tbb::task_arena& arena, tbb::task_group& g, std::function<void()> c) {
194  arena.execute( [&g, &c] () {g.run( c ); } );
195 }
196 
197 TBBSession::TBBSession(const SessionOptions& options,
198  const DeviceMgr* device_mgr,
199  TBBSessionFactory* const factory)
200  : options_(options),
201  device_mgr_(device_mgr),
202  factory_(factory),
203  cancellation_manager_(new CancellationManager()),
204  operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
205  // The default value of sync_on_finish will be flipped soon and this
206  // environment variable will be removed as well.
207  const Status status =
208  ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
209  if (!status.ok()) {
210  LOG(ERROR) << status.error_message();
211  }
212  // NOTE(mrry): We do not need to use a unique string for the session
213  // handle, because TBBSession owns its devices. This may change
214  // in future versions.
215  session_handle_ = "tbb";
216  int devices_added = 0;
217  if (options.config.log_device_placement()) {
218  const string mapping_str = device_mgr_->DeviceMappingString();
219  if (mapping_str.empty()) {
220  printf("Device mapping: no known devices.\n");
221  } else {
222  printf("Device mapping:\n%s", mapping_str.c_str());
223  }
224  LOG(INFO) << "Device mapping:\n" << mapping_str;
225  }
226  for (auto d : device_mgr_->ListDevices()) {
227  devices_.push_back(d);
228  device_set_.AddDevice(d);
229  d->op_segment()->AddHold(session_handle_);
230 
231  // The first device added is special: it is the 'client device' (a
232  // CPU device) from which we feed and fetch Tensors.
233  if (devices_added == 0) {
234  device_set_.set_client_device(d);
235  }
236  ++devices_added;
237  }
238 }
239 
241  if (!closed_) Close().IgnoreError();
242  for (auto& it : partial_runs_) {
243  it.second.reset(nullptr);
244  }
245  for (auto& it : executors_) {
246  it.second.reset();
247  }
248  for (auto d : device_mgr_->ListDevices()) {
249  d->op_segment()->RemoveHold(session_handle_);
250  }
251  for (auto d : device_mgr_->ListDevices()) {
252  d->ClearResourceMgr();
253  }
254  functions_.clear();
255  delete cancellation_manager_;
256 
257  execution_state_.reset(nullptr);
258  flib_def_.reset(nullptr);
259 }
260 
262  const GraphDef& graph, bool* out_already_initialized) {
263  // If already initialized, do nothing.
264  if (flib_def_ && execution_state_) {
265  *out_already_initialized = true;
266  return Status::OK();
267  }
268  // Set up the per-session execution state.
269  // NOTE(mrry): The function library created here will be used for
270  // all subsequent extensions of the graph.
271  flib_def_.reset(
272  new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
273  GraphExecutionStateOptions options;
274  options.device_set = &device_set_;
275  options.session_options = &options_;
276  // TODO(mrry,suharshs): We explicitly copy `graph` so that
277  // `MakeForBaseGraph()` can take ownership of its
278  // contents. Previously this happened implicitly in calls to the
279  // `GraphExecutionState`. Other sessions call
280  // `MakeForBaseGraph` in such a way that we can destructively read
281  // the passed-in `GraphDef`. In principle we could do the same here,
282  // with a wider refactoring; we might revise the direct session so
283  // that it copies the graph fewer times.
284  GraphDef temp(graph);
285  TF_RETURN_IF_ERROR(
286  GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
287  graph_created_ = true;
288  *out_already_initialized = false;
289  return Status::OK();
290 }
291 
292 Status TBBSession::Create(const GraphDef& graph) {
293  TF_RETURN_IF_ERROR(init_error_);
294  if (graph.node_size() > 0) {
295  mutex_lock l(graph_def_lock_);
296  if (graph_created_) {
297  return errors::AlreadyExists(
298  "A Graph has already been created for this session.");
299  }
300  return ExtendLocked(graph);
301  }
302  return Status::OK();
303 }
304 
305 Status TBBSession::Extend(const GraphDef& graph) {
306  TF_RETURN_IF_ERROR(CheckNotClosed());
307  mutex_lock l(graph_def_lock_);
308  return ExtendLocked(graph);
309 }
310 
311 Status TBBSession::ExtendLocked(const GraphDef& graph) {
312  bool already_initialized;
313  // If this is the first call, we can initialize the execution state
314  // with `graph` and do not need to call `Extend()`.
315  TF_RETURN_IF_ERROR(
316  MaybeInitializeExecutionState(graph, &already_initialized));
317  if (already_initialized) {
318  TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
319  std::unique_ptr<GraphExecutionState> state;
320  TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
321  execution_state_.swap(state);
322  }
323  return Status::OK();
324 }
325 
327  const std::vector<string>& output_names,
328  const std::vector<string>& target_nodes,
329  std::vector<Tensor>* outputs) {
330  RunMetadata run_metadata;
331  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
332  &run_metadata);
333 }
334 
336  const DebugOptions& debug_options, int64 session_run_index,
337  int64 executor_step_index, const std::vector<string>& input_names,
338  const std::vector<string>& output_names,
339  const std::vector<string>& target_names,
340  std::unique_ptr<DebuggerStateInterface>* debugger_state) {
341  TF_RETURN_IF_ERROR(
342  DebuggerStateRegistry::CreateState(debug_options, debugger_state));
343  TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
344  debug_options.global_step(), session_run_index, executor_step_index,
345  input_names, output_names, target_names));
346  return Status::OK();
347 }
348 
350  const DebugOptions& debug_options, Graph* graph, Device* device) {
351  std::unique_ptr<DebugGraphDecoratorInterface> decorator;
352  TF_RETURN_IF_ERROR(
353  DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
354 
355  TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
356  TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
357  return Status::OK();
358 }
359 
360 Status TBBSession::Run(const RunOptions& run_options,
361  const NamedTensorList& inputs,
362  const std::vector<string>& output_names,
363  const std::vector<string>& target_nodes,
364  std::vector<Tensor>* outputs,
365  RunMetadata* run_metadata) {
366  TF_RETURN_IF_ERROR(CheckNotClosed());
367  tbb_session_runs->GetCell()->IncrementBy(1);
368  {
369  mutex_lock l(graph_def_lock_);
370  if (!graph_created_) {
371  return errors::InvalidArgument(
372  "Session was not created with a graph before Run()!");
373  }
374  }
375 
376  // Extract the inputs names for this run of the session.
377  std::vector<string> input_tensor_names;
378  input_tensor_names.reserve(inputs.size());
379  for (const auto& it : inputs) {
380  input_tensor_names.push_back(it.first);
381  }
382 
383  // Check if we already have an executor for these arguments.
384  ExecutorsAndKeys* executors_and_keys;
385  RunStateArgs run_state_args(run_options.debug_options());
386 
387  Executor::Args args;
388  args.step_id = step_id_counter_.fetch_add(1);
389 
390  TF_RETURN_IF_ERROR(
391  GetOrCreateExecutors(input_tensor_names, output_names,
392  target_nodes, &executors_and_keys,
393  &run_state_args));
394  const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
395 
396  std::unique_ptr<DebuggerStateInterface> debugger_state;
397  if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
398  TF_RETURN_IF_ERROR(CreateDebuggerState(
399  run_options.debug_options(), args.step_id, executor_step_count,
400  input_tensor_names, output_names, target_nodes, &debugger_state));
401  }
402 
403  // Configure a call frame for the step, which we use to feed and
404  // fetch values to and from the executors.
405  FunctionCallFrame call_frame(executors_and_keys->input_types,
406  executors_and_keys->output_types);
407  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
408  for (const auto& it : inputs) {
409  if (it.second.dtype() == DT_RESOURCE) {
410  Tensor tensor_from_handle;
411  TF_RETURN_IF_ERROR(
412  ResourceHandleToInputTensor(it.second, &tensor_from_handle));
413  feed_args[executors_and_keys->input_name_to_index[it.first]] =
414  tensor_from_handle;
415  } else {
416  feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
417  }
418  }
419  const Status s = call_frame.SetArgs(feed_args);
420  if (errors::IsInternal(s)) {
421  return errors::InvalidArgument(s.error_message());
422  } else if (!s.ok()) {
423  return s;
424  }
425 
426  // Create a run state and start execution.
427  RunState run_state(args.step_id, &devices_);
428  run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
429  CancellationManager step_cancellation_manager;
430  args.call_frame = &call_frame;
431 
432  // Use a task_arena to avoid having unrelated tasks start
433  // running on this thread (which could start deadlocks)
434  tbb::task_arena taskArena;
435  tbb::task_group taskGroup;
436  // we are required to always call wait before destructor
437  auto doneWithTaskGroup = [&taskArena, &taskGroup](void *) { taskArena.execute([&taskGroup]() { taskGroup.wait();}); };
438  std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup) > guard(&taskGroup, doneWithTaskGroup);
439 
440  // Start parallel Executors.
441  const size_t num_executors = executors_and_keys->items.size();
442  ExecutorBarrier* barrier = new ExecutorBarrier(
443  num_executors, run_state.rendez, [&run_state](const Status& ret) {
444  {
445  mutex_lock l(run_state.mu_);
446  run_state.status.Update(ret);
447  }
448  run_state.executors_done.Notify();
449  });
450 
451  args.rendezvous = run_state.rendez;
452  args.cancellation_manager = &step_cancellation_manager;
453 
454  args.session_state = &session_state_;
455  args.tensor_store = &run_state.tensor_store;
456  args.step_container = &run_state.step_container;
457  if (LogMemory::IsEnabled()) {
458  LogMemory::RecordStep(args.step_id, run_state_args.handle);
459  }
460  args.sync_on_finish = sync_on_finish_;
461 
462  const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
463 
464  bool update_cost_model = false;
465  if (options_.config.graph_options().build_cost_model() > 0) {
466  const int64 build_cost_model_every =
467  options_.config.graph_options().build_cost_model();
468  const int64 build_cost_model_after =
469  options_.config.graph_options().build_cost_model_after();
470  int64 measure_step_count = executor_step_count - build_cost_model_after;
471  if (measure_step_count >= 0) {
472  update_cost_model =
473  ((measure_step_count + 1) % build_cost_model_every == 0);
474  }
475  }
476  if (do_trace || update_cost_model ||
477  run_options.report_tensor_allocations_upon_oom()) {
478  run_state.collector.reset(
479  new StepStatsCollector(run_metadata->mutable_step_stats()));
480  args.stats_collector = run_state.collector.get();
481  }
482 
483  std::unique_ptr<DeviceTracer> tracer;
484  if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
485  tracer = CreateDeviceTracer();
486  // tracer may be NULL on platforms without accelerators.
487  if (tracer) {
488  Status s = tracer->Start();
489  if (!s.ok()) {
490  run_state.executors_done.Notify();
491  delete barrier;
492  return s;
493  }
494  }
495  }
496 
497  // Register this step with session's cancellation manager, so that
498  // `Session::Close()` will cancel the step.
499  const CancellationToken cancellation_token =
500  cancellation_manager_->get_cancellation_token();
501  const bool already_cancelled = !cancellation_manager_->RegisterCallback(
502  cancellation_token, [&step_cancellation_manager]() {
503  step_cancellation_manager.StartCancel();
504  });
505  if (already_cancelled) {
506  // NOTE(mrry): If we don't explicitly notify
507  // `run_state.executors_done`, the RunState destructor would
508  // block on this notification.
509  run_state.executors_done.Notify();
510  delete barrier;
511  return errors::Cancelled("Run call was cancelled");
512  }
513 
514  // pass taskArena and taskGroup to SchedClosure
515  // consequently, disable TF's own thread logic inside the loop
516  Executor::Args::Runner default_runner = [this, &taskArena, &taskGroup](Executor::Args::Closure c) {
517  SchedClosure(taskArena, taskGroup, std::move(c));
518  };
519  for (const auto& item : executors_and_keys->items) {
520  // TODO(zhengxq): support partial run.
521  // TODO(zhengxq): if the device picks its own threadpool, we need to assign
522  // less threads to the main compute pool by default.
523  // thread::ThreadPool* device_thread_pool =
524  // item.device->tensorflow_device_thread_pool();
525  // if (!device_thread_pool) {
526  // args.runner = default_runner;
527  // } else {
528  // args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
529  // SchedClosure(device_thread_pool, std::move(c));
530  // };
531  // }
532  args.runner = default_runner;
533  item.executor->RunAsync(args, barrier->Get());
534  }
535 
536  // WaitForNotification will handle calling wait on taskGroup
537  guard.release();
538  WaitForNotification(taskArena, taskGroup, &run_state, &step_cancellation_manager,
539  run_options.timeout_in_ms() > 0
540  ? run_options.timeout_in_ms()
542 
543  if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
544  // The step has been cancelled: make sure we don't attempt to receive the
545  // outputs as this would make it block forever.
546  mutex_lock l(run_state.mu_);
547  run_state.status.Update(errors::Cancelled("Run call was cancelled"));
548  }
549 
550  if (tracer) {
551  TF_RETURN_IF_ERROR(tracer->Stop());
552  TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
553  }
554 
555  {
556  mutex_lock l(run_state.mu_);
557  TF_RETURN_IF_ERROR(run_state.status);
558  }
559 
560  // Receive outputs.
561  if (outputs) {
562  std::vector<Tensor> sorted_outputs;
563  const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
564  if (errors::IsInternal(s)) {
565  return errors::InvalidArgument(s.error_message());
566  } else if (!s.ok()) {
567  return s;
568  }
569  const bool unique_outputs =
570  output_names.size() == executors_and_keys->output_name_to_index.size();
571  // first_indices[i] = j implies that j is the smallest value for which
572  // output_names[i] == output_names[j].
573  std::vector<int> first_indices;
574  if (!unique_outputs) {
575  first_indices.resize(output_names.size());
576  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
577  for (int j = 0; j <= i; ++j) {
578  if (output_names[i] == output_names[j]) {
579  first_indices[i] = j;
580  break;
581  }
582  }
583  }
584  }
585  outputs->clear();
586  outputs->reserve(sorted_outputs.size());
587  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
588  const string& output_name = output_names[i];
589  if (first_indices.empty() || first_indices[i] == i) {
590  outputs->emplace_back(
591  std::move(sorted_outputs[executors_and_keys
592  ->output_name_to_index[output_name]]));
593  } else {
594  outputs->push_back((*outputs)[first_indices[i]]);
595  }
596  }
597  }
598 
599  // Save the output tensors of this run we choose to keep.
600  TF_RETURN_IF_ERROR(
601  run_state.tensor_store.SaveTensors(output_names, &session_state_));
602  if (args.stats_collector) {
603  args.stats_collector->Finalize();
604  }
605 
606  // Build and return the cost model as instructed.
607  mutex_lock l(executor_lock_);
608  if (update_cost_model) {
609  // Build the cost model
610  std::unordered_map<string, const Graph*> device_to_graph;
611  for (const PerPartitionExecutorsAndLib& partition :
612  executors_and_keys->items) {
613  const Graph* graph = partition.graph;
614  const string device = partition.flib->device()->name();
615  device_to_graph[device] = graph;
616  }
617  args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
618 
619  // annotate stats onto cost graph.
620  CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
621  for (const auto& item : executors_and_keys->items) {
622  TF_RETURN_IF_ERROR(
623  cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
624  }
625  }
626 
627  // If requested via RunOptions, output the partition graphs.
628  if (run_options.output_partition_graphs()) {
629  protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
630  run_metadata->mutable_partition_graphs();
631  for (const PerPartitionExecutorsAndLib& exec_and_lib :
632  executors_and_keys->items) {
633  GraphDef* partition_graph_def = partition_graph_defs->Add();
634  exec_and_lib.graph->ToGraphDef(partition_graph_def);
635  }
636  }
637 
638  return Status::OK();
639 }
640 
641 Status TBBSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
642  Tensor* retrieved_tensor) {
643  if (resource_tensor.dtype() != DT_RESOURCE) {
644  return errors::InvalidArgument(strings::StrCat(
645  "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
646  resource_tensor.dtype()));
647  }
648 
649  const ResourceHandle& resource_handle =
650  resource_tensor.scalar<ResourceHandle>()();
651 
652  if (resource_handle.container() ==
653  SessionState::kTensorHandleResourceTypeName) {
654  return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
655  } else {
656  return errors::InvalidArgument(strings::StrCat(
657  "Invalid resource type hash code: ", resource_handle.hash_code(),
658  "(name: ", resource_handle.name(),
659  " type: ", resource_handle.maybe_type_name(),
660  "). Perhaps a resource tensor was being provided as a feed? That is "
661  "not currently allowed. Please file an issue at "
662  "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
663  "short code snippet that leads to this error message."));
664  }
665 }
666 
668  gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
669  gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
670  RunStateArgs* run_state_args) {
671  int64 handle_name_counter_value = -1;
672  if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
673  handle_name_counter_value = handle_name_counter_.fetch_add(1);
674  }
675 
676  string debug_tensor_watches_summary;
677  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
678  debug_tensor_watches_summary = SummarizeDebugTensorWatches(
679  run_state_args->debug_options.debug_tensor_watch_opts());
680  }
681 
682  // Fast lookup path, no sorting.
683  const string key = strings::StrCat(
684  str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
685  str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
686  "/", debug_tensor_watches_summary);
687  // Set the handle, if it's needed to log memory or for partial run.
688  if (handle_name_counter_value >= 0) {
689  run_state_args->handle =
690  strings::StrCat(key, ";", handle_name_counter_value);
691  }
692 
693  // See if we already have the executors for this run.
694  {
695  mutex_lock l(executor_lock_); // could use reader lock
696  auto it = executors_.find(key);
697  if (it != executors_.end()) {
698  *executors_and_keys = it->second.get();
699  return Status::OK();
700  }
701  }
702 
703  // Slow lookup path, the unsorted key missed the cache.
704  // Sort the inputs and outputs, and look up with the sorted key in case an
705  // earlier call used a different order of inputs and outputs.
706  //
707  // We could consider some other signature instead of sorting that
708  // preserves the same property to avoid the sort in the future.
709  std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
710  std::sort(inputs_sorted.begin(), inputs_sorted.end());
711  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
712  std::sort(outputs_sorted.begin(), outputs_sorted.end());
713  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
714  std::sort(tn_sorted.begin(), tn_sorted.end());
715 
716  const string sorted_key = strings::StrCat(
717  str_util::Join(inputs_sorted, ","), "->",
718  str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
719  "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
720  // Set the handle, if its needed to log memory or for partial run.
721  if (handle_name_counter_value >= 0) {
722  run_state_args->handle =
723  strings::StrCat(sorted_key, ";", handle_name_counter_value);
724  }
725 
726  // See if we already have the executors for this run.
727  {
728  mutex_lock l(executor_lock_);
729  auto it = executors_.find(sorted_key);
730  if (it != executors_.end()) {
731  *executors_and_keys = it->second.get();
732  // Insert this under the original key.
733  executors_.emplace(key, it->second);
734  return Status::OK();
735  }
736  }
737 
738  // Nothing found, so create the executors and store in the cache.
739  BuildGraphOptions options;
740  options.feed_endpoints = inputs_sorted;
741  options.fetch_endpoints = outputs_sorted;
742  options.target_nodes = tn_sorted;
743  options.use_function_convention = !run_state_args->is_partial_run;
744  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
745  options.debug_options = run_state_args->debug_options;
746  }
747 
748  std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
749  std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
750 
751  // The executor_lock_ is intentionally released while executor is
752  // being created.
753  std::unordered_map<string, std::unique_ptr<Graph>> graphs;
754  TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
755  run_state_args, &ek->input_types,
756  &ek->output_types));
757 
758  if (run_state_args->is_partial_run) {
759  ek->graph = std::move(run_state_args->graph);
760  std::unordered_set<StringPiece, StringPieceHasher> names;
761  for (const string& input : inputs) {
762  TensorId id(ParseTensorName(input));
763  names.emplace(id.first);
764  }
765  for (const string& output : outputs) {
766  TensorId id(ParseTensorName(output));
767  names.emplace(id.first);
768  }
769  for (Node* n : ek->graph->nodes()) {
770  if (names.count(n->name()) > 0) {
771  ek->name_to_node.insert({n->name(), n});
772  }
773  }
774  }
775  ek->items.reserve(graphs.size());
776  const auto& optimizer_opts =
777  options_.config.graph_options().optimizer_options();
778 
779  int graph_def_version;
780  {
781  mutex_lock l(graph_def_lock_);
782  graph_def_version =
783  execution_state_->original_graph_def().versions().producer();
784  }
785  func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
786  device_mgr_.get(), options_.env, graph_def_version,
787  func_info->flib_def.get(), optimizer_opts));
788 
789  GraphOptimizer optimizer(optimizer_opts);
790  for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
791  const string& partition_name = iter->first;
792  std::unique_ptr<Graph>& partition_graph = iter->second;
793 
794  Device* device;
795  TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
796 
797  ek->items.resize(ek->items.size() + 1);
798  auto* item = &(ek->items.back());
799  auto lib = func_info->proc_flr->GetFLR(partition_name);
800  if (lib == nullptr) {
801  return errors::Internal("Could not find device: ", partition_name);
802  }
803  item->flib = lib;
804 
805  LocalExecutorParams params;
806  params.device = device;
807  params.function_library = lib;
808  auto opseg = device->op_segment();
809  params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
810  OpKernel** kernel) {
811  // We do not share the kernel via the OpSegment if the node is
812  // stateless, or a function.
813  // NOTE(mrry): We must not share function kernels (implemented
814  // using `CallOp`) between subgraphs, because `CallOp::handle_`
815  // is tied to a particular subgraph. Even if the function itself
816  // is stateful, the `CallOp` that invokes it is not.
817  if (!lib->IsStateful(ndef.op()) ||
818  lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
819  return lib->CreateKernel(ndef, kernel);
820  }
821  auto create_fn = [lib, &ndef](OpKernel** kernel) {
822  return lib->CreateKernel(ndef, kernel);
823  };
824  // Kernels created for subgraph nodes need to be cached. On
825  // cache miss, create_fn() is invoked to create a kernel based
826  // on the function library here + global op registry.
827  return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
828  create_fn);
829  };
830  params.delete_kernel = [lib](OpKernel* kernel) {
831  // If the node is stateful, opseg owns it. Otherwise, delete it.
832  if (kernel && !lib->IsStateful(kernel->type_string())) {
833  delete kernel;
834  }
835  };
836  params.node_outputs_cb = node_outputs_callback_;
837 
838  optimizer.Optimize(lib, options_.env, device, &iter->second,
839  /*shape_map=*/nullptr);
840 
841  // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
842  if (!options.debug_options.debug_tensor_watch_opts().empty()) {
843  TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
844  options.debug_options, partition_graph.get(), params.device));
845  }
846 
847  TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
848  device->name(),
849  partition_graph.get()));
850  // NewLocalExecutor takes ownership of partition_graph.
851  item->graph = partition_graph.get();
852  item->executor = nullptr;
853  item->device = device;
854  Executor* executor;
855  TF_RETURN_IF_ERROR(
856  NewLocalExecutor(params, partition_graph.release(), &executor));
857  item->executor.reset(executor);
858  }
859 
860  // Cache the mapping from input/output names to graph elements to
861  // avoid recomputing it every time.
862  if (!run_state_args->is_partial_run) {
863  // For regular `Run()`, we use the function calling convention, and so
864  // maintain a mapping from input/output names to
865  // argument/return-value ordinal index.
866  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
867  const string& input = inputs_sorted[i];
868  ek->input_name_to_index[input] = i;
869  }
870  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
871  const string& output = outputs_sorted[i];
872  ek->output_name_to_index[output] = i;
873  }
874  } else {
875  // For `PRun()`, we use the rendezvous calling convention, and so
876  // maintain a mapping from input/output names to rendezvous keys.
877  //
878  // We always use the first device as the device name portion of the
879  // key, even if we're feeding another graph.
880  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
881  const string& input = inputs_sorted[i];
882  ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
883  input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
884  }
885  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
886  const string& output = outputs_sorted[i];
887  ek->output_name_to_rendezvous_key[output] =
888  GetRendezvousKey(output, device_set_.client_device()->attributes(),
889  FrameAndIter(0, 0));
890  }
891  }
892 
893  // Reacquire the lock, try to insert into the map.
894  mutex_lock l(executor_lock_);
895  functions_.push_back(std::move(func_info));
896 
897  // Another thread may have created the entry before us, in which case we will
898  // reuse the already created one.
899  auto insert_result = executors_.emplace(sorted_key, ek);
900  // Insert the value under the original key, so the fast path lookup will work
901  // if the user uses the same order of inputs, outputs, and targets again.
902  executors_.emplace(key, insert_result.first->second);
903  *executors_and_keys = insert_result.first->second.get();
904 
905  return Status::OK();
906 }
907 
909  const BuildGraphOptions& subgraph_options,
910  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
911  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
912  RunStateArgs* run_state_args, DataTypeVector* input_types,
913  DataTypeVector* output_types) {
914  mutex_lock l(graph_def_lock_);
915  std::unique_ptr<ClientGraph> client_graph;
916 
917  std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
918  GraphExecutionState* execution_state = nullptr;
919  if (options_.config.graph_options().place_pruned_graph()) {
920  // Because we are placing pruned graphs, we need to create a
921  // new GraphExecutionState for every new unseen graph,
922  // and then place it.
923  GraphExecutionStateOptions prune_options;
924  prune_options.device_set = &device_set_;
925  prune_options.session_options = &options_;
926  prune_options.stateful_placements = stateful_placements_;
927  TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
928  execution_state_->original_graph_def().library(), prune_options,
929  execution_state_->original_graph_def(), subgraph_options,
930  &temp_exec_state_holder, &client_graph));
931  execution_state = temp_exec_state_holder.get();
932  } else {
933  execution_state = execution_state_.get();
934  TF_RETURN_IF_ERROR(
935  execution_state->BuildGraph(subgraph_options, &client_graph));
936  }
937 
938  if (subgraph_options.feed_endpoints.size() !=
939  client_graph->feed_types.size()) {
940  return errors::Internal(
941  "Graph pruning failed: requested number of feed endpoints = ",
942  subgraph_options.feed_endpoints.size(),
943  " versus number of pruned feed endpoints = ",
944  client_graph->feed_types.size());
945  }
946  if (subgraph_options.fetch_endpoints.size() !=
947  client_graph->fetch_types.size()) {
948  return errors::Internal(
949  "Graph pruning failed: requested number of fetch endpoints = ",
950  subgraph_options.fetch_endpoints.size(),
951  " versus number of pruned fetch endpoints = ",
952  client_graph->fetch_types.size());
953  }
954 
955  auto current_stateful_placements = execution_state->GetStatefulPlacements();
956  // Update our current state based on the execution_state's
957  // placements. If there are any mismatches for a node,
958  // we should fail, as this should never happen.
959  for (auto placement_pair : current_stateful_placements) {
960  const string& node_name = placement_pair.first;
961  const string& placement = placement_pair.second;
962  auto iter = stateful_placements_.find(node_name);
963  if (iter == stateful_placements_.end()) {
964  stateful_placements_.insert(std::make_pair(node_name, placement));
965  } else if (iter->second != placement) {
966  return errors::Internal(
967  "Stateful placement mismatch. "
968  "Current assignment of ",
969  node_name, " to ", iter->second, " does not match ", placement);
970  }
971  }
972 
973  stateful_placements_ = execution_state->GetStatefulPlacements();
974 
975  // Remember the graph in run state if this is a partial run.
976  if (run_state_args->is_partial_run) {
977  run_state_args->graph.reset(new Graph(flib_def_.get()));
978  CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
979  }
980 
981  // Partition the graph across devices.
982  PartitionOptions popts;
983  popts.node_to_loc = [](const Node* node) {
984  assert(node != nullptr);
985  return node->assigned_device_name();
986  };
987  popts.new_name = [this](const string& prefix) {
988  return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
989  };
990  popts.get_incarnation = [](const string& name) {
991  // The direct session does not have changing incarnation numbers.
992  // Just return '1'.
993  return 1;
994  };
995  popts.flib_def = &client_graph->graph.flib_def();
996  popts.control_flow_added = false;
997 
998  std::unordered_map<string, GraphDef> partitions;
999  TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1000 
1001  std::vector<string> device_names;
1002  for (auto device : devices_) {
1003  // Extract the LocalName from the device.
1004  device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1005  }
1006 
1007  // Check for valid partitions.
1008  for (const auto& partition : partitions) {
1009  const string local_partition_name =
1010  DeviceNameUtils::LocalName(partition.first);
1011  if (std::count(device_names.begin(), device_names.end(),
1012  local_partition_name) == 0) {
1013  return errors::InvalidArgument(
1014  "Creating a partition for ", local_partition_name,
1015  " which doesn't exist in the list of available devices. Available "
1016  "devices: ",
1017  str_util::Join(device_names, ","));
1018  }
1019  }
1020 
1021  for (const auto& partition : partitions) {
1022  std::unique_ptr<Graph> device_graph(
1023  new Graph(client_graph->flib_def.get()));
1024  GraphConstructorOptions device_opts;
1025  // There are internal operations (e.g., send/recv) that we now allow.
1026  device_opts.allow_internal_ops = true;
1027  device_opts.expect_device_spec = true;
1028  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1029  device_graph.get()));
1030  outputs->emplace(partition.first, std::move(device_graph));
1031  }
1032 
1033  GraphOptimizationPassOptions optimization_options;
1034  optimization_options.session_options = &options_;
1035  optimization_options.flib_def = client_graph->flib_def.get();
1036  optimization_options.partition_graphs = outputs;
1037  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1038  OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1039 
1040  Status s;
1041  for (auto& partition : *outputs) {
1042  const string& partition_name = partition.first;
1043  std::unique_ptr<Graph>* graph = &partition.second;
1044 
1045  VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1046  << partition_name;
1047 
1048  // Give the device an opportunity to rewrite its subgraph.
1049  Device* d;
1050  s = device_mgr_->LookupDevice(partition_name, &d);
1051  if (!s.ok()) break;
1052  s = d->MaybeRewriteGraph(graph);
1053  if (!s.ok()) {
1054  break;
1055  }
1056  }
1057  *flib_def = std::move(client_graph->flib_def);
1058  std::swap(*input_types, client_graph->feed_types);
1059  std::swap(*output_types, client_graph->fetch_types);
1060  return s;
1061 }
1062 
1064  std::vector<DeviceAttributes>* response) {
1065  response->clear();
1066  response->reserve(devices_.size());
1067  for (Device* d : devices_) {
1068  const DeviceAttributes& attrs = d->attributes();
1069  response->emplace_back(attrs);
1070  }
1072 }
1073 
1075  const std::vector<string>& containers) {
1076  device_mgr_->ClearContainers(containers);
1078 }
1079 
1081  cancellation_manager_->StartCancel();
1082  {
1083  mutex_lock l(closed_lock_);
1084  if (closed_) return ::tensorflow::Status::OK();
1085  closed_ = true;
1086  }
1087  if (factory_ != nullptr) factory_->Deregister(this);
1089 }
1090 
1092  const std::vector<string>& pending_input_names,
1093  const std::vector<string>& pending_output_names, int64 step_id,
1094  const std::vector<Device*>* devices)
1095  : step_container(step_id, [devices](const string& name) {
1096  for (auto d : *devices) {
1097  if (!d->resource_manager()->Cleanup(name).ok()) {
1098  // Do nothing...
1099  }
1100  }
1101  }) {
1102  // Initially all the feeds and fetches are pending.
1103  for (auto& name : pending_input_names) {
1104  pending_inputs[name] = false;
1105  }
1106  for (auto& name : pending_output_names) {
1107  pending_outputs[name] = false;
1108  }
1109 }
1110 
1112  const std::vector<Device*>* devices)
1113  : RunState({}, {}, step_id, devices) {}
1114 
1116  if (rendez != nullptr) {
1117  if (!executors_done.HasBeenNotified()) {
1118  rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1119  executors_done.WaitForNotification();
1120  }
1121  rendez->Unref();
1122  }
1123 }
1124 
1126  for (const auto& it : pending_inputs) {
1127  if (!it.second) return false;
1128  }
1129  for (const auto& it : pending_outputs) {
1130  if (!it.second) return false;
1131  }
1132  return true;
1133 }
1134 
1135 void TBBSession::WaitForNotification(tbb::task_arena& arena, tbb::task_group& taskGroup,
1136  RunState* run_state, CancellationManager* cm, int64 timeout_in_ms) {
1137  // Doing the wait in the arena adds this thread to the arena
1138  // and therefore tasks associated to the group can run on this thread
1139  arena.execute([&taskGroup]() { taskGroup.wait();} );
1140 
1141  const Status status =
1142  WaitForNotification(&run_state->executors_done, timeout_in_ms);
1143  if (!status.ok()) {
1144  {
1145  mutex_lock l(run_state->mu_);
1146  run_state->status.Update(status);
1147  }
1148  cm->StartCancel();
1149  // We must wait for the executors to complete, because they have borrowed
1150  // references to `cm` and other per-step state. After this notification, it
1151  // is safe to clean up the step.
1152  run_state->executors_done.WaitForNotification();
1153  }
1154 }
1155 
1157  Notification* notification, int64 timeout_in_ms) {
1158  if (timeout_in_ms > 0) {
1159  const int64 timeout_in_us = timeout_in_ms * 1000;
1160  const bool notified =
1161  WaitForNotificationWithTimeout(notification, timeout_in_us);
1162  if (!notified) {
1163  return Status(error::DEADLINE_EXCEEDED,
1164  "Timed out waiting for notification");
1165  }
1166  } else {
1167  notification->WaitForNotification();
1168  }
1169  return Status::OK();
1170 }
1171 
1172 } // namespace tensorflow
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:322
static boost::mutex mutex
Definition: Proxy.cc:11
std::vector< PerPartitionExecutorsAndLib > items
Definition: TBBSession.h:151
RunState(int64 step_id, const std::vector< Device * > *devices)
Definition: TBBSession.cc:1111
std::unordered_map< string, size_t > input_name_to_index
Definition: TBBSession.h:152
TGeoNode Node
::tensorflow::Status Close() override
Definition: TBBSession.cc:1080
std::vector< Device * > devices_
Definition: TBBSession.h:271
std::atomic< int64 > handle_name_counter_
Definition: TBBSession.h:330
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: TBBSession.h:93
std::atomic< int64 > edge_name_counter_
Definition: TBBSession.h:329
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: TBBSession.cc:1156
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: TBBSession.cc:349
CostModelManager cost_model_manager_
Definition: TBBSession.h:339
#define LOG(A)
Definition: config.py:1
std::unordered_map< string, bool > pending_inputs
Definition: TBBSession.h:187
::tensorflow::Status Create(const GraphDef &graph) override
Definition: TBBSession.cc:292
const std::string names[nVars_]
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Definition: Activities.doc:4
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Definition: TBBSession.cc:134
Session * NewSession(const SessionOptions &options) override
Definition: TBBSession.cc:112
static std::string const input
Definition: EdmProvDump.cc:48
Partition
Definition: HLTHPDFilter.cc:32
void Deregister(const TBBSession *session)
Definition: TBBSession.cc:156
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: TBBSession.cc:261
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: TBBSession.cc:667
static NTSessionRegistrar registrar
Definition: NTSession.cc:172
IntraProcessRendezvous * rendez
Definition: TBBSession.h:184
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::vector< TBBSession * > sessions_ GUARDED_BY(sessions_lock_)
void SchedClosure(tbb::task_arena &arena, tbb::task_group &g, std::function< void()> c)
Definition: TBBSession.cc:193
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: TBBSession.cc:641
CancellationManager * cancellation_manager_
Definition: TBBSession.h:306
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: TBBSession.cc:908
#define CMS_THREAD_SAFE
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: TBBSession.cc:1074
const DebugOptions & debug_options
Definition: TBBSession.h:210
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: TBBSession.cc:305
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unordered_map< string, bool > pending_outputs
Definition: TBBSession.h:188
::tensorflow::Status CheckNotClosed()
Definition: TBBSession.h:251
std::atomic_int_fast64_t step_count
Definition: TBBSession.h:148
const SessionOptions options_
Definition: TBBSession.h:267
TBBSessionFactory *const factory_
Definition: TBBSession.h:305
static std::atomic_int_fast64_t step_id_counter_
Definition: TBBSession.h:333
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:212
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:270
const int64 operation_timeout_in_ms_
Definition: TBBSession.h:336
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: TBBSession.h:341
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: TBBSession.cc:326
TBBSession(const SessionOptions &options, const DeviceMgr *device_mgr, TBBSessionFactory *factory)
Definition: TBBSession.cc:197
std::unique_ptr< Graph > graph
Definition: TBBSession.h:209
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: TBBSession.cc:335
SessionState session_state_
Definition: TBBSession.h:303
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: TBBSession.cc:311
graphs
Definition: cuy.py:962
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: TBBSession.cc:1063
DDCompactView::Graph Graph
def move(src, dest)
Definition: eostools.py:511
static const int ERROR
bool AcceptsOptions(const SessionOptions &options) override
Definition: TBBSession.cc:108