CMS 3D CMS Logo

NTSession.cc
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 //NOTE: The memory layout of the Node class changes depending on if NDEBUG was
16 // set when tensorflow was compiled. The reason is Node class holds two edgeset
17 // class instances and edgeset adds a member data if NDEBUG is set
18 
19 /*
20 This file is an adaptation of the original direct_session.cc file located at
21 https://github.com/tensorflow/tensorflow/blob/v1.6.0/tensorflow/core/common_runtime/direct_session.cc
22 to meet the demands of the software environment developed and used by the CMS collaboration.
23 
24 Changes with respect to the original code are documented in the NTSession.h header file.
25 */
26 
27 #if !defined(NDEBUG)
28 #define NDEBUG 1
29 #endif
30 
31 #include "NTSession.h"
32 
33 #include <atomic>
34 #include <string>
35 #include <vector>
36 
38 
39 #include "tensorflow/core/common_runtime/constant_folding.h"
40 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
41 #include "tensorflow/core/common_runtime/device_factory.h"
42 #include "tensorflow/core/common_runtime/executor.h"
43 #include "tensorflow/core/common_runtime/function.h"
44 #include "tensorflow/core/common_runtime/graph_optimizer.h"
45 #include "tensorflow/core/common_runtime/memory_types.h"
46 #include "tensorflow/core/common_runtime/optimization_registry.h"
47 #include "tensorflow/core/common_runtime/step_stats_collector.h"
48 #include "tensorflow/core/framework/function.h"
49 #include "tensorflow/core/framework/graph.pb_text.h"
50 #include "tensorflow/core/framework/graph.pb.h"
51 #include "tensorflow/core/framework/graph_def_util.h"
52 #include "tensorflow/core/framework/log_memory.h"
53 #include "tensorflow/core/framework/node_def.pb.h"
54 #include "tensorflow/core/framework/tensor.h"
55 #include "tensorflow/core/framework/versions.pb.h"
56 #include "tensorflow/core/graph/algorithm.h"
57 #include "tensorflow/core/graph/graph.h"
58 #include "tensorflow/core/graph/graph_constructor.h"
59 #include "tensorflow/core/graph/graph_partition.h"
60 #include "tensorflow/core/graph/subgraph.h"
61 #include "tensorflow/core/graph/tensor_id.h"
62 #include "tensorflow/core/lib/core/errors.h"
63 #include "tensorflow/core/lib/core/notification.h"
64 #include "tensorflow/core/lib/core/refcount.h"
65 #include "tensorflow/core/lib/core/status.h"
66 #include "tensorflow/core/lib/gtl/array_slice.h"
67 #include "tensorflow/core/lib/gtl/stl_util.h"
68 #include "tensorflow/core/lib/monitoring/counter.h"
69 #include "tensorflow/core/lib/strings/numbers.h"
70 #include "tensorflow/core/lib/strings/str_util.h"
71 #include "tensorflow/core/lib/strings/strcat.h"
72 #include "tensorflow/core/platform/cpu_info.h"
73 #include "tensorflow/core/platform/device_tracer.h"
74 #include "tensorflow/core/platform/logging.h"
75 #include "tensorflow/core/platform/mutex.h"
76 #include "tensorflow/core/platform/types.h"
77 #include "tensorflow/core/util/device_name_utils.h"
78 #include "tensorflow/core/util/env_var.h"
79 
80 namespace tensorflow {
81 
82 namespace {
83 
84 CMS_THREAD_SAFE auto* nothreads_session_runs = monitoring::Counter<0>::New(
85  "/tensorflow/core/nothreads_session_runs",
86  "The number of times NTSession::Run() has been called.");
87 
88 
89 // TODO(vrv): Figure out how to unify the many different functions
90 // that generate RendezvousKey, since many of them have to be
91 // consistent with each other.
92 string GetRendezvousKey(const string& tensor_name,
93  const DeviceAttributes& device_info,
94  const FrameAndIter& frame_iter) {
95  return strings::StrCat(device_info.name(), ";",
96  strings::FpToString(device_info.incarnation()), ";",
97  device_info.name(), ";", tensor_name, ";",
98  frame_iter.frame_id, ":", frame_iter.iter_id);
99 }
100 
101 } // namespace
102 
103 class NTSessionFactory : public SessionFactory {
104  public:
106 
107  bool AcceptsOptions(const SessionOptions& options) override {
108  return options.target == "no_threads";
109  }
110 
111  Session* NewSession(const SessionOptions& options) override {
112  // Must do this before the CPU allocator is created.
113  if (options.config.graph_options().build_cost_model() > 0) {
114  EnableCPUAllocatorFullStats(true);
115  }
116  std::vector<Device*> devices;
117  const Status s = DeviceFactory::AddDevices(
118  options, "/job:localhost/replica:0/task:0", &devices);
119  if (!s.ok()) {
120  LOG(ERROR) << s;
121  return nullptr;
122  }
123 
124  NTSession* session =
125  new NTSession(options, new DeviceMgr(devices), this);
126  {
127  mutex_lock l(sessions_lock_);
128  sessions_.push_back(session);
129  }
130  return session;
131  }
132 
133  Status Reset(const SessionOptions& options,
134  const std::vector<string>& containers) override {
135  std::vector<NTSession*> sessions_to_reset;
136  {
137  mutex_lock l(sessions_lock_);
138  // We create a copy to ensure that we don't have a deadlock when
139  // session->Close calls the NTSessionFactory.Deregister, which
140  // acquires sessions_lock_.
141  std::swap(sessions_to_reset, sessions_);
142  }
143  Status s;
144  for (auto session : sessions_to_reset) {
145  s.Update(session->Reset(containers));
146  }
147  // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
148  // it doesn't close the sessions?
149  for (auto session : sessions_to_reset) {
150  s.Update(session->Close());
151  }
152  return s;
153  }
154 
155  void Deregister(const NTSession* session) {
156  mutex_lock l(sessions_lock_);
157  sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
158  sessions_.end());
159  }
160 
161  private:
163  std::vector<NTSession*> sessions_ GUARDED_BY(sessions_lock_);
164 };
165 
167  public:
169  SessionFactory::Register("NOTHREADS_SESSION", new NTSessionFactory());
170  }
171 };
173 
174 std::atomic_int_fast64_t NTSession::step_id_counter_(1);
175 
176 // NOTE: On Android with a single device, there is never
177 // a risk of an OpKernel blocking indefinitely:
178 //
179 // 1) No operations do I/O that depends on other simultaneous kernels,
180 //
181 // 2) Recv nodes always complete immediately: The inputs are sent into
182 // the local rendezvous before we start the executor, so the
183 // corresponding recvs will not block.
184 //
185 // Based on these assumptions, we can use the same thread pool for
186 // both "non-blocking" and "blocking" OpKernels on Android.
187 //
188 // This may change down the road when we add support for multiple
189 // devices that run concurrently, in which case we will need to
190 // revisit this decision.
191 // Override to allow CMSSW FWK to schedule
193  c();
194 }
195 
196 NTSession::NTSession(const SessionOptions& options,
197  const DeviceMgr* device_mgr,
198  NTSessionFactory* const factory)
199  : options_(options),
200  device_mgr_(device_mgr),
201  factory_(factory),
202  cancellation_manager_(new CancellationManager()),
203  operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
204  // The default value of sync_on_finish will be flipped soon and this
205  // environment variable will be removed as well.
206  const Status status =
207  ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
208  if (!status.ok()) {
209  LOG(ERROR) << status.error_message();
210  }
211  // NOTE(mrry): We do not need to use a unique string for the session
212  // handle, because NTSession owns its devices. This may change
213  // in future versions.
214  session_handle_ = "no_threads";
215  int devices_added = 0;
216  if (options.config.log_device_placement()) {
217  const string mapping_str = device_mgr_->DeviceMappingString();
218  if (mapping_str.empty()) {
219  printf("Device mapping: no known devices.\n");
220  } else {
221  printf("Device mapping:\n%s", mapping_str.c_str());
222  }
223  LOG(INFO) << "Device mapping:\n" << mapping_str;
224  }
225  for (auto d : device_mgr_->ListDevices()) {
226  devices_.push_back(d);
227  device_set_.AddDevice(d);
228  d->op_segment()->AddHold(session_handle_);
229 
230  // The first device added is special: it is the 'client device' (a
231  // CPU device) from which we feed and fetch Tensors.
232  if (devices_added == 0) {
233  device_set_.set_client_device(d);
234  }
235  ++devices_added;
236  }
237 }
238 
240  if (!closed_) Close().IgnoreError();
241  for (auto& it : partial_runs_) {
242  it.second.reset(nullptr);
243  }
244  for (auto& it : executors_) {
245  it.second.reset();
246  }
247  for (auto d : device_mgr_->ListDevices()) {
248  d->op_segment()->RemoveHold(session_handle_);
249  }
250  for (auto d : device_mgr_->ListDevices()) {
251  d->ClearResourceMgr();
252  }
253  functions_.clear();
254  delete cancellation_manager_;
255 
256  execution_state_.reset(nullptr);
257  flib_def_.reset(nullptr);
258 }
259 
261  const GraphDef& graph, bool* out_already_initialized) {
262  // If already initialized, do nothing.
263  if (flib_def_ && execution_state_) {
264  *out_already_initialized = true;
265  return Status::OK();
266  }
267  // Set up the per-session execution state.
268  // NOTE(mrry): The function library created here will be used for
269  // all subsequent extensions of the graph.
270  flib_def_.reset(
271  new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
272  GraphExecutionStateOptions options;
273  options.device_set = &device_set_;
274  options.session_options = &options_;
275  // TODO(mrry,suharshs): We explicitly copy `graph` so that
276  // `MakeForBaseGraph()` can take ownership of its
277  // contents. Previously this happened implicitly in calls to the
278  // `GraphExecutionState`. Other sessions call
279  // `MakeForBaseGraph` in such a way that we can destructively read
280  // the passed-in `GraphDef`. In principle we could do the same here,
281  // with a wider refactoring; we might revise the direct session so
282  // that it copies the graph fewer times.
283  GraphDef temp(graph);
284  TF_RETURN_IF_ERROR(
285  GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
286  graph_created_ = true;
287  *out_already_initialized = false;
288  return Status::OK();
289 }
290 
291 Status NTSession::Create(const GraphDef& graph) {
292  TF_RETURN_IF_ERROR(init_error_);
293  if (graph.node_size() > 0) {
294  mutex_lock l(graph_def_lock_);
295  if (graph_created_) {
296  return errors::AlreadyExists(
297  "A Graph has already been created for this session.");
298  }
299  return ExtendLocked(graph);
300  }
301  return Status::OK();
302 }
303 
304 Status NTSession::Extend(const GraphDef& graph) {
305  TF_RETURN_IF_ERROR(CheckNotClosed());
306  mutex_lock l(graph_def_lock_);
307  return ExtendLocked(graph);
308 }
309 
310 Status NTSession::ExtendLocked(const GraphDef& graph) {
311  bool already_initialized;
312  // If this is the first call, we can initialize the execution state
313  // with `graph` and do not need to call `Extend()`.
314  TF_RETURN_IF_ERROR(
315  MaybeInitializeExecutionState(graph, &already_initialized));
316  if (already_initialized) {
317  TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
318  std::unique_ptr<GraphExecutionState> state;
319  TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
320  execution_state_.swap(state);
321  }
322  return Status::OK();
323 }
324 
326  const std::vector<string>& output_names,
327  const std::vector<string>& target_nodes,
328  std::vector<Tensor>* outputs) {
329  RunMetadata run_metadata;
330  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
331  &run_metadata);
332 }
333 
335  const DebugOptions& debug_options, int64 session_run_index,
336  int64 executor_step_index, const std::vector<string>& input_names,
337  const std::vector<string>& output_names,
338  const std::vector<string>& target_names,
339  std::unique_ptr<DebuggerStateInterface>* debugger_state) {
340  TF_RETURN_IF_ERROR(
341  DebuggerStateRegistry::CreateState(debug_options, debugger_state));
342  TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
343  debug_options.global_step(), session_run_index, executor_step_index,
344  input_names, output_names, target_names));
345  return Status::OK();
346 }
347 
349  const DebugOptions& debug_options, Graph* graph, Device* device) {
350  std::unique_ptr<DebugGraphDecoratorInterface> decorator;
351  TF_RETURN_IF_ERROR(
352  DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
353 
354  TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
355  TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
356  return Status::OK();
357 }
358 
359 Status NTSession::Run(const RunOptions& run_options,
360  const NamedTensorList& inputs,
361  const std::vector<string>& output_names,
362  const std::vector<string>& target_nodes,
363  std::vector<Tensor>* outputs,
364  RunMetadata* run_metadata) {
365  TF_RETURN_IF_ERROR(CheckNotClosed());
366  nothreads_session_runs->GetCell()->IncrementBy(1);
367  {
368  mutex_lock l(graph_def_lock_);
369  if (!graph_created_) {
370  return errors::InvalidArgument(
371  "Session was not created with a graph before Run()!");
372  }
373  }
374 
375  // Extract the inputs names for this run of the session.
376  std::vector<string> input_tensor_names;
377  input_tensor_names.reserve(inputs.size());
378  for (const auto& it : inputs) {
379  input_tensor_names.push_back(it.first);
380  }
381 
382  // Check if we already have an executor for these arguments.
383  ExecutorsAndKeys* executors_and_keys;
384  RunStateArgs run_state_args(run_options.debug_options());
385 
386  Executor::Args args;
387  args.step_id = step_id_counter_.fetch_add(1);
388 
389  TF_RETURN_IF_ERROR(
390  GetOrCreateExecutors(input_tensor_names, output_names,
391  target_nodes, &executors_and_keys,
392  &run_state_args));
393  const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
394 
395  std::unique_ptr<DebuggerStateInterface> debugger_state;
396  if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
397  TF_RETURN_IF_ERROR(CreateDebuggerState(
398  run_options.debug_options(), args.step_id, executor_step_count,
399  input_tensor_names, output_names, target_nodes, &debugger_state));
400  }
401 
402  // Configure a call frame for the step, which we use to feed and
403  // fetch values to and from the executors.
404  FunctionCallFrame call_frame(executors_and_keys->input_types,
405  executors_and_keys->output_types);
406  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
407  for (const auto& it : inputs) {
408  if (it.second.dtype() == DT_RESOURCE) {
409  Tensor tensor_from_handle;
410  TF_RETURN_IF_ERROR(
411  ResourceHandleToInputTensor(it.second, &tensor_from_handle));
412  feed_args[executors_and_keys->input_name_to_index[it.first]] =
413  tensor_from_handle;
414  } else {
415  feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
416  }
417  }
418  const Status s = call_frame.SetArgs(feed_args);
419  if (errors::IsInternal(s)) {
420  return errors::InvalidArgument(s.error_message());
421  } else if (!s.ok()) {
422  return s;
423  }
424 
425  // Create a run state and start execution.
426  RunState run_state(args.step_id, &devices_);
427  run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
428  CancellationManager step_cancellation_manager;
429  args.call_frame = &call_frame;
430 
431  // Start parallel Executors.
432  const size_t num_executors = executors_and_keys->items.size();
433  ExecutorBarrier* barrier = new ExecutorBarrier(
434  num_executors, run_state.rendez, [&run_state](const Status& ret) {
435  {
436  mutex_lock l(run_state.mu_);
437  run_state.status.Update(ret);
438  }
439  run_state.executors_done.Notify();
440  });
441 
442  args.rendezvous = run_state.rendez;
443  args.cancellation_manager = &step_cancellation_manager;
444 
445  args.session_state = &session_state_;
446  args.tensor_store = &run_state.tensor_store;
447  args.step_container = &run_state.step_container;
448  if (LogMemory::IsEnabled()) {
449  LogMemory::RecordStep(args.step_id, run_state_args.handle);
450  }
451  args.sync_on_finish = sync_on_finish_;
452 
453  const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
454 
455  bool update_cost_model = false;
456  if (options_.config.graph_options().build_cost_model() > 0) {
457  const int64 build_cost_model_every =
458  options_.config.graph_options().build_cost_model();
459  const int64 build_cost_model_after =
460  options_.config.graph_options().build_cost_model_after();
461  int64 measure_step_count = executor_step_count - build_cost_model_after;
462  if (measure_step_count >= 0) {
463  update_cost_model =
464  ((measure_step_count + 1) % build_cost_model_every == 0);
465  }
466  }
467  if (do_trace || update_cost_model ||
468  run_options.report_tensor_allocations_upon_oom()) {
469  run_state.collector.reset(
470  new StepStatsCollector(run_metadata->mutable_step_stats()));
471  args.stats_collector = run_state.collector.get();
472  }
473 
474  std::unique_ptr<DeviceTracer> tracer;
475  if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
476  tracer = CreateDeviceTracer();
477  // tracer may be NULL on platforms without accelerators.
478  if (tracer) {
479  Status s = tracer->Start();
480  if (!s.ok()) {
481  run_state.executors_done.Notify();
482  delete barrier;
483  return s;
484  }
485  }
486  }
487 
488  // Register this step with session's cancellation manager, so that
489  // `Session::Close()` will cancel the step.
490  const CancellationToken cancellation_token =
491  cancellation_manager_->get_cancellation_token();
492  const bool already_cancelled = !cancellation_manager_->RegisterCallback(
493  cancellation_token, [&step_cancellation_manager]() {
494  step_cancellation_manager.StartCancel();
495  });
496  if (already_cancelled) {
497  // NOTE(mrry): If we don't explicitly notify
498  // `run_state.executors_done`, the RunState destructor would
499  // block on this notification.
500  run_state.executors_done.Notify();
501  delete barrier;
502  return errors::Cancelled("Run call was cancelled");
503  }
504 
505  // pass no arguments to SchedClosure
506  // consequently, disable TF's own thread logic inside the loop
507  Executor::Args::Runner default_runner = [this](Executor::Args::Closure c) {
509  };
510  for (const auto& item : executors_and_keys->items) {
511  // TODO(zhengxq): support partial run.
512  // TODO(zhengxq): if the device picks its own threadpool, we need to assign
513  // less threads to the main compute pool by default.
514  // thread::ThreadPool* device_thread_pool =
515  // item.device->tensorflow_device_thread_pool();
516  // if (!device_thread_pool) {
517  // args.runner = default_runner;
518  // } else {
519  // args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
520  // SchedClosure(device_thread_pool, std::move(c));
521  // };
522  // }
523  args.runner = default_runner;
524  item.executor->RunAsync(args, barrier->Get());
525  }
526 
527  WaitForNotification(&run_state, &step_cancellation_manager,
528  run_options.timeout_in_ms() > 0
529  ? run_options.timeout_in_ms()
531 
532  if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
533  // The step has been cancelled: make sure we don't attempt to receive the
534  // outputs as this would make it block forever.
535  mutex_lock l(run_state.mu_);
536  run_state.status.Update(errors::Cancelled("Run call was cancelled"));
537  }
538 
539  if (tracer) {
540  TF_RETURN_IF_ERROR(tracer->Stop());
541  TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
542  }
543 
544  {
545  mutex_lock l(run_state.mu_);
546  TF_RETURN_IF_ERROR(run_state.status);
547  }
548 
549  // Receive outputs.
550  if (outputs) {
551  std::vector<Tensor> sorted_outputs;
552  const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
553  if (errors::IsInternal(s)) {
554  return errors::InvalidArgument(s.error_message());
555  } else if (!s.ok()) {
556  return s;
557  }
558  const bool unique_outputs =
559  output_names.size() == executors_and_keys->output_name_to_index.size();
560  // first_indices[i] = j implies that j is the smallest value for which
561  // output_names[i] == output_names[j].
562  std::vector<int> first_indices;
563  if (!unique_outputs) {
564  first_indices.resize(output_names.size());
565  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
566  for (int j = 0; j <= i; ++j) {
567  if (output_names[i] == output_names[j]) {
568  first_indices[i] = j;
569  break;
570  }
571  }
572  }
573  }
574  outputs->clear();
575  outputs->reserve(sorted_outputs.size());
576  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
577  const string& output_name = output_names[i];
578  if (first_indices.empty() || first_indices[i] == i) {
579  outputs->emplace_back(
580  std::move(sorted_outputs[executors_and_keys
581  ->output_name_to_index[output_name]]));
582  } else {
583  outputs->push_back((*outputs)[first_indices[i]]);
584  }
585  }
586  }
587 
588  // Save the output tensors of this run we choose to keep.
589  TF_RETURN_IF_ERROR(
590  run_state.tensor_store.SaveTensors(output_names, &session_state_));
591  if (args.stats_collector) {
592  args.stats_collector->Finalize();
593  }
594 
595  // Build and return the cost model as instructed.
596  mutex_lock l(executor_lock_);
597  if (update_cost_model) {
598  // Build the cost model
599  std::unordered_map<string, const Graph*> device_to_graph;
600  for (const PerPartitionExecutorsAndLib& partition :
601  executors_and_keys->items) {
602  const Graph* graph = partition.graph;
603  const string device = partition.flib->device()->name();
604  device_to_graph[device] = graph;
605  }
606  args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
607 
608  // annotate stats onto cost graph.
609  CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
610  for (const auto& item : executors_and_keys->items) {
611  TF_RETURN_IF_ERROR(
612  cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
613  }
614  }
615 
616  // If requested via RunOptions, output the partition graphs.
617  if (run_options.output_partition_graphs()) {
618  protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
619  run_metadata->mutable_partition_graphs();
620  for (const PerPartitionExecutorsAndLib& exec_and_lib :
621  executors_and_keys->items) {
622  GraphDef* partition_graph_def = partition_graph_defs->Add();
623  exec_and_lib.graph->ToGraphDef(partition_graph_def);
624  }
625  }
626 
627  return Status::OK();
628 }
629 
630 Status NTSession::PRunSetup(const std::vector<string>& input_names,
631  const std::vector<string>& output_names,
632  const std::vector<string>& target_nodes,
633  string* handle) {
634  TF_RETURN_IF_ERROR(CheckNotClosed());
635  {
636  mutex_lock l(graph_def_lock_);
637  if (!graph_created_) {
638  return errors::InvalidArgument(
639  "Session was not created with a graph before PRunSetup()!");
640  }
641  }
642 
643  // Check if we already have an executor for these arguments.
644  ExecutorsAndKeys* executors_and_keys;
645  // TODO(cais): TFDBG support for partial runs.
646  DebugOptions debug_options;
647  RunStateArgs run_state_args(debug_options);
648  run_state_args.is_partial_run = true;
649  TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
650  target_nodes, &executors_and_keys,
651  &run_state_args));
652 
653  // Create the run state and save it for future PRun calls.
654  Executor::Args args;
655  args.step_id = step_id_counter_.fetch_add(1);
656  RunState* run_state =
657  new RunState(input_names, output_names, args.step_id, &devices_);
658  run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
659  {
660  mutex_lock l(executor_lock_);
661  if (!partial_runs_
662  .emplace(run_state_args.handle,
663  std::unique_ptr<RunState>(run_state))
664  .second) {
665  return errors::Internal("The handle '", run_state_args.handle,
666  "' created for this partial run is not unique.");
667  }
668  }
669 
670  // Start parallel Executors.
671  const size_t num_executors = executors_and_keys->items.size();
672  ExecutorBarrier* barrier = new ExecutorBarrier(
673  num_executors, run_state->rendez, [run_state](const Status& ret) {
674  if (!ret.ok()) {
675  mutex_lock l(run_state->mu_);
676  run_state->status.Update(ret);
677  }
678  run_state->executors_done.Notify();
679  });
680 
681  args.rendezvous = run_state->rendez;
682  args.cancellation_manager = cancellation_manager_;
683  args.runner = [this](Executor::Args::Closure c) {
685  };
686  args.session_state = &session_state_;
687  args.tensor_store = &run_state->tensor_store;
688  args.step_container = &run_state->step_container;
689  if (LogMemory::IsEnabled()) {
690  LogMemory::RecordStep(args.step_id, run_state_args.handle);
691  }
692  args.sync_on_finish = sync_on_finish_;
693 
694  if (options_.config.graph_options().build_cost_model()) {
695  run_state->collector.reset(new StepStatsCollector(nullptr));
696  args.stats_collector = run_state->collector.get();
697  }
698 
699  for (auto& item : executors_and_keys->items) {
700  item.executor->RunAsync(args, barrier->Get());
701  }
702 
703  *handle = run_state_args.handle;
704  return Status::OK();
705 }
706 
708  const std::vector<string>& output_names,
709  std::vector<Tensor>* outputs) {
710  TF_RETURN_IF_ERROR(CheckNotClosed());
711  std::vector<string> parts = str_util::Split(handle, ';');
712  const string& key = parts[0];
713  // Get the executors for this partial run.
714  ExecutorsAndKeys* executors_and_keys;
715  RunState* run_state;
716  {
717  mutex_lock l(executor_lock_); // could use reader lock
718  auto exc_it = executors_.find(key);
719  if (exc_it == executors_.end()) {
720  return errors::InvalidArgument(
721  "Must run 'setup' before performing partial runs!");
722  }
723  executors_and_keys = exc_it->second.get();
724 
725  auto prun_it = partial_runs_.find(handle);
726  if (prun_it == partial_runs_.end()) {
727  return errors::InvalidArgument(
728  "Must run 'setup' before performing partial runs!");
729  }
730  run_state = prun_it->second.get();
731 
732  // Make sure that this is a new set of feeds that are still pending.
733  for (const auto& input : inputs) {
734  auto it = run_state->pending_inputs.find(input.first);
735  if (it == run_state->pending_inputs.end()) {
736  return errors::InvalidArgument(
737  "The feed ", input.first,
738  " was not specified in partial_run_setup.");
739  } else if (it->second) {
740  return errors::InvalidArgument("The feed ", input.first,
741  " has already been fed.");
742  }
743  }
744  // Check that this is a new set of fetches that are still pending.
745  for (const auto& output : output_names) {
746  auto it = run_state->pending_outputs.find(output);
747  if (it == run_state->pending_outputs.end()) {
748  return errors::InvalidArgument(
749  "The fetch ", output, " was not specified in partial_run_setup.");
750  } else if (it->second) {
751  return errors::InvalidArgument("The fetch ", output,
752  " has already been fetched.");
753  }
754  }
755  }
756 
757  // Check that this new set of fetches can be computed from all the
758  // feeds we have supplied.
759  TF_RETURN_IF_ERROR(
760  CheckFetch(inputs, output_names, executors_and_keys, run_state));
761 
762  // Send inputs.
763  Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
764 
765  // Receive outputs.
766  if (s.ok()) {
767  s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
768  }
769 
770  // Save the output tensors of this run we choose to keep.
771  if (s.ok()) {
772  s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
773  }
774 
775  {
776  mutex_lock l(executor_lock_);
777  // Delete the run state if there is an error or all fetches are done.
778  bool done = true;
779  if (s.ok()) {
780  {
781  mutex_lock l(run_state->mu_);
782  if (!run_state->status.ok()) {
783  LOG(WARNING) << "An error unrelated to this prun has been detected. "
784  << run_state->status;
785  }
786  }
787  for (const auto& input : inputs) {
788  auto it = run_state->pending_inputs.find(input.first);
789  it->second = true;
790  }
791  for (const auto& name : output_names) {
792  auto it = run_state->pending_outputs.find(name);
793  it->second = true;
794  }
795  done = run_state->PendingDone();
796  }
797  if (done) {
800  partial_runs_.erase(handle);
801  }
802  }
803 
804  return s;
805 }
806 
807 Status NTSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
808  Tensor* retrieved_tensor) {
809  if (resource_tensor.dtype() != DT_RESOURCE) {
810  return errors::InvalidArgument(strings::StrCat(
811  "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
812  resource_tensor.dtype()));
813  }
814 
815  const ResourceHandle& resource_handle =
816  resource_tensor.scalar<ResourceHandle>()();
817 
818  if (resource_handle.container() ==
819  SessionState::kTensorHandleResourceTypeName) {
820  return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
821  } else {
822  return errors::InvalidArgument(strings::StrCat(
823  "Invalid resource type hash code: ", resource_handle.hash_code(),
824  "(name: ", resource_handle.name(),
825  " type: ", resource_handle.maybe_type_name(),
826  "). Perhaps a resource tensor was being provided as a feed? That is "
827  "not currently allowed. Please file an issue at "
828  "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
829  "short code snippet that leads to this error message."));
830  }
831 }
832 
834  const ExecutorsAndKeys* executors_and_keys,
835  IntraProcessRendezvous* rendez) {
836  Status s;
837  Rendezvous::ParsedKey parsed;
838  // Insert the input tensors into the local rendezvous by their
839  // rendezvous key.
840  for (const auto& input : inputs) {
841  auto it =
842  executors_and_keys->input_name_to_rendezvous_key.find(input.first);
843  if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
844  return errors::Internal("'", input.first, "' is not a pre-defined feed.");
845  }
846  const string& input_key = it->second;
847 
848  s = Rendezvous::ParseKey(input_key, &parsed);
849  if (!s.ok()) {
850  rendez->StartAbort(s);
851  return s;
852  }
853 
854  if (input.second.dtype() == DT_RESOURCE) {
855  Tensor tensor_from_handle;
856  s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
857  if (s.ok()) {
858  s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
859  }
860  } else {
861  s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
862  }
863 
864  if (!s.ok()) {
865  rendez->StartAbort(s);
866  return s;
867  }
868  }
869  return Status::OK();
870 }
871 
873  const std::vector<string>& output_names,
874  const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
875  std::vector<Tensor>* outputs) {
876  Status s;
877  if (!output_names.empty()) {
878  outputs->resize(output_names.size());
879  }
880 
881  Rendezvous::ParsedKey parsed;
882  // Get the outputs from the rendezvous
883  for (size_t output_offset = 0; output_offset < output_names.size();
884  ++output_offset) {
885  const string& output_name = output_names[output_offset];
886  auto it =
887  executors_and_keys->output_name_to_rendezvous_key.find(output_name);
888  if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
889  return errors::Internal("'", output_name,
890  "' is not a pre-defined fetch.");
891  }
892  const string& output_key = it->second;
893  Tensor output_tensor;
894  bool is_dead;
895  IntraProcessRendezvous* rendez = run_state->rendez;
896 
897  s = Rendezvous::ParseKey(output_key, &parsed);
898  if (s.ok()) {
899  // Fetch data from the Rendezvous.
900  s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
902  if (is_dead && s.ok()) {
903  s = errors::InvalidArgument("The tensor returned for ", output_name,
904  " was not valid.");
905  }
906  }
907  if (!s.ok()) {
908  rendez->StartAbort(s);
909  outputs->clear();
910  return s;
911  }
912 
913  (*outputs)[output_offset] = output_tensor;
914  }
915  return Status::OK();
916 }
917 
919  const std::vector<string>& fetches,
920  const ExecutorsAndKeys* executors_and_keys,
921  const RunState* run_state) {
922  const Graph* graph = executors_and_keys->graph.get();
923  const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
924 
925  // Build the set of pending feeds that we haven't seen.
926  std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
927  {
928  mutex_lock l(executor_lock_);
929  for (const auto& input : run_state->pending_inputs) {
930  // Skip if the feed has already been fed.
931  if (input.second) continue;
932  TensorId id(ParseTensorName(input.first));
933  auto it = name_to_node->find(id.first);
934  if (it == name_to_node->end()) {
935  return errors::NotFound("Feed ", input.first, ": not found");
936  }
937  pending_feeds.insert(id);
938  }
939  }
940  for (const auto& it : feeds) {
941  TensorId id(ParseTensorName(it.first));
942  pending_feeds.erase(id);
943  }
944 
945  // Initialize the stack with the fetch nodes.
946  std::vector<const Node*> stack;
947  for (const string& fetch : fetches) {
948  TensorId id(ParseTensorName(fetch));
949  auto it = name_to_node->find(id.first);
950  if (it == name_to_node->end()) {
951  return errors::NotFound("Fetch ", fetch, ": not found");
952  }
953  stack.push_back(it->second);
954  }
955 
956  // Any tensor needed for fetches can't be in pending_feeds.
957  std::vector<bool> visited(graph->num_node_ids(), false);
958  while (!stack.empty()) {
959  const Node* n = stack.back();
960  stack.pop_back();
961 
962  for (const Edge* in_edge : n->in_edges()) {
963  const Node* in_node = in_edge->src();
964  if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
965  return errors::InvalidArgument("Fetch ", in_node->name(), ":",
966  in_edge->src_output(),
967  " can't be computed from the feeds"
968  " that have been fed so far.");
969  }
970  if (!visited[in_node->id()]) {
971  visited[in_node->id()] = true;
972  stack.push_back(in_node);
973  }
974  }
975  }
976  return Status::OK();
977 }
978 
980  gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
981  gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
982  RunStateArgs* run_state_args) {
983  int64 handle_name_counter_value = -1;
984  if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
985  handle_name_counter_value = handle_name_counter_.fetch_add(1);
986  }
987 
988  string debug_tensor_watches_summary;
989  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
990  debug_tensor_watches_summary = SummarizeDebugTensorWatches(
991  run_state_args->debug_options.debug_tensor_watch_opts());
992  }
993 
994  // Fast lookup path, no sorting.
995  const string key = strings::StrCat(
996  str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
997  str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
998  "/", debug_tensor_watches_summary);
999  // Set the handle, if it's needed to log memory or for partial run.
1000  if (handle_name_counter_value >= 0) {
1001  run_state_args->handle =
1002  strings::StrCat(key, ";", handle_name_counter_value);
1003  }
1004 
1005  // See if we already have the executors for this run.
1006  {
1007  mutex_lock l(executor_lock_); // could use reader lock
1008  auto it = executors_.find(key);
1009  if (it != executors_.end()) {
1010  *executors_and_keys = it->second.get();
1011  return Status::OK();
1012  }
1013  }
1014 
1015  // Slow lookup path, the unsorted key missed the cache.
1016  // Sort the inputs and outputs, and look up with the sorted key in case an
1017  // earlier call used a different order of inputs and outputs.
1018  //
1019  // We could consider some other signature instead of sorting that
1020  // preserves the same property to avoid the sort in the future.
1021  std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1022  std::sort(inputs_sorted.begin(), inputs_sorted.end());
1023  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1024  std::sort(outputs_sorted.begin(), outputs_sorted.end());
1025  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1026  std::sort(tn_sorted.begin(), tn_sorted.end());
1027 
1028  const string sorted_key = strings::StrCat(
1029  str_util::Join(inputs_sorted, ","), "->",
1030  str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
1031  "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1032  // Set the handle, if its needed to log memory or for partial run.
1033  if (handle_name_counter_value >= 0) {
1034  run_state_args->handle =
1035  strings::StrCat(sorted_key, ";", handle_name_counter_value);
1036  }
1037 
1038  // See if we already have the executors for this run.
1039  {
1040  mutex_lock l(executor_lock_);
1041  auto it = executors_.find(sorted_key);
1042  if (it != executors_.end()) {
1043  *executors_and_keys = it->second.get();
1044  // Insert this under the original key.
1045  executors_.emplace(key, it->second);
1046  return Status::OK();
1047  }
1048  }
1049 
1050  // Nothing found, so create the executors and store in the cache.
1051  BuildGraphOptions options;
1052  options.feed_endpoints = inputs_sorted;
1053  options.fetch_endpoints = outputs_sorted;
1054  options.target_nodes = tn_sorted;
1055  options.use_function_convention = !run_state_args->is_partial_run;
1056  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1057  options.debug_options = run_state_args->debug_options;
1058  }
1059 
1060  std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1061  std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1062 
1063  // The executor_lock_ is intentionally released while executor is
1064  // being created.
1065  std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1066  TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
1067  run_state_args, &ek->input_types,
1068  &ek->output_types));
1069 
1070  if (run_state_args->is_partial_run) {
1071  ek->graph = std::move(run_state_args->graph);
1072  std::unordered_set<StringPiece, StringPieceHasher> names;
1073  for (const string& input : inputs) {
1074  TensorId id(ParseTensorName(input));
1075  names.emplace(id.first);
1076  }
1077  for (const string& output : outputs) {
1078  TensorId id(ParseTensorName(output));
1079  names.emplace(id.first);
1080  }
1081  for (Node* n : ek->graph->nodes()) {
1082  if (names.count(n->name()) > 0) {
1083  ek->name_to_node.insert({n->name(), n});
1084  }
1085  }
1086  }
1087  ek->items.reserve(graphs.size());
1088  const auto& optimizer_opts =
1089  options_.config.graph_options().optimizer_options();
1090 
1091  int graph_def_version;
1092  {
1093  mutex_lock l(graph_def_lock_);
1094  graph_def_version =
1095  execution_state_->original_graph_def().versions().producer();
1096  }
1097  func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1098  device_mgr_.get(), options_.env, graph_def_version,
1099  func_info->flib_def.get(), optimizer_opts));
1100 
1101  GraphOptimizer optimizer(optimizer_opts);
1102  for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1103  const string& partition_name = iter->first;
1104  std::unique_ptr<Graph>& partition_graph = iter->second;
1105 
1106  Device* device;
1107  TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1108 
1109  ek->items.resize(ek->items.size() + 1);
1110  auto* item = &(ek->items.back());
1111  auto lib = func_info->proc_flr->GetFLR(partition_name);
1112  if (lib == nullptr) {
1113  return errors::Internal("Could not find device: ", partition_name);
1114  }
1115  item->flib = lib;
1116 
1117  LocalExecutorParams params;
1118  params.device = device;
1119  params.function_library = lib;
1120  auto opseg = device->op_segment();
1121  params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
1122  OpKernel** kernel) {
1123  // We do not share the kernel via the OpSegment if the node is
1124  // stateless, or a function.
1125  // NOTE(mrry): We must not share function kernels (implemented
1126  // using `CallOp`) between subgraphs, because `CallOp::handle_`
1127  // is tied to a particular subgraph. Even if the function itself
1128  // is stateful, the `CallOp` that invokes it is not.
1129  if (!lib->IsStateful(ndef.op()) ||
1130  lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
1131  return lib->CreateKernel(ndef, kernel);
1132  }
1133  auto create_fn = [lib, &ndef](OpKernel** kernel) {
1134  return lib->CreateKernel(ndef, kernel);
1135  };
1136  // Kernels created for subgraph nodes need to be cached. On
1137  // cache miss, create_fn() is invoked to create a kernel based
1138  // on the function library here + global op registry.
1139  return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
1140  create_fn);
1141  };
1142  params.delete_kernel = [lib](OpKernel* kernel) {
1143  // If the node is stateful, opseg owns it. Otherwise, delete it.
1144  if (kernel && !lib->IsStateful(kernel->type_string())) {
1145  delete kernel;
1146  }
1147  };
1148  params.node_outputs_cb = node_outputs_callback_;
1149 
1150  optimizer.Optimize(lib, options_.env, device, &iter->second,
1151  /*shape_map=*/nullptr);
1152 
1153  // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
1154  if (!options.debug_options.debug_tensor_watch_opts().empty()) {
1155  TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1156  options.debug_options, partition_graph.get(), params.device));
1157  }
1158 
1159  TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1160  device->name(),
1161  partition_graph.get()));
1162  // NewLocalExecutor takes ownership of partition_graph.
1163  item->graph = partition_graph.get();
1164  item->executor = nullptr;
1165  item->device = device;
1166  Executor* executor;
1167  TF_RETURN_IF_ERROR(
1168  NewLocalExecutor(params, partition_graph.release(), &executor));
1169  item->executor.reset(executor);
1170  }
1171 
1172  // Cache the mapping from input/output names to graph elements to
1173  // avoid recomputing it every time.
1174  if (!run_state_args->is_partial_run) {
1175  // For regular `Run()`, we use the function calling convention, and so
1176  // maintain a mapping from input/output names to
1177  // argument/return-value ordinal index.
1178  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
1179  const string& input = inputs_sorted[i];
1180  ek->input_name_to_index[input] = i;
1181  }
1182  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
1183  const string& output = outputs_sorted[i];
1184  ek->output_name_to_index[output] = i;
1185  }
1186  } else {
1187  // For `PRun()`, we use the rendezvous calling convention, and so
1188  // maintain a mapping from input/output names to rendezvous keys.
1189  //
1190  // We always use the first device as the device name portion of the
1191  // key, even if we're feeding another graph.
1192  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
1193  const string& input = inputs_sorted[i];
1194  ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1195  input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1196  }
1197  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
1198  const string& output = outputs_sorted[i];
1199  ek->output_name_to_rendezvous_key[output] =
1200  GetRendezvousKey(output, device_set_.client_device()->attributes(),
1201  FrameAndIter(0, 0));
1202  }
1203  }
1204 
1205  // Reacquire the lock, try to insert into the map.
1206  mutex_lock l(executor_lock_);
1207  functions_.push_back(std::move(func_info));
1208 
1209  // Another thread may have created the entry before us, in which case we will
1210  // reuse the already created one.
1211  auto insert_result = executors_.emplace(sorted_key, ek);
1212  // Insert the value under the original key, so the fast path lookup will work
1213  // if the user uses the same order of inputs, outputs, and targets again.
1214  executors_.emplace(key, insert_result.first->second);
1215  *executors_and_keys = insert_result.first->second.get();
1216 
1217  return Status::OK();
1218 }
1219 
1221  const BuildGraphOptions& subgraph_options,
1222  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1223  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1224  RunStateArgs* run_state_args, DataTypeVector* input_types,
1225  DataTypeVector* output_types) {
1226  mutex_lock l(graph_def_lock_);
1227  std::unique_ptr<ClientGraph> client_graph;
1228 
1229  std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1230  GraphExecutionState* execution_state = nullptr;
1231  if (options_.config.graph_options().place_pruned_graph()) {
1232  // Because we are placing pruned graphs, we need to create a
1233  // new GraphExecutionState for every new unseen graph,
1234  // and then place it.
1235  GraphExecutionStateOptions prune_options;
1236  prune_options.device_set = &device_set_;
1237  prune_options.session_options = &options_;
1238  prune_options.stateful_placements = stateful_placements_;
1239  TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1240  execution_state_->original_graph_def().library(), prune_options,
1241  execution_state_->original_graph_def(), subgraph_options,
1242  &temp_exec_state_holder, &client_graph));
1243  execution_state = temp_exec_state_holder.get();
1244  } else {
1245  execution_state = execution_state_.get();
1246  TF_RETURN_IF_ERROR(
1247  execution_state->BuildGraph(subgraph_options, &client_graph));
1248  }
1249 
1250  if (subgraph_options.feed_endpoints.size() !=
1251  client_graph->feed_types.size()) {
1252  return errors::Internal(
1253  "Graph pruning failed: requested number of feed endpoints = ",
1254  subgraph_options.feed_endpoints.size(),
1255  " versus number of pruned feed endpoints = ",
1256  client_graph->feed_types.size());
1257  }
1258  if (subgraph_options.fetch_endpoints.size() !=
1259  client_graph->fetch_types.size()) {
1260  return errors::Internal(
1261  "Graph pruning failed: requested number of fetch endpoints = ",
1262  subgraph_options.fetch_endpoints.size(),
1263  " versus number of pruned fetch endpoints = ",
1264  client_graph->fetch_types.size());
1265  }
1266 
1267  auto current_stateful_placements = execution_state->GetStatefulPlacements();
1268  // Update our current state based on the execution_state's
1269  // placements. If there are any mismatches for a node,
1270  // we should fail, as this should never happen.
1271  for (auto placement_pair : current_stateful_placements) {
1272  const string& node_name = placement_pair.first;
1273  const string& placement = placement_pair.second;
1274  auto iter = stateful_placements_.find(node_name);
1275  if (iter == stateful_placements_.end()) {
1276  stateful_placements_.insert(std::make_pair(node_name, placement));
1277  } else if (iter->second != placement) {
1278  return errors::Internal(
1279  "Stateful placement mismatch. "
1280  "Current assignment of ",
1281  node_name, " to ", iter->second, " does not match ", placement);
1282  }
1283  }
1284 
1285  stateful_placements_ = execution_state->GetStatefulPlacements();
1286 
1287  // Remember the graph in run state if this is a partial run.
1288  if (run_state_args->is_partial_run) {
1289  run_state_args->graph.reset(new Graph(flib_def_.get()));
1290  CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1291  }
1292 
1293  // Partition the graph across devices.
1294  PartitionOptions popts;
1295  popts.node_to_loc = [](const Node* node) {
1296  assert(node != nullptr);
1297  return node->assigned_device_name();
1298  };
1299  popts.new_name = [this](const string& prefix) {
1300  return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1301  };
1302  popts.get_incarnation = [](const string& name) {
1303  // The direct session does not have changing incarnation numbers.
1304  // Just return '1'.
1305  return 1;
1306  };
1307  popts.flib_def = &client_graph->graph.flib_def();
1308  popts.control_flow_added = false;
1309 
1310  std::unordered_map<string, GraphDef> partitions;
1311  TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1312 
1313  std::vector<string> device_names;
1314  for (auto device : devices_) {
1315  // Extract the LocalName from the device.
1316  device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1317  }
1318 
1319  // Check for valid partitions.
1320  for (const auto& partition : partitions) {
1321  const string local_partition_name =
1322  DeviceNameUtils::LocalName(partition.first);
1323  if (std::count(device_names.begin(), device_names.end(),
1324  local_partition_name) == 0) {
1325  return errors::InvalidArgument(
1326  "Creating a partition for ", local_partition_name,
1327  " which doesn't exist in the list of available devices. Available "
1328  "devices: ",
1329  str_util::Join(device_names, ","));
1330  }
1331  }
1332 
1333  for (const auto& partition : partitions) {
1334  std::unique_ptr<Graph> device_graph(
1335  new Graph(client_graph->flib_def.get()));
1336  GraphConstructorOptions device_opts;
1337  // There are internal operations (e.g., send/recv) that we now allow.
1338  device_opts.allow_internal_ops = true;
1339  device_opts.expect_device_spec = true;
1340  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1341  device_graph.get()));
1342  outputs->emplace(partition.first, std::move(device_graph));
1343  }
1344 
1345  GraphOptimizationPassOptions optimization_options;
1346  optimization_options.session_options = &options_;
1347  optimization_options.flib_def = client_graph->flib_def.get();
1348  optimization_options.partition_graphs = outputs;
1349  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1350  OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1351 
1352  Status s;
1353  for (auto& partition : *outputs) {
1354  const string& partition_name = partition.first;
1355  std::unique_ptr<Graph>* graph = &partition.second;
1356 
1357  VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1358  << partition_name;
1359 
1360  // Give the device an opportunity to rewrite its subgraph.
1361  Device* d;
1362  s = device_mgr_->LookupDevice(partition_name, &d);
1363  if (!s.ok()) break;
1364  s = d->MaybeRewriteGraph(graph);
1365  if (!s.ok()) {
1366  break;
1367  }
1368  }
1369  *flib_def = std::move(client_graph->flib_def);
1370  std::swap(*input_types, client_graph->feed_types);
1371  std::swap(*output_types, client_graph->fetch_types);
1372  return s;
1373 }
1374 
1376  std::vector<DeviceAttributes>* response) {
1377  response->clear();
1378  response->reserve(devices_.size());
1379  for (Device* d : devices_) {
1380  const DeviceAttributes& attrs = d->attributes();
1381  response->emplace_back(attrs);
1382  }
1384 }
1385 
1387  const std::vector<string>& containers) {
1388  device_mgr_->ClearContainers(containers);
1390 }
1391 
1393  cancellation_manager_->StartCancel();
1394  {
1395  mutex_lock l(closed_lock_);
1396  if (closed_) return ::tensorflow::Status::OK();
1397  closed_ = true;
1398  }
1399  if (factory_ != nullptr) factory_->Deregister(this);
1401 }
1402 
1404  const std::vector<string>& pending_input_names,
1405  const std::vector<string>& pending_output_names, int64 step_id,
1406  const std::vector<Device*>* devices)
1407  : step_container(step_id, [devices](const string& name) {
1408  for (auto d : *devices) {
1409  if (!d->resource_manager()->Cleanup(name).ok()) {
1410  // Do nothing...
1411  }
1412  }
1413  }) {
1414  // Initially all the feeds and fetches are pending.
1415  for (auto& name : pending_input_names) {
1416  pending_inputs[name] = false;
1417  }
1418  for (auto& name : pending_output_names) {
1419  pending_outputs[name] = false;
1420  }
1421 }
1422 
1424  const std::vector<Device*>* devices)
1425  : RunState({}, {}, step_id, devices) {}
1426 
1428  if (rendez != nullptr) {
1429  if (!executors_done.HasBeenNotified()) {
1430  rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1431  executors_done.WaitForNotification();
1432  }
1433  rendez->Unref();
1434  }
1435 }
1436 
1438  for (const auto& it : pending_inputs) {
1439  if (!it.second) return false;
1440  }
1441  for (const auto& it : pending_outputs) {
1442  if (!it.second) return false;
1443  }
1444  return true;
1445 }
1446 
1448  CancellationManager* cm,
1449  int64 timeout_in_ms) {
1450  const Status status =
1451  WaitForNotification(&run_state->executors_done, timeout_in_ms);
1452  if (!status.ok()) {
1453  {
1454  mutex_lock l(run_state->mu_);
1455  run_state->status.Update(status);
1456  }
1457  cm->StartCancel();
1458  // We must wait for the executors to complete, because they have borrowed
1459  // references to `cm` and other per-step state. After this notification, it
1460  // is safe to clean up the step.
1461  run_state->executors_done.WaitForNotification();
1462  }
1463 }
1464 
1466  Notification* notification, int64 timeout_in_ms) {
1467  if (timeout_in_ms > 0) {
1468  const int64 timeout_in_us = timeout_in_ms * 1000;
1469  const bool notified =
1470  WaitForNotificationWithTimeout(notification, timeout_in_us);
1471  if (!notified) {
1472  return Status(error::DEADLINE_EXCEEDED,
1473  "Timed out waiting for notification");
1474  }
1475  } else {
1476  notification->WaitForNotification();
1477  }
1478  return Status::OK();
1479 }
1480 
1481 } // namespace tensorflow
static boost::mutex mutex
Definition: Proxy.cc:11
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: NTSession.cc:334
std::vector< PerPartitionExecutorsAndLib > items
Definition: NTSession.h:151
static std::atomic_int_fast64_t step_id_counter_
Definition: NTSession.h:353
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: NTSession.cc:1386
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:260
const SessionOptions options_
Definition: NTSession.h:287
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: NTSession.h:153
IntraProcessRendezvous * rendez
Definition: NTSession.h:184
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: NTSession.cc:807
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
Definition: NTSession.cc:630
void Deregister(const NTSession *session)
Definition: NTSession.cc:155
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: NTSession.cc:348
static const int WARNING
SessionState session_state_
Definition: NTSession.h:323
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: NTSession.h:155
#define LOG(A)
Definition: config.py:1
const std::string names[nVars_]
bool AcceptsOptions(const SessionOptions &options) override
Definition: NTSession.cc:107
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
Definition: NTSession.cc:833
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: NTSession.h:83
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Definition: NTSession.cc:133
RunState(int64 step_id, const std::vector< Device * > *devices)
Definition: NTSession.cc:1423
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:707
static std::string const input
Definition: EdmProvDump.cc:45
~NTSession() override
Definition: NTSession.cc:239
DeviceSet device_set_
Definition: NTSession.h:292
Partition
Definition: HLTHPDFilter.cc:32
std::unordered_map< string, bool > pending_outputs
Definition: NTSession.h:188
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:325
static NTSessionRegistrar registrar
Definition: NTSession.cc:172
::tensorflow::Status CheckNotClosed()
Definition: NTSession.h:271
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:310
stack
Definition: svgfig.py:558
#define CMS_THREAD_SAFE
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: NTSession.cc:1465
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: NTSession.cc:1375
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
Definition: NTSession.cc:918
std::unordered_map< string, size_t > input_name_to_index
Definition: NTSession.h:152
std::atomic< int64 > edge_name_counter_
Definition: NTSession.h:349
::tensorflow::Status Close() override
Definition: NTSession.cc:1392
std::atomic< int64 > handle_name_counter_
Definition: NTSession.h:350
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: NTSession.h:290
Session * NewSession(const SessionOptions &options) override
Definition: NTSession.cc:111
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unique_ptr< Graph > graph
Definition: NTSession.h:209
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
Definition: NTSession.cc:872
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: NTSession.cc:1220
std::vector< Device * > devices_
Definition: NTSession.h:291
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: NTSession.h:361
std::atomic_int_fast64_t step_count
Definition: NTSession.h:148
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:212
void SchedClosure(std::function< void()> c)
Definition: NTSession.cc:192
std::vector< NTSession * > sessions_ GUARDED_BY(sessions_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: NTSession.cc:979
std::unordered_map< string, bool > pending_inputs
Definition: NTSession.h:187
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
Definition: NTSession.cc:196
CancellationManager * cancellation_manager_
Definition: NTSession.h:326
std::unique_ptr< Graph > graph
Definition: NTSession.h:149
::tensorflow::Status Create(const GraphDef &graph) override
Definition: NTSession.cc:291
const int64 operation_timeout_in_ms_
Definition: NTSession.h:356
NTSessionFactory *const factory_
Definition: NTSession.h:325
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: NTSession.h:342
graphs
Definition: cuy.py:962
const DebugOptions & debug_options
Definition: NTSession.h:210
std::pair< std::string, std::shared_ptr< void > > fetch(const cond::Hash &payloadId, Session &session)
Definition: CondDBFetch.cc:323
DDCompactView::Graph Graph
CostModelManager cost_model_manager_
Definition: NTSession.h:359
def move(src, dest)
Definition: eostools.py:511
static const int ERROR
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: NTSession.cc:304
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: NTSession.h:84