CMS 3D CMS Logo

NTSession.cc
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 //NOTE: The memory layout of the Node class changes depending on if NDEBUG was
16 // set when tensorflow was compiled. The reason is Node class holds two edgeset
17 // class instances and edgeset adds a member data if NDEBUG is set
18 
19 /*
20 This file is an adaptation of the original direct_session.cc file located at
21 https://github.com/tensorflow/tensorflow/blob/v1.5.0/tensorflow/core/common_runtime/direct_session.cc
22 to meet the demands of the software environment developed and used by the CMS collaboration.
23 
24 Changes with respect to the original code are documented in the NTSession.h header file.
25 */
26 
27 #if !defined(NDEBUG)
28 #define NDEBUG 1
29 #endif
30 
31 #include "NTSession.h"
32 
33 #include <atomic>
34 #include <string>
35 #include <vector>
36 
38 
39 #include "tensorflow/core/common_runtime/constant_folding.h"
40 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
41 #include "tensorflow/core/common_runtime/device_factory.h"
42 #include "tensorflow/core/common_runtime/executor.h"
43 #include "tensorflow/core/common_runtime/function.h"
44 #include "tensorflow/core/common_runtime/graph_optimizer.h"
45 #include "tensorflow/core/common_runtime/memory_types.h"
46 #include "tensorflow/core/common_runtime/optimization_registry.h"
47 #include "tensorflow/core/common_runtime/step_stats_collector.h"
48 #include "tensorflow/core/framework/function.h"
49 #include "tensorflow/core/framework/graph.pb_text.h"
50 #include "tensorflow/core/framework/graph.pb.h"
51 #include "tensorflow/core/framework/graph_def_util.h"
52 #include "tensorflow/core/framework/log_memory.h"
53 #include "tensorflow/core/framework/node_def.pb.h"
54 #include "tensorflow/core/framework/tensor.h"
55 #include "tensorflow/core/framework/versions.pb.h"
56 #include "tensorflow/core/graph/algorithm.h"
57 #include "tensorflow/core/graph/graph.h"
58 #include "tensorflow/core/graph/graph_constructor.h"
59 #include "tensorflow/core/graph/graph_partition.h"
60 #include "tensorflow/core/graph/subgraph.h"
61 #include "tensorflow/core/graph/tensor_id.h"
62 #include "tensorflow/core/lib/core/errors.h"
63 #include "tensorflow/core/lib/core/notification.h"
64 #include "tensorflow/core/lib/core/refcount.h"
65 #include "tensorflow/core/lib/core/status.h"
66 #include "tensorflow/core/lib/gtl/array_slice.h"
67 #include "tensorflow/core/lib/gtl/stl_util.h"
68 #include "tensorflow/core/lib/monitoring/counter.h"
69 #include "tensorflow/core/lib/strings/numbers.h"
70 #include "tensorflow/core/lib/strings/str_util.h"
71 #include "tensorflow/core/lib/strings/strcat.h"
72 #include "tensorflow/core/platform/cpu_info.h"
73 #include "tensorflow/core/platform/device_tracer.h"
74 #include "tensorflow/core/platform/logging.h"
75 #include "tensorflow/core/platform/mutex.h"
76 #include "tensorflow/core/platform/types.h"
77 #include "tensorflow/core/util/device_name_utils.h"
78 #include "tensorflow/core/util/env_var.h"
79 
80 
81 namespace tensorflow {
82 
83 namespace {
84 
85 CMS_THREAD_SAFE auto* nothreads_session_runs = monitoring::Counter<0>::New(
86  "/tensorflow/core/nothreads_session_runs",
87  "The number of times NTSession::Run() has been called.");
88 
89 
90 // TODO(vrv): Figure out how to unify the many different functions
91 // that generate RendezvousKey, since many of them have to be
92 // consistent with each other.
93 string GetRendezvousKey(const string& tensor_name,
94  const DeviceAttributes& device_info,
95  const FrameAndIter& frame_iter) {
96  return strings::StrCat(device_info.name(), ";",
97  strings::FpToString(device_info.incarnation()), ";",
98  device_info.name(), ";", tensor_name, ";",
99  frame_iter.frame_id, ":", frame_iter.iter_id);
100 }
101 
102 } // namespace
103 
104 class NTSessionFactory : public SessionFactory {
105  public:
107 
108  bool AcceptsOptions(const SessionOptions& options) override {
109  return options.target == "no_threads";
110  }
111 
112  Session* NewSession(const SessionOptions& options) override {
113  // Must do this before the CPU allocator is created.
114  if (options.config.graph_options().build_cost_model() > 0) {
115  EnableCPUAllocatorFullStats(true);
116  }
117  std::vector<Device*> devices;
118  const Status s = DeviceFactory::AddDevices(
119  options, "/job:localhost/replica:0/task:0", &devices);
120  if (!s.ok()) {
121  LOG(ERROR) << s;
122  return nullptr;
123  }
124 
125  NTSession* session =
126  new NTSession(options, new DeviceMgr(devices), this);
127  {
128  mutex_lock l(sessions_lock_);
129  sessions_.push_back(session);
130  }
131  return session;
132  }
133 
134  Status Reset(const SessionOptions& options,
135  const std::vector<string>& containers) override {
136  std::vector<NTSession*> sessions_to_reset;
137  {
138  mutex_lock l(sessions_lock_);
139  // We create a copy to ensure that we don't have a deadlock when
140  // session->Close calls the NTSessionFactory.Deregister, which
141  // acquires sessions_lock_.
142  std::swap(sessions_to_reset, sessions_);
143  }
144  Status s;
145  for (auto session : sessions_to_reset) {
146  s.Update(session->Reset(containers));
147  }
148  // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
149  // it doesn't close the sessions?
150  for (auto session : sessions_to_reset) {
151  s.Update(session->Close());
152  }
153  return s;
154  }
155 
156  void Deregister(const NTSession* session) {
157  mutex_lock l(sessions_lock_);
158  sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
159  sessions_.end());
160  }
161 
162  private:
164  std::vector<NTSession*> sessions_ GUARDED_BY(sessions_lock_);
165 };
166 
168  public:
170  SessionFactory::Register("NOTHREADS_SESSION", new NTSessionFactory());
171  }
172 };
174 
175 std::atomic_int_fast64_t NTSession::step_id_counter_(1);
176 
177 // NOTE: On Android with a single device, there is never
178 // a risk of an OpKernel blocking indefinitely:
179 //
180 // 1) No operations do I/O that depends on other simultaneous kernels,
181 //
182 // 2) Recv nodes always complete immediately: The inputs are sent into
183 // the local rendezvous before we start the executor, so the
184 // corresponding recvs will not block.
185 //
186 // Based on these assumptions, we can use the same thread pool for
187 // both "non-blocking" and "blocking" OpKernels on Android.
188 //
189 // This may change down the road when we add support for multiple
190 // devices that run concurrently, in which case we will need to
191 // revisit this decision.
192 // Override to allow CMSSW FWK to schedule
194  c();
195 }
196 
197 NTSession::NTSession(const SessionOptions& options,
198  const DeviceMgr* device_mgr,
199  NTSessionFactory* const factory)
200  : options_(options),
201  device_mgr_(device_mgr),
202  factory_(factory),
203  cancellation_manager_(new CancellationManager()),
204  operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
205  // The default value of sync_on_finish will be flipped soon and this
206  // environment variable will be removed as well.
207  const Status status =
208  ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
209  if (!status.ok()) {
210  LOG(ERROR) << status.error_message();
211  }
212  // NOTE(mrry): We do not need to use a unique string for the session
213  // handle, because NTSession owns its devices. This may change
214  // in future versions.
215  session_handle_ = "no_threads";
216  int devices_added = 0;
217  if (options.config.log_device_placement()) {
218  const string mapping_str = device_mgr_->DeviceMappingString();
219  if (mapping_str.empty()) {
220  printf("Device mapping: no known devices.\n");
221  } else {
222  printf("Device mapping:\n%s", mapping_str.c_str());
223  }
224  LOG(INFO) << "Device mapping:\n" << mapping_str;
225  }
226  for (auto d : device_mgr_->ListDevices()) {
227  devices_.push_back(d);
228  device_set_.AddDevice(d);
229  d->op_segment()->AddHold(session_handle_);
230 
231  // The first device added is special: it is the 'client device' (a
232  // CPU device) from which we feed and fetch Tensors.
233  if (devices_added == 0) {
234  device_set_.set_client_device(d);
235  }
236  ++devices_added;
237  }
238 }
239 
241  if (!closed_) Close().IgnoreError();
242  for (auto& it : partial_runs_) {
243  it.second.reset(nullptr);
244  }
245  for (auto& it : executors_) {
246  it.second.reset();
247  }
248  for (auto d : device_mgr_->ListDevices()) {
249  d->op_segment()->RemoveHold(session_handle_);
250  }
251  delete cancellation_manager_;
252 
253  execution_state_.reset(nullptr);
254  flib_def_.reset(nullptr);
255 }
256 
258  const GraphDef& graph, bool* out_already_initialized) {
259  // If already initialized, do nothing.
260  if (flib_def_ && execution_state_) {
261  *out_already_initialized = true;
262  return Status::OK();
263  }
264  // Set up the per-session execution state.
265  // NOTE(mrry): The function library created here will be used for
266  // all subsequent extensions of the graph.
267  flib_def_.reset(
268  new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
269  GraphExecutionStateOptions options;
270  options.device_set = &device_set_;
271  options.session_options = &options_;
272  // TODO(mrry,suharshs): We explicitly copy `graph` so that
273  // `MakeForBaseGraph()` can take ownership of its
274  // contents. Previously this happened implicitly in calls to the
275  // `GraphExecutionState`. Other sessions call
276  // `MakeForBaseGraph` in such a way that we can destructively read
277  // the passed-in `GraphDef`. In principle we could do the same here,
278  // with a wider refactoring; we might revise the direct session so
279  // that it copies the graph fewer times.
280  GraphDef temp(graph);
281  TF_RETURN_IF_ERROR(
282  GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
283  graph_created_ = true;
284  *out_already_initialized = false;
285  return Status::OK();
286 }
287 
288 Status NTSession::Create(const GraphDef& graph) {
289  TF_RETURN_IF_ERROR(init_error_);
290  if (graph.node_size() > 0) {
291  mutex_lock l(graph_def_lock_);
292  if (graph_created_) {
293  return errors::AlreadyExists(
294  "A Graph has already been created for this session.");
295  }
296  return ExtendLocked(graph);
297  }
298  return Status::OK();
299 }
300 
301 Status NTSession::Extend(const GraphDef& graph) {
302  TF_RETURN_IF_ERROR(CheckNotClosed());
303  mutex_lock l(graph_def_lock_);
304  return ExtendLocked(graph);
305 }
306 
307 Status NTSession::ExtendLocked(const GraphDef& graph) {
308  bool already_initialized;
309  // If this is the first call, we can initialize the execution state
310  // with `graph` and do not need to call `Extend()`.
311  TF_RETURN_IF_ERROR(
312  MaybeInitializeExecutionState(graph, &already_initialized));
313  if (already_initialized) {
314  TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
315  std::unique_ptr<GraphExecutionState> state;
316  TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
317  execution_state_.swap(state);
318  }
319  return Status::OK();
320 }
321 
323  const std::vector<string>& output_names,
324  const std::vector<string>& target_nodes,
325  std::vector<Tensor>* outputs) {
326  RunMetadata run_metadata;
327  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
328  &run_metadata);
329 }
330 
332  const DebugOptions& debug_options, int64 session_run_index,
333  int64 executor_step_index, const std::vector<string>& input_names,
334  const std::vector<string>& output_names,
335  const std::vector<string>& target_names,
336  std::unique_ptr<DebuggerStateInterface>* debugger_state) {
337  TF_RETURN_IF_ERROR(
338  DebuggerStateRegistry::CreateState(debug_options, debugger_state));
339  TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
340  debug_options.global_step(), session_run_index, executor_step_index,
341  input_names, output_names, target_names));
342  return Status::OK();
343 }
344 
346  const DebugOptions& debug_options, Graph* graph, Device* device) {
347  std::unique_ptr<DebugGraphDecoratorInterface> decorator;
348  TF_RETURN_IF_ERROR(
349  DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
350 
351  TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
352  TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
353  return Status::OK();
354 }
355 
356 Status NTSession::Run(const RunOptions& run_options,
357  const NamedTensorList& inputs,
358  const std::vector<string>& output_names,
359  const std::vector<string>& target_nodes,
360  std::vector<Tensor>* outputs,
361  RunMetadata* run_metadata) {
362  TF_RETURN_IF_ERROR(CheckNotClosed());
363  nothreads_session_runs->GetCell()->IncrementBy(1);
364  {
365  mutex_lock l(graph_def_lock_);
366  if (!graph_created_) {
367  return errors::InvalidArgument(
368  "Session was not created with a graph before Run()!");
369  }
370  }
371 
372  // Extract the inputs names for this run of the session.
373  std::vector<string> input_tensor_names;
374  input_tensor_names.reserve(inputs.size());
375  for (const auto& it : inputs) {
376  input_tensor_names.push_back(it.first);
377  }
378 
379  // Check if we already have an executor for these arguments.
380  ExecutorsAndKeys* executors_and_keys;
381  RunStateArgs run_state_args(run_options.debug_options());
382 
383  Executor::Args args;
384  args.step_id = step_id_counter_.fetch_add(1);
385 
386  TF_RETURN_IF_ERROR(
387  GetOrCreateExecutors(input_tensor_names, output_names, target_nodes,
388  &executors_and_keys, &run_state_args));
389  const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
390 
391  std::unique_ptr<DebuggerStateInterface> debugger_state;
392  if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
393  TF_RETURN_IF_ERROR(CreateDebuggerState(
394  run_options.debug_options(), args.step_id, executor_step_count,
395  input_tensor_names, output_names, target_nodes, &debugger_state));
396  }
397 
398  // Configure a call frame for the step, which we use to feed and
399  // fetch values to and from the executors.
400  FunctionCallFrame call_frame(executors_and_keys->input_types,
401  executors_and_keys->output_types);
402  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
403  for (const auto& it : inputs) {
404  if (it.second.dtype() == DT_RESOURCE) {
405  Tensor tensor_from_handle;
406  TF_RETURN_IF_ERROR(
407  ResourceHandleToInputTensor(it.second, &tensor_from_handle));
408  feed_args[executors_and_keys->input_name_to_index[it.first]] =
409  tensor_from_handle;
410  } else {
411  feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
412  }
413  }
414  const Status s = call_frame.SetArgs(feed_args);
415  if (errors::IsInternal(s)) {
416  return errors::InvalidArgument(s.error_message());
417  } else if (!s.ok()) {
418  return s;
419  }
420 
421  // Create a run state and start execution.
422  RunState run_state(args.step_id, &devices_);
423  run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
424  CancellationManager step_cancellation_manager;
425  args.call_frame = &call_frame;
426 
427  // Start parallel Executors.
428  const size_t num_executors = executors_and_keys->items.size();
429  ExecutorBarrier* barrier = new ExecutorBarrier(
430  num_executors, run_state.rendez, [&run_state](const Status& ret) {
431  {
432  mutex_lock l(run_state.mu_);
433  run_state.status.Update(ret);
434  }
435  run_state.executors_done.Notify();
436  });
437 
438  args.rendezvous = run_state.rendez;
439  args.cancellation_manager = &step_cancellation_manager;
440 
441  args.session_state = &session_state_;
442  args.tensor_store = &run_state.tensor_store;
443  args.step_container = &run_state.step_container;
444  if (LogMemory::IsEnabled()) {
445  LogMemory::RecordStep(args.step_id, run_state_args.handle);
446  }
447  args.sync_on_finish = sync_on_finish_;
448 
449  const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
450 
451  bool update_cost_model = false;
452  if (options_.config.graph_options().build_cost_model() > 0) {
453  const int64 build_cost_model_every =
454  options_.config.graph_options().build_cost_model();
455  const int64 build_cost_model_after =
456  options_.config.graph_options().build_cost_model_after();
457  int64 measure_step_count = executor_step_count - build_cost_model_after;
458  if (measure_step_count >= 0) {
459  update_cost_model =
460  ((measure_step_count + 1) % build_cost_model_every == 0);
461  }
462  }
463  if (do_trace || update_cost_model ||
464  run_options.report_tensor_allocations_upon_oom()) {
465  run_state.collector.reset(
466  new StepStatsCollector(run_metadata->mutable_step_stats()));
467  args.stats_collector = run_state.collector.get();
468  }
469 
470  std::unique_ptr<DeviceTracer> tracer;
471  if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
472  tracer = CreateDeviceTracer();
473  // tracer may be NULL on platforms without accelerators.
474  if (tracer) {
475  Status s = tracer->Start();
476  if (!s.ok()) {
477  run_state.executors_done.Notify();
478  delete barrier;
479  return s;
480  }
481  }
482  }
483 
484  // Register this step with session's cancellation manager, so that
485  // `Session::Close()` will cancel the step.
486  const CancellationToken cancellation_token =
487  cancellation_manager_->get_cancellation_token();
488  const bool already_cancelled = !cancellation_manager_->RegisterCallback(
489  cancellation_token, [&step_cancellation_manager]() {
490  step_cancellation_manager.StartCancel();
491  });
492  if (already_cancelled) {
493  // NOTE(mrry): If we don't explicitly notify
494  // `run_state.executors_done`, the RunState destructor would
495  // block on this notification.
496  run_state.executors_done.Notify();
497  delete barrier;
498  return errors::Cancelled("Run call was cancelled");
499  }
500 
501  // pass no arguments to SchedClosure
502  // consequently, disable TF's own thread logic inside the loop
503  Executor::Args::Runner default_runner = [this](Executor::Args::Closure c) {
505  };
506  for (const auto& item : executors_and_keys->items) {
507  // TODO(zhengxq): support partial run.
508  // TODO(zhengxq): if the device picks its own threadpool, we need to assign
509  // less threads to the main compute pool by default.
510  // thread::ThreadPool* device_thread_pool =
511  // item.device->tensorflow_device_thread_pool();
512  // if (!device_thread_pool) {
513  // args.runner = default_runner;
514  // } else {
515  // args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
516  // SchedClosure(device_thread_pool, std::move(c));
517  // };
518  // }
519  args.runner = default_runner;
520  item.executor->RunAsync(args, barrier->Get());
521  }
522 
523  WaitForNotification(&run_state, &step_cancellation_manager,
524  run_options.timeout_in_ms() > 0
525  ? run_options.timeout_in_ms()
527 
528  if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
529  // The step has been cancelled: make sure we don't attempt to receive the
530  // outputs as this would make it block forever.
531  mutex_lock l(run_state.mu_);
532  run_state.status.Update(errors::Cancelled("Run call was cancelled"));
533  }
534 
535  if (tracer) {
536  TF_RETURN_IF_ERROR(tracer->Stop());
537  TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
538  }
539 
540  {
541  mutex_lock l(run_state.mu_);
542  TF_RETURN_IF_ERROR(run_state.status);
543  }
544 
545  // Receive outputs.
546  if (outputs) {
547  std::vector<Tensor> sorted_outputs;
548  const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
549  if (errors::IsInternal(s)) {
550  return errors::InvalidArgument(s.error_message());
551  } else if (!s.ok()) {
552  return s;
553  }
554  const bool unique_outputs =
555  output_names.size() == executors_and_keys->output_name_to_index.size();
556  // first_indices[i] = j implies that j is the smallest value for which
557  // output_names[i] == output_names[j].
558  std::vector<int> first_indices;
559  if (!unique_outputs) {
560  first_indices.resize(output_names.size());
561  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
562  for (int j = 0; j <= i; ++j) {
563  if (output_names[i] == output_names[j]) {
564  first_indices[i] = j;
565  break;
566  }
567  }
568  }
569  }
570  outputs->clear();
571  outputs->reserve(sorted_outputs.size());
572  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
573  const string& output_name = output_names[i];
574  if (first_indices.empty() || first_indices[i] == i) {
575  outputs->emplace_back(
576  std::move(sorted_outputs[executors_and_keys
577  ->output_name_to_index[output_name]]));
578  } else {
579  outputs->push_back((*outputs)[first_indices[i]]);
580  }
581  }
582  }
583 
584  // Save the output tensors of this run we choose to keep.
585  TF_RETURN_IF_ERROR(
586  run_state.tensor_store.SaveTensors(output_names, &session_state_));
587  if (args.stats_collector) {
588  args.stats_collector->Finalize();
589  }
590 
591  // Build and return the cost model as instructed.
592  mutex_lock l(executor_lock_);
593  if (update_cost_model) {
594  // Build the cost model
595  std::unordered_map<string, const Graph*> device_to_graph;
596  for (const PerPartitionExecutorsAndLib& partition :
597  executors_and_keys->items) {
598  const Graph* graph = partition.graph;
599  const string device = partition.flib->device()->name();
600  device_to_graph[device] = graph;
601  }
602  args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
603 
604  // annotate stats onto cost graph.
605  CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
606  for (const auto& item : executors_and_keys->items) {
607  TF_RETURN_IF_ERROR(
608  cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
609  }
610  }
611 
612  // If requested via RunOptions, output the partition graphs.
613  if (run_options.output_partition_graphs()) {
614  protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
615  run_metadata->mutable_partition_graphs();
616  for (const PerPartitionExecutorsAndLib& exec_and_lib :
617  executors_and_keys->items) {
618  GraphDef* partition_graph_def = partition_graph_defs->Add();
619  exec_and_lib.graph->ToGraphDef(partition_graph_def);
620  }
621  }
622 
623  return Status::OK();
624 }
625 
626 Status NTSession::PRunSetup(const std::vector<string>& input_names,
627  const std::vector<string>& output_names,
628  const std::vector<string>& target_nodes,
629  string* handle) {
630  TF_RETURN_IF_ERROR(CheckNotClosed());
631  {
632  mutex_lock l(graph_def_lock_);
633  if (!graph_created_) {
634  return errors::InvalidArgument(
635  "Session was not created with a graph before PRunSetup()!");
636  }
637  }
638 
639  // Check if we already have an executor for these arguments.
640  ExecutorsAndKeys* executors_and_keys;
641  // TODO(cais): TFDBG support for partial runs.
642  DebugOptions debug_options;
643  RunStateArgs run_state_args(debug_options);
644  run_state_args.is_partial_run = true;
645  TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
646  target_nodes, &executors_and_keys,
647  &run_state_args));
648 
649  // Create the run state and save it for future PRun calls.
650  Executor::Args args;
651  args.step_id = step_id_counter_.fetch_add(1);
652  RunState* run_state =
653  new RunState(input_names, output_names, args.step_id, &devices_);
654  run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
655  {
656  mutex_lock l(executor_lock_);
657  if (!partial_runs_
658  .emplace(run_state_args.handle,
659  std::unique_ptr<RunState>(run_state))
660  .second) {
661  return errors::Internal("The handle '", run_state_args.handle,
662  "' created for this partial run is not unique.");
663  }
664  }
665 
666  // Start parallel Executors.
667  const size_t num_executors = executors_and_keys->items.size();
668  ExecutorBarrier* barrier = new ExecutorBarrier(
669  num_executors, run_state->rendez, [run_state](const Status& ret) {
670  if (!ret.ok()) {
671  mutex_lock l(run_state->mu_);
672  run_state->status.Update(ret);
673  }
674  run_state->executors_done.Notify();
675  });
676 
677  args.rendezvous = run_state->rendez;
678  args.cancellation_manager = cancellation_manager_;
679  args.runner = [this](Executor::Args::Closure c) {
681  };
682  args.session_state = &session_state_;
683  args.tensor_store = &run_state->tensor_store;
684  args.step_container = &run_state->step_container;
685  if (LogMemory::IsEnabled()) {
686  LogMemory::RecordStep(args.step_id, run_state_args.handle);
687  }
688  args.sync_on_finish = sync_on_finish_;
689 
690  if (options_.config.graph_options().build_cost_model()) {
691  run_state->collector.reset(new StepStatsCollector(nullptr));
692  args.stats_collector = run_state->collector.get();
693  }
694 
695  for (auto& item : executors_and_keys->items) {
696  item.executor->RunAsync(args, barrier->Get());
697  }
698 
699  *handle = run_state_args.handle;
700  return Status::OK();
701 }
702 
704  const std::vector<string>& output_names,
705  std::vector<Tensor>* outputs) {
706  TF_RETURN_IF_ERROR(CheckNotClosed());
707  std::vector<string> parts = str_util::Split(handle, ';');
708  const string& key = parts[0];
709  // Get the executors for this partial run.
710  ExecutorsAndKeys* executors_and_keys;
711  RunState* run_state;
712  {
713  mutex_lock l(executor_lock_); // could use reader lock
714  auto exc_it = executors_.find(key);
715  if (exc_it == executors_.end()) {
716  return errors::InvalidArgument(
717  "Must run 'setup' before performing partial runs!");
718  }
719  executors_and_keys = exc_it->second.get();
720 
721  auto prun_it = partial_runs_.find(handle);
722  if (prun_it == partial_runs_.end()) {
723  return errors::InvalidArgument(
724  "Must run 'setup' before performing partial runs!");
725  }
726  run_state = prun_it->second.get();
727 
728  // Make sure that this is a new set of feeds that are still pending.
729  for (const auto& input : inputs) {
730  auto it = run_state->pending_inputs.find(input.first);
731  if (it == run_state->pending_inputs.end()) {
732  return errors::InvalidArgument(
733  "The feed ", input.first,
734  " was not specified in partial_run_setup.");
735  } else if (it->second) {
736  return errors::InvalidArgument("The feed ", input.first,
737  " has already been fed.");
738  }
739  }
740  // Check that this is a new set of fetches that are still pending.
741  for (const auto& output : output_names) {
742  auto it = run_state->pending_outputs.find(output);
743  if (it == run_state->pending_outputs.end()) {
744  return errors::InvalidArgument(
745  "The fetch ", output, " was not specified in partial_run_setup.");
746  } else if (it->second) {
747  return errors::InvalidArgument("The fetch ", output,
748  " has already been fetched.");
749  }
750  }
751  }
752 
753  // Check that this new set of fetches can be computed from all the
754  // feeds we have supplied.
755  TF_RETURN_IF_ERROR(
756  CheckFetch(inputs, output_names, executors_and_keys, run_state));
757 
758  // Send inputs.
759  Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
760 
761  // Receive outputs.
762  if (s.ok()) {
763  s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
764  }
765 
766  // Save the output tensors of this run we choose to keep.
767  if (s.ok()) {
768  s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
769  }
770 
771  {
772  mutex_lock l(executor_lock_);
773  // Delete the run state if there is an error or all fetches are done.
774  bool done = true;
775  if (s.ok()) {
776  {
777  mutex_lock l(run_state->mu_);
778  if (!run_state->status.ok()) {
779  LOG(WARNING) << "An error unrelated to this prun has been detected. "
780  << run_state->status;
781  }
782  }
783  for (const auto& input : inputs) {
784  auto it = run_state->pending_inputs.find(input.first);
785  it->second = true;
786  }
787  for (const auto& name : output_names) {
788  auto it = run_state->pending_outputs.find(name);
789  it->second = true;
790  }
791  done = run_state->PendingDone();
792  }
793  if (done) {
796  partial_runs_.erase(handle);
797  }
798  }
799 
800  return s;
801 }
802 
803 Status NTSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
804  Tensor* retrieved_tensor) {
805  if (resource_tensor.dtype() != DT_RESOURCE) {
806  return errors::InvalidArgument(strings::StrCat(
807  "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
808  resource_tensor.dtype()));
809  }
810 
811  const ResourceHandle& resource_handle =
812  resource_tensor.scalar<ResourceHandle>()();
813 
814  if (resource_handle.container() ==
815  SessionState::kTensorHandleResourceTypeName) {
816  return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
817  } else {
818  return errors::InvalidArgument(strings::StrCat(
819  "Invalid resource type hash code: ", resource_handle.hash_code(),
820  "(name: ", resource_handle.name(),
821  " type: ", resource_handle.maybe_type_name(),
822  "). Perhaps a resource tensor was being provided as a feed? That is "
823  "not currently allowed. Please file an issue at "
824  "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
825  "short code snippet that leads to this error message."));
826  }
827 }
828 
830  const ExecutorsAndKeys* executors_and_keys,
831  IntraProcessRendezvous* rendez) {
832  Status s;
833  Rendezvous::ParsedKey parsed;
834  // Insert the input tensors into the local rendezvous by their
835  // rendezvous key.
836  for (const auto& input : inputs) {
837  auto it =
838  executors_and_keys->input_name_to_rendezvous_key.find(input.first);
839  if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
840  return errors::Internal("'", input.first, "' is not a pre-defined feed.");
841  }
842  const string& input_key = it->second;
843 
844  s = Rendezvous::ParseKey(input_key, &parsed);
845  if (!s.ok()) {
846  rendez->StartAbort(s);
847  return s;
848  }
849 
850  if (input.second.dtype() == DT_RESOURCE) {
851  Tensor tensor_from_handle;
852  s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
853  if (s.ok()) {
854  s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
855  }
856  } else {
857  s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
858  }
859 
860  if (!s.ok()) {
861  rendez->StartAbort(s);
862  return s;
863  }
864  }
865  return Status::OK();
866 }
867 
869  const std::vector<string>& output_names,
870  const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
871  std::vector<Tensor>* outputs) {
872  Status s;
873  if (!output_names.empty()) {
874  outputs->resize(output_names.size());
875  }
876 
877  Rendezvous::ParsedKey parsed;
878  // Get the outputs from the rendezvous
879  for (size_t output_offset = 0; output_offset < output_names.size();
880  ++output_offset) {
881  const string& output_name = output_names[output_offset];
882  auto it =
883  executors_and_keys->output_name_to_rendezvous_key.find(output_name);
884  if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
885  return errors::Internal("'", output_name,
886  "' is not a pre-defined fetch.");
887  }
888  const string& output_key = it->second;
889  Tensor output_tensor;
890  bool is_dead;
891  IntraProcessRendezvous* rendez = run_state->rendez;
892 
893  s = Rendezvous::ParseKey(output_key, &parsed);
894  if (s.ok()) {
895  // Fetch data from the Rendezvous.
896  s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
898  if (is_dead && s.ok()) {
899  s = errors::InvalidArgument("The tensor returned for ", output_name,
900  " was not valid.");
901  }
902  }
903  if (!s.ok()) {
904  rendez->StartAbort(s);
905  outputs->clear();
906  return s;
907  }
908 
909  (*outputs)[output_offset] = output_tensor;
910  }
911  return Status::OK();
912 }
913 
915  const std::vector<string>& fetches,
916  const ExecutorsAndKeys* executors_and_keys,
917  const RunState* run_state) {
918  const Graph* graph = executors_and_keys->graph.get();
919  const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
920 
921  // Build the set of pending feeds that we haven't seen.
922  std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
923  {
924  mutex_lock l(executor_lock_);
925  for (const auto& input : run_state->pending_inputs) {
926  // Skip if the feed has already been fed.
927  if (input.second) continue;
928  TensorId id(ParseTensorName(input.first));
929  auto it = name_to_node->find(id.first);
930  if (it == name_to_node->end()) {
931  return errors::NotFound("Feed ", input.first, ": not found");
932  }
933  pending_feeds.insert(id);
934  }
935  }
936  for (const auto& it : feeds) {
937  TensorId id(ParseTensorName(it.first));
938  pending_feeds.erase(id);
939  }
940 
941  // Initialize the stack with the fetch nodes.
942  std::vector<const Node*> stack;
943  for (const string& fetch : fetches) {
944  TensorId id(ParseTensorName(fetch));
945  auto it = name_to_node->find(id.first);
946  if (it == name_to_node->end()) {
947  return errors::NotFound("Fetch ", fetch, ": not found");
948  }
949  stack.push_back(it->second);
950  }
951 
952  // Any tensor needed for fetches can't be in pending_feeds.
953  std::vector<bool> visited(graph->num_node_ids(), false);
954  while (!stack.empty()) {
955  const Node* n = stack.back();
956  stack.pop_back();
957 
958  for (const Edge* in_edge : n->in_edges()) {
959  const Node* in_node = in_edge->src();
960  if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
961  return errors::InvalidArgument("Fetch ", in_node->name(), ":",
962  in_edge->src_output(),
963  " can't be computed from the feeds"
964  " that have been fed so far.");
965  }
966  if (!visited[in_node->id()]) {
967  visited[in_node->id()] = true;
968  stack.push_back(in_node);
969  }
970  }
971  }
972  return Status::OK();
973 }
974 
976  gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
977  gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
978  RunStateArgs* run_state_args) {
979  int64 handle_name_counter_value = -1;
980  if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
981  handle_name_counter_value = handle_name_counter_.fetch_add(1);
982  }
983 
984  string debug_tensor_watches_summary;
985  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
986  debug_tensor_watches_summary = SummarizeDebugTensorWatches(
987  run_state_args->debug_options.debug_tensor_watch_opts());
988  }
989 
990  // Fast lookup path, no sorting.
991  const string key = strings::StrCat(
992  str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
993  str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
994  "/", debug_tensor_watches_summary);
995  // Set the handle, if it's needed to log memory or for partial run.
996  if (handle_name_counter_value >= 0) {
997  run_state_args->handle =
998  strings::StrCat(key, ";", handle_name_counter_value);
999  }
1000 
1001  // See if we already have the executors for this run.
1002  {
1003  mutex_lock l(executor_lock_); // could use reader lock
1004  auto it = executors_.find(key);
1005  if (it != executors_.end()) {
1006  *executors_and_keys = it->second.get();
1007  return Status::OK();
1008  }
1009  }
1010 
1011  // Slow lookup path, the unsorted key missed the cache.
1012  // Sort the inputs and outputs, and look up with the sorted key in case an
1013  // earlier call used a different order of inputs and outputs.
1014  //
1015  // We could consider some other signature instead of sorting that
1016  // preserves the same property to avoid the sort in the future.
1017  std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1018  std::sort(inputs_sorted.begin(), inputs_sorted.end());
1019  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1020  std::sort(outputs_sorted.begin(), outputs_sorted.end());
1021  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1022  std::sort(tn_sorted.begin(), tn_sorted.end());
1023 
1024  const string sorted_key = strings::StrCat(
1025  str_util::Join(inputs_sorted, ","), "->",
1026  str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
1027  "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1028  // Set the handle, if its needed to log memory or for partial run.
1029  if (handle_name_counter_value >= 0) {
1030  run_state_args->handle =
1031  strings::StrCat(sorted_key, ";", handle_name_counter_value);
1032  }
1033 
1034  // See if we already have the executors for this run.
1035  {
1036  mutex_lock l(executor_lock_);
1037  auto it = executors_.find(sorted_key);
1038  if (it != executors_.end()) {
1039  *executors_and_keys = it->second.get();
1040  // Insert this under the original key.
1041  executors_.emplace(key, it->second);
1042  return Status::OK();
1043  }
1044  }
1045 
1046  // Nothing found, so create the executors and store in the cache.
1047  BuildGraphOptions options;
1048  options.feed_endpoints = inputs_sorted;
1049  options.fetch_endpoints = outputs_sorted;
1050  options.target_nodes = tn_sorted;
1051  options.use_function_convention = !run_state_args->is_partial_run;
1052  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1053  options.debug_options = run_state_args->debug_options;
1054  }
1055 
1056  std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1057 
1058  // The executor_lock_ is intentionally released while executor is
1059  // being created.
1060  std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1061  TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
1062  run_state_args, &ek->input_types,
1063  &ek->output_types));
1064 
1065  if (run_state_args->is_partial_run) {
1066  ek->graph = std::move(run_state_args->graph);
1067  std::unordered_set<StringPiece, StringPieceHasher> names;
1068  for (const string& input : inputs) {
1069  TensorId id(ParseTensorName(input));
1070  names.emplace(id.first);
1071  }
1072  for (const string& output : outputs) {
1073  TensorId id(ParseTensorName(output));
1074  names.emplace(id.first);
1075  }
1076  for (Node* n : ek->graph->nodes()) {
1077  if (names.count(n->name()) > 0) {
1078  ek->name_to_node.insert({n->name(), n});
1079  }
1080  }
1081  }
1082  ek->items.reserve(graphs.size());
1083  const auto& optimizer_opts =
1084  options_.config.graph_options().optimizer_options();
1085 
1086  int graph_def_version;
1087  {
1088  mutex_lock l(graph_def_lock_);
1089  graph_def_version =
1090  execution_state_->original_graph_def().versions().producer();
1091  }
1092  ek->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1093  device_mgr_.get(), options_.env, graph_def_version, ek->flib_def.get(),
1094  optimizer_opts));
1095 
1096  GraphOptimizer optimizer(optimizer_opts);
1097  for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1098  const string& partition_name = iter->first;
1099  std::unique_ptr<Graph>& partition_graph = iter->second;
1100 
1101  Device* device;
1102  TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1103 
1104  ek->items.resize(ek->items.size() + 1);
1105  auto* item = &(ek->items.back());
1106  auto lib = ek->proc_flr->GetFLR(partition_name);
1107  if (lib == nullptr) {
1108  return errors::Internal("Could not find device: ", partition_name);
1109  }
1110  item->flib = lib;
1111 
1112  LocalExecutorParams params;
1113  params.device = device;
1114  params.function_library = lib;
1115  auto opseg = device->op_segment();
1116  params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
1117  OpKernel** kernel) {
1118  // We do not share the kernel via the OpSegment if the node is
1119  // stateless, or a function.
1120  // NOTE(mrry): We must not share function kernels (implemented
1121  // using `CallOp`) between subgraphs, because `CallOp::handle_`
1122  // is tied to a particular subgraph. Even if the function itself
1123  // is stateful, the `CallOp` that invokes it is not.
1124  if (!lib->IsStateful(ndef.op()) ||
1125  lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
1126  return lib->CreateKernel(ndef, kernel);
1127  }
1128  auto create_fn = [lib, &ndef](OpKernel** kernel) {
1129  return lib->CreateKernel(ndef, kernel);
1130  };
1131  // Kernels created for subgraph nodes need to be cached. On
1132  // cache miss, create_fn() is invoked to create a kernel based
1133  // on the function library here + global op registry.
1134  return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
1135  create_fn);
1136  };
1137  params.delete_kernel = [lib](OpKernel* kernel) {
1138  // If the node is stateful, opseg owns it. Otherwise, delete it.
1139  if (kernel && !lib->IsStateful(kernel->type_string())) {
1140  delete kernel;
1141  }
1142  };
1143  params.node_outputs_cb = node_outputs_callback_;
1144 
1145  optimizer.Optimize(lib, options_.env, device, &iter->second,
1146  /*shape_map=*/nullptr);
1147 
1148  // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
1149  if (!options.debug_options.debug_tensor_watch_opts().empty()) {
1150  TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1151  options.debug_options, partition_graph.get(), params.device));
1152  }
1153 
1154  TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1155  device->name(),
1156  partition_graph.get()));
1157  // NewLocalExecutor takes ownership of partition_graph.
1158  item->graph = partition_graph.get();
1159  item->executor = nullptr;
1160  item->device = device;
1161  Executor* executor;
1162  TF_RETURN_IF_ERROR(
1163  NewLocalExecutor(params, partition_graph.release(), &executor));
1164  item->executor.reset(executor);
1165  }
1166 
1167  // Cache the mapping from input/output names to graph elements to
1168  // avoid recomputing it every time.
1169  if (!run_state_args->is_partial_run) {
1170  // For regular `Run()`, we use the function calling convention, and so
1171  // maintain a mapping from input/output names to
1172  // argument/return-value ordinal index.
1173  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
1174  const string& input = inputs_sorted[i];
1175  ek->input_name_to_index[input] = i;
1176  }
1177  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
1178  const string& output = outputs_sorted[i];
1179  ek->output_name_to_index[output] = i;
1180  }
1181  } else {
1182  // For `PRun()`, we use the rendezvous calling convention, and so
1183  // maintain a mapping from input/output names to rendezvous keys.
1184  //
1185  // We always use the first device as the device name portion of the
1186  // key, even if we're feeding another graph.
1187  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
1188  const string& input = inputs_sorted[i];
1189  ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1190  input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1191  }
1192  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
1193  const string& output = outputs_sorted[i];
1194  ek->output_name_to_rendezvous_key[output] =
1195  GetRendezvousKey(output, device_set_.client_device()->attributes(),
1196  FrameAndIter(0, 0));
1197  }
1198  }
1199 
1200  // Reacquire the lock, try to insert into the map.
1201  mutex_lock l(executor_lock_);
1202 
1203  // Another thread may have created the entry before us, in which case we will
1204  // reuse the already created one.
1205  auto insert_result = executors_.emplace(sorted_key, ek);
1206  // Insert the value under the original key, so the fast path lookup will work
1207  // if the user uses the same order of inputs, outputs, and targets again.
1208  executors_.emplace(key, insert_result.first->second);
1209  *executors_and_keys = insert_result.first->second.get();
1210 
1211  return Status::OK();
1212 }
1213 
1215  const BuildGraphOptions& subgraph_options,
1216  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1217  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1218  RunStateArgs* run_state_args, DataTypeVector* input_types,
1219  DataTypeVector* output_types) {
1220  mutex_lock l(graph_def_lock_);
1221  std::unique_ptr<ClientGraph> client_graph;
1222 
1223  std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1224  GraphExecutionState* execution_state = nullptr;
1225  if (options_.config.graph_options().place_pruned_graph()) {
1226  // Because we are placing pruned graphs, we need to create a
1227  // new GraphExecutionState for every new unseen graph,
1228  // and then place it.
1229  GraphExecutionStateOptions prune_options;
1230  prune_options.device_set = &device_set_;
1231  prune_options.session_options = &options_;
1232  prune_options.stateful_placements = stateful_placements_;
1233  TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1234  execution_state_->original_graph_def().library(), prune_options,
1235  execution_state_->original_graph_def(), subgraph_options,
1236  &temp_exec_state_holder, &client_graph));
1237  execution_state = temp_exec_state_holder.get();
1238  } else {
1239  execution_state = execution_state_.get();
1240  TF_RETURN_IF_ERROR(
1241  execution_state->BuildGraph(subgraph_options, &client_graph));
1242  }
1243 
1244  if (subgraph_options.feed_endpoints.size() !=
1245  client_graph->feed_types.size()) {
1246  return errors::Internal(
1247  "Graph pruning failed: requested number of feed endpoints = ",
1248  subgraph_options.feed_endpoints.size(),
1249  " versus number of pruned feed endpoints = ",
1250  client_graph->feed_types.size());
1251  }
1252  if (subgraph_options.fetch_endpoints.size() !=
1253  client_graph->fetch_types.size()) {
1254  return errors::Internal(
1255  "Graph pruning failed: requested number of fetch endpoints = ",
1256  subgraph_options.fetch_endpoints.size(),
1257  " versus number of pruned fetch endpoints = ",
1258  client_graph->fetch_types.size());
1259  }
1260 
1261  auto current_stateful_placements = execution_state->GetStatefulPlacements();
1262  // Update our current state based on the execution_state's
1263  // placements. If there are any mismatches for a node,
1264  // we should fail, as this should never happen.
1265  for (auto placement_pair : current_stateful_placements) {
1266  const string& node_name = placement_pair.first;
1267  const string& placement = placement_pair.second;
1268  auto iter = stateful_placements_.find(node_name);
1269  if (iter == stateful_placements_.end()) {
1270  stateful_placements_.insert(std::make_pair(node_name, placement));
1271  } else if (iter->second != placement) {
1272  return errors::Internal(
1273  "Stateful placement mismatch. "
1274  "Current assignment of ",
1275  node_name, " to ", iter->second, " does not match ", placement);
1276  }
1277  }
1278 
1279  stateful_placements_ = execution_state->GetStatefulPlacements();
1280 
1281  // Remember the graph in run state if this is a partial run.
1282  if (run_state_args->is_partial_run) {
1283  run_state_args->graph.reset(new Graph(flib_def_.get()));
1284  CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1285  }
1286 
1287  // Partition the graph across devices.
1288  PartitionOptions popts;
1289  popts.node_to_loc = [](const Node* node) {
1290  assert(node != nullptr);
1291  return node->assigned_device_name();
1292  };
1293  popts.new_name = [this](const string& prefix) {
1294  return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1295  };
1296  popts.get_incarnation = [](const string& name) {
1297  // The direct session does not have changing incarnation numbers.
1298  // Just return '1'.
1299  return 1;
1300  };
1301  popts.flib_def = &client_graph->graph.flib_def();
1302  popts.control_flow_added = false;
1303 
1304  std::unordered_map<string, GraphDef> partitions;
1305  TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1306 
1307  std::vector<string> device_names;
1308  for (auto device : devices_) {
1309  // Extract the LocalName from the device.
1310  device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1311  }
1312 
1313  // Check for valid partitions.
1314  for (const auto& partition : partitions) {
1315  const string local_partition_name =
1316  DeviceNameUtils::LocalName(partition.first);
1317  if (std::count(device_names.begin(), device_names.end(),
1318  local_partition_name) == 0) {
1319  return errors::InvalidArgument(
1320  "Creating a partition for ", local_partition_name,
1321  " which doesn't exist in the list of available devices. Available "
1322  "devices: ",
1323  str_util::Join(device_names, ","));
1324  }
1325  }
1326 
1327  for (const auto& partition : partitions) {
1328  std::unique_ptr<Graph> device_graph(
1329  new Graph(client_graph->flib_def.get()));
1330  GraphConstructorOptions device_opts;
1331  // There are internal operations (e.g., send/recv) that we now allow.
1332  device_opts.allow_internal_ops = true;
1333  device_opts.expect_device_spec = true;
1334  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1335  device_graph.get()));
1336  outputs->emplace(partition.first, std::move(device_graph));
1337  }
1338 
1339  GraphOptimizationPassOptions optimization_options;
1340  optimization_options.session_options = &options_;
1341  optimization_options.flib_def = client_graph->flib_def.get();
1342  optimization_options.partition_graphs = outputs;
1343  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1344  OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1345 
1346  Status s;
1347  for (auto& partition : *outputs) {
1348  const string& partition_name = partition.first;
1349  std::unique_ptr<Graph>* graph = &partition.second;
1350 
1351  VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1352  << partition_name;
1353 
1354  // Give the device an opportunity to rewrite its subgraph.
1355  Device* d;
1356  s = device_mgr_->LookupDevice(partition_name, &d);
1357  if (!s.ok()) break;
1358  s = d->MaybeRewriteGraph(graph);
1359  if (!s.ok()) {
1360  break;
1361  }
1362  }
1363  *flib_def = std::move(client_graph->flib_def);
1364  std::swap(*input_types, client_graph->feed_types);
1365  std::swap(*output_types, client_graph->fetch_types);
1366  return s;
1367 }
1368 
1370  std::vector<DeviceAttributes>* response) {
1371  response->clear();
1372  response->reserve(devices_.size());
1373  for (Device* d : devices_) {
1374  const DeviceAttributes& attrs = d->attributes();
1375  response->emplace_back(attrs);
1376  }
1378 }
1379 
1381  const std::vector<string>& containers) {
1382  device_mgr_->ClearContainers(containers);
1384 }
1385 
1387  cancellation_manager_->StartCancel();
1388  {
1389  mutex_lock l(closed_lock_);
1390  if (closed_) return ::tensorflow::Status::OK();
1391  closed_ = true;
1392  }
1393  if (factory_ != nullptr) factory_->Deregister(this);
1395 }
1396 
1398  const std::vector<string>& pending_input_names,
1399  const std::vector<string>& pending_output_names, int64 step_id,
1400  const std::vector<Device*>* devices)
1401  : step_container(step_id, [devices](const string& name) {
1402  for (auto d : *devices) {
1403  if (!d->resource_manager()->Cleanup(name).ok()) {
1404  // Do nothing...
1405  }
1406  }
1407  }) {
1408  // Initially all the feeds and fetches are pending.
1409  for (auto& name : pending_input_names) {
1410  pending_inputs[name] = false;
1411  }
1412  for (auto& name : pending_output_names) {
1413  pending_outputs[name] = false;
1414  }
1415 }
1416 
1418  const std::vector<Device*>* devices)
1419  : RunState({}, {}, step_id, devices) {}
1420 
1422  if (rendez != nullptr) {
1423  if (!executors_done.HasBeenNotified()) {
1424  rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1425  executors_done.WaitForNotification();
1426  }
1427  rendez->Unref();
1428  }
1429 }
1430 
1432  for (const auto& it : pending_inputs) {
1433  if (!it.second) return false;
1434  }
1435  for (const auto& it : pending_outputs) {
1436  if (!it.second) return false;
1437  }
1438  return true;
1439 }
1440 
1442  CancellationManager* cm,
1443  int64 timeout_in_ms) {
1444  const Status status =
1445  WaitForNotification(&run_state->executors_done, timeout_in_ms);
1446  if (!status.ok()) {
1447  {
1448  mutex_lock l(run_state->mu_);
1449  run_state->status.Update(status);
1450  }
1451  cm->StartCancel();
1452  // We must wait for the executors to complete, because they have borrowed
1453  // references to `cm` and other per-step state. After this notification, it
1454  // is safe to clean up the step.
1455  run_state->executors_done.WaitForNotification();
1456  }
1457 }
1458 
1460  Notification* notification, int64 timeout_in_ms) {
1461  if (timeout_in_ms > 0) {
1462  const int64 timeout_in_us = timeout_in_ms * 1000;
1463  const bool notified =
1464  WaitForNotificationWithTimeout(notification, timeout_in_us);
1465  if (!notified) {
1466  return Status(error::DEADLINE_EXCEEDED,
1467  "Timed out waiting for notification");
1468  }
1469  } else {
1470  notification->WaitForNotification();
1471  }
1472  return Status::OK();
1473 }
1474 
1475 } // namespace tensorflow
static boost::mutex mutex
Definition: Proxy.cc:11
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: NTSession.cc:331
std::vector< PerPartitionExecutorsAndLib > items
Definition: NTSession.h:159
static std::atomic_int_fast64_t step_id_counter_
Definition: NTSession.h:342
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: NTSession.cc:1380
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:257
const SessionOptions options_
Definition: NTSession.h:279
static const HistoName names[]
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: NTSession.h:161
IntraProcessRendezvous * rendez
Definition: NTSession.h:176
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: NTSession.cc:803
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
Definition: NTSession.cc:626
void Deregister(const NTSession *session)
Definition: NTSession.cc:156
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: NTSession.cc:345
static const int WARNING
SessionState session_state_
Definition: NTSession.h:312
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: NTSession.h:163
#define LOG(A)
Definition: config.py:1
bool AcceptsOptions(const SessionOptions &options) override
Definition: NTSession.cc:108
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
Definition: NTSession.cc:829
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: NTSession.h:83
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Definition: NTSession.cc:134
RunState(int64 step_id, const std::vector< Device * > *devices)
Definition: NTSession.cc:1417
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:703
static std::string const input
Definition: EdmProvDump.cc:44
~NTSession() override
Definition: NTSession.cc:240
DeviceSet device_set_
Definition: NTSession.h:284
Partition
Definition: HLTHPDFilter.cc:32
std::unordered_map< string, bool > pending_outputs
Definition: NTSession.h:180
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:322
static NTSessionRegistrar registrar
Definition: NTSession.cc:173
::tensorflow::Status CheckNotClosed()
Definition: NTSession.h:263
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:307
stack
Definition: svgfig.py:558
#define CMS_THREAD_SAFE
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: NTSession.cc:1459
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: NTSession.cc:1369
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
Definition: NTSession.cc:914
std::unordered_map< string, size_t > input_name_to_index
Definition: NTSession.h:160
std::atomic< int64 > edge_name_counter_
Definition: NTSession.h:338
::tensorflow::Status Close() override
Definition: NTSession.cc:1386
std::atomic< int64 > handle_name_counter_
Definition: NTSession.h:339
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: NTSession.h:282
Session * NewSession(const SessionOptions &options) override
Definition: NTSession.cc:112
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unique_ptr< Graph > graph
Definition: NTSession.h:201
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
Definition: NTSession.cc:868
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: NTSession.cc:1214
std::vector< Device * > devices_
Definition: NTSession.h:283
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: NTSession.h:350
std::atomic_int_fast64_t step_count
Definition: NTSession.h:154
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:211
void SchedClosure(std::function< void()> c)
Definition: NTSession.cc:193
std::vector< NTSession * > sessions_ GUARDED_BY(sessions_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: NTSession.cc:975
std::unordered_map< string, bool > pending_inputs
Definition: NTSession.h:179
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
Definition: NTSession.cc:197
CancellationManager * cancellation_manager_
Definition: NTSession.h:315
std::unique_ptr< Graph > graph
Definition: NTSession.h:155
::tensorflow::Status Create(const GraphDef &graph) override
Definition: NTSession.cc:288
const int64 operation_timeout_in_ms_
Definition: NTSession.h:345
NTSessionFactory *const factory_
Definition: NTSession.h:314
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: NTSession.h:331
graphs
Definition: cuy.py:960
const DebugOptions & debug_options
Definition: NTSession.h:202
std::pair< std::string, std::shared_ptr< void > > fetch(const cond::Hash &payloadId, Session &session)
Definition: CondDBFetch.cc:327
CostModelManager cost_model_manager_
Definition: NTSession.h:348
def move(src, dest)
Definition: eostools.py:510
static const int ERROR
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: NTSession.cc:301
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: NTSession.h:84