CMS 3D CMS Logo

NTSession.cc
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 //NOTE: The memory layout of the Node class changes depending on if NDEBUG was
16 // set when tensorflow was compiled. The reason is Node class holds two edgeset
17 // class instances and edgeset adds a member data if NDEBUG is set
18 
19 /*
20 This file is an adaptation of the original direct_session.cc file located at
21 https://github.com/tensorflow/tensorflow/blob/v1.3.0/tensorflow/core/common_runtime/direct_session.cc
22 to meet the demands of the software environment developed and used by the CMS collaboration.
23 
24 Changes with respect to the original code are documented in the NTSession.h header file.
25 */
26 
27 #if !defined(NDEBUG)
28 #define NDEBUG 1
29 #endif
30 
31 #include "NTSession.h"
32 
33 #include <atomic>
34 #include <string>
35 #include <vector>
36 
38 
39 #include "tensorflow/core/common_runtime/constant_folding.h"
40 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
41 #include "tensorflow/core/common_runtime/device_factory.h"
42 #include "tensorflow/core/common_runtime/executor.h"
43 #include "tensorflow/core/common_runtime/function.h"
44 #include "tensorflow/core/common_runtime/graph_optimizer.h"
45 #include "tensorflow/core/common_runtime/memory_types.h"
46 #include "tensorflow/core/common_runtime/optimization_registry.h"
47 #include "tensorflow/core/common_runtime/simple_placer.h"
48 #include "tensorflow/core/common_runtime/step_stats_collector.h"
49 #include "tensorflow/core/framework/function.h"
50 #include "tensorflow/core/framework/graph.pb_text.h"
51 #include "tensorflow/core/framework/graph.pb.h"
52 #include "tensorflow/core/framework/graph_def_util.h"
53 #include "tensorflow/core/framework/log_memory.h"
54 #include "tensorflow/core/framework/node_def.pb.h"
55 #include "tensorflow/core/framework/tensor.h"
56 #include "tensorflow/core/framework/versions.pb.h"
57 #include "tensorflow/core/graph/algorithm.h"
58 #include "tensorflow/core/graph/graph.h"
59 #include "tensorflow/core/graph/graph_constructor.h"
60 #include "tensorflow/core/graph/graph_partition.h"
61 #include "tensorflow/core/graph/subgraph.h"
62 #include "tensorflow/core/graph/tensor_id.h"
63 #include "tensorflow/core/lib/core/errors.h"
64 #include "tensorflow/core/lib/core/notification.h"
65 #include "tensorflow/core/lib/core/refcount.h"
66 #include "tensorflow/core/lib/core/status.h"
67 #include "tensorflow/core/lib/gtl/array_slice.h"
68 #include "tensorflow/core/lib/gtl/stl_util.h"
69 #include "tensorflow/core/lib/monitoring/counter.h"
70 #include "tensorflow/core/lib/strings/numbers.h"
71 #include "tensorflow/core/lib/strings/str_util.h"
72 #include "tensorflow/core/lib/strings/strcat.h"
73 #include "tensorflow/core/platform/cpu_info.h"
74 #include "tensorflow/core/platform/logging.h"
75 #include "tensorflow/core/platform/mutex.h"
76 #include "tensorflow/core/platform/types.h"
77 #include "tensorflow/core/util/device_name_utils.h"
78 #include "tensorflow/core/util/env_var.h"
79 
80 #if GOOGLE_CUDA
81 #include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
82 #endif // GOOGLE_CUDA
83 
84 namespace tensorflow {
85 
86 namespace {
87 
88 CMS_THREAD_SAFE auto* nothreads_session_runs = monitoring::Counter<0>::New(
89  "/tensorflow/core/nothreads_session_runs",
90  "The number of times NTSession::Run() has been called.");
91 
92 
93 // TODO(vrv): Figure out how to unify the many different functions
94 // that generate RendezvousKey, since many of them have to be
95 // consistent with each other.
96 string GetRendezvousKey(const string& tensor_name,
97  const DeviceAttributes& device_info,
98  const FrameAndIter& frame_iter) {
99  return strings::StrCat(device_info.name(), ";",
100  strings::FpToString(device_info.incarnation()), ";",
101  device_info.name(), ";", tensor_name, ";",
102  frame_iter.frame_id, ":", frame_iter.iter_id);
103 }
104 
105 } // namespace
106 
107 class NTSessionFactory : public SessionFactory {
108  public:
110 
111  bool AcceptsOptions(const SessionOptions& options) override {
112  return options.target == "no_threads";
113  }
114 
115  Session* NewSession(const SessionOptions& options) override {
116  // Must do this before the CPU allocator is created.
117  if (options.config.graph_options().build_cost_model() > 0) {
118  EnableCPUAllocatorFullStats(true);
119  }
120  std::vector<Device*> devices;
121  Status s = DeviceFactory::AddDevices(
122  options, "/job:localhost/replica:0/task:0", &devices);
123  if (!s.ok()) {
124  LOG(ERROR) << s;
125  return nullptr;
126  }
127 
128  NTSession* session =
129  new NTSession(options, new DeviceMgr(devices), this);
130  {
131  mutex_lock l(sessions_lock_);
132  sessions_.push_back(session);
133  }
134  return session;
135  }
136 
137  Status Reset(const SessionOptions& options,
138  const std::vector<string>& containers) override {
139  std::vector<NTSession*> sessions_to_reset;
140  {
141  mutex_lock l(sessions_lock_);
142  // We create a copy to ensure that we don't have a deadlock when
143  // session->Close calls the NTSessionFactory.Deregister, which
144  // acquires sessions_lock_.
145  std::swap(sessions_to_reset, sessions_);
146  }
147  Status s;
148  for (auto session : sessions_to_reset) {
149  s.Update(session->Reset(containers));
150  }
151  // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
152  // it doesn't close the sessions?
153  for (auto session : sessions_to_reset) {
154  s.Update(session->Close());
155  }
156  return s;
157  }
158 
159  void Deregister(const NTSession* session) {
160  mutex_lock l(sessions_lock_);
161  sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
162  sessions_.end());
163  }
164 
165  private:
167  std::vector<NTSession*> sessions_ GUARDED_BY(sessions_lock_);
168 };
169 
171  public:
173  SessionFactory::Register("NOTHREADS_SESSION", new NTSessionFactory());
174  }
175 };
177 
178 std::atomic_int_fast64_t NTSession::step_id_counter_(1);
179 
180 // NOTE: On Android with a single device, there is never
181 // a risk of an OpKernel blocking indefinitely:
182 //
183 // 1) No operations do I/O that depends on other simultaneous kernels,
184 //
185 // 2) Recv nodes always complete immediately: The inputs are sent into
186 // the local rendezvous before we start the executor, so the
187 // corresponding recvs will not block.
188 //
189 // Based on these assumptions, we can use the same thread pool for
190 // both "non-blocking" and "blocking" OpKernels on Android.
191 //
192 // This may change down the road when we add support for multiple
193 // devices that run concurrently, in which case we will need to
194 // revisit this decision.
196 
197  // On Android, there is no implementation of ThreadPool that takes
198  // std::function, only Closure, which we cannot easily convert.
199  //
200  // Instead, we just run the function in-line, which is currently
201  // safe given the reasoning above.
202 
203  //Override to allow CMSSW FWK to schedule
204  c();
205  //pool->Schedule(std::move(c));
206 }
207 
208 NTSession::NTSession(const SessionOptions& options,
209  const DeviceMgr* device_mgr,
210  NTSessionFactory* const factory)
211  : options_(options),
212  device_mgr_(device_mgr),
213  factory_(factory),
214  cancellation_manager_(new CancellationManager()),
215  operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
216  // The default value of sync_on_finish will be flipped soon and this
217  // environment variable will be removed as well.
218  Status status =
219  ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
220  if (!status.ok()) {
221  LOG(ERROR) << status.error_message();
222  }
223  // NOTE(mrry): We do not need to use a unique string for the session
224  // handle, because NTSession owns its devices. This may change
225  // in future versions.
226  session_handle_ = "no_threads";
227  int devices_added = 0;
228  if (options.config.log_device_placement()) {
229  const string mapping_str = device_mgr_->DeviceMappingString();
230  if (mapping_str.empty()) {
231  printf("Device mapping: no known devices.\n");
232  } else {
233  printf("Device mapping:\n%s", mapping_str.c_str());
234  }
235  LOG(INFO) << "Device mapping:\n" << mapping_str;
236  }
237  for (auto d : device_mgr_->ListDevices()) {
238  devices_.push_back(d);
239  device_set_.AddDevice(d);
240  d->op_segment()->AddHold(session_handle_);
241 
242  // The first device added is special: it is the 'client device' (a
243  // CPU device) from which we feed and fetch Tensors.
244  if (devices_added == 0) {
245  device_set_.set_client_device(d);
246  }
247  ++devices_added;
248  }
249 }
250 
252  if (!closed_) Close().IgnoreError();
253  for (auto& it : partial_runs_) {
254  it.second.reset(nullptr);
255  }
256  for (auto& it : executors_) {
257  it.second.reset();
258  }
259  for (auto d : device_mgr_->ListDevices()) {
260  d->op_segment()->RemoveHold(session_handle_);
261  }
262  delete cancellation_manager_;
263 
264  execution_state_.reset(nullptr);
265  flib_def_.reset(nullptr);
266 }
267 
269  const GraphDef& graph, bool* out_already_initialized) {
270  // If already initialized, do nothing.
271  if (flib_def_ && execution_state_) {
272  *out_already_initialized = true;
273  return Status::OK();
274  }
275  // Set up the per-session execution state.
276  // NOTE(mrry): The function library created here will be used for
277  // all subsequent extensions of the graph.
278  flib_def_.reset(
279  new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
280  SimpleGraphExecutionStateOptions options;
281  options.device_set = &device_set_;
282  options.session_options = &options_;
283  // TODO(mrry,suharshs): We explicitly copy `graph` so that
284  // `MakeForBaseGraph()` can take ownership of its
285  // contents. Previously this happened implicitly in calls to the
286  // `SimpleGraphExecutionState`. Other sessions call
287  // `MakeForBaseGraph` in such a way that we can destructively read
288  // the passed-in `GraphDef`. In principle we could do the same here,
289  // with a wider refactoring; we might revise the direct session so
290  // that it copies the graph fewer times.
291  GraphDef temp(graph);
292  TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
293  &temp, options, &execution_state_));
294  graph_created_ = true;
295  *out_already_initialized = false;
296  return Status::OK();
297 }
298 
299 Status NTSession::Create(const GraphDef& graph) {
300  TF_RETURN_IF_ERROR(init_error_);
301  if (graph.node_size() > 0) {
302  mutex_lock l(graph_def_lock_);
303  if (graph_created_) {
304  return errors::AlreadyExists(
305  "A Graph has already been created for this session.");
306  }
307  return ExtendLocked(graph);
308  }
309  return Status::OK();
310 }
311 
312 Status NTSession::Extend(const GraphDef& graph) {
313  TF_RETURN_IF_ERROR(CheckNotClosed());
314  mutex_lock l(graph_def_lock_);
315  return ExtendLocked(graph);
316 }
317 
318 Status NTSession::ExtendLocked(const GraphDef& graph) {
319  bool already_initialized;
320  // If this is the first call, we can initialize the execution state
321  // with `graph` and do not need to call `Extend()`.
322  TF_RETURN_IF_ERROR(
323  MaybeInitializeExecutionState(graph, &already_initialized));
324  if (already_initialized) {
325  TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
326  std::unique_ptr<SimpleGraphExecutionState> state;
327  TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
328  execution_state_.swap(state);
329  }
330  return Status::OK();
331 }
332 
334  const std::vector<string>& output_names,
335  const std::vector<string>& target_nodes,
336  std::vector<Tensor>* outputs) {
337  RunMetadata run_metadata;
338  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
339  &run_metadata);
340 }
341 
343  const DebugOptions& debug_options, int64 session_run_index,
344  int64 executor_step_index, const std::vector<string>& input_names,
345  const std::vector<string>& output_names,
346  const std::vector<string>& target_names,
347  std::unique_ptr<DebuggerStateInterface>* debugger_state) {
348  TF_RETURN_IF_ERROR(
349  DebuggerStateRegistry::CreateState(debug_options, debugger_state));
350  TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
351  debug_options.global_step(), session_run_index, executor_step_index,
352  input_names, output_names, target_names));
353  return Status::OK();
354 }
355 
357  const DebugOptions& debug_options, Graph* graph, Device* device) {
358  std::unique_ptr<DebugGraphDecoratorInterface> decorator;
359  TF_RETURN_IF_ERROR(
360  DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
361 
362  TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
363  TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
364  return Status::OK();
365 }
366 
367 Status NTSession::Run(const RunOptions& run_options,
368  const NamedTensorList& inputs,
369  const std::vector<string>& output_names,
370  const std::vector<string>& target_nodes,
371  std::vector<Tensor>* outputs,
372  RunMetadata* run_metadata) {
373  TF_RETURN_IF_ERROR(CheckNotClosed());
374  nothreads_session_runs->GetCell()->IncrementBy(1);
375  {
376  mutex_lock l(graph_def_lock_);
377  if (!graph_created_) {
378  return errors::InvalidArgument(
379  "Session was not created with a graph before Run()!");
380  }
381  }
382 
383  // Extract the inputs names for this run of the session.
384  std::vector<string> input_tensor_names;
385  input_tensor_names.reserve(inputs.size());
386  for (const auto& it : inputs) {
387  input_tensor_names.push_back(it.first);
388  }
389 
390 
391  // Check if we already have an executor for these arguments.
392  ExecutorsAndKeys* executors_and_keys;
393  RunStateArgs run_state_args(run_options.debug_options());
394 
395  Executor::Args args;
396  args.step_id = step_id_counter_.fetch_add(1);
397 
398  TF_RETURN_IF_ERROR(
399  GetOrCreateExecutors(input_tensor_names, output_names, target_nodes,
400  &executors_and_keys, &run_state_args));
401  const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
402 
403  std::unique_ptr<DebuggerStateInterface> debugger_state;
404  if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
405  TF_RETURN_IF_ERROR(CreateDebuggerState(
406  run_options.debug_options(), args.step_id, executor_step_count,
407  input_tensor_names, output_names, target_nodes, &debugger_state));
408  }
409 
410  // Configure a call frame for the step, which we use to feed and
411  // fetch values to and from the executors.
412  FunctionCallFrame call_frame(executors_and_keys->input_types,
413  executors_and_keys->output_types);
414  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
415  for (const auto& it : inputs) {
416  if (it.second.dtype() == DT_RESOURCE) {
417  Tensor tensor_from_handle;
418  TF_RETURN_IF_ERROR(
419  ResourceHandleToInputTensor(it.second, &tensor_from_handle));
420  feed_args[executors_and_keys->input_name_to_index[it.first]] =
421  tensor_from_handle;
422  } else {
423  feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
424  }
425  }
426  Status s = call_frame.SetArgs(feed_args);
427  if (errors::IsInternal(s)) {
428  return errors::InvalidArgument(s.error_message());
429  } else if (!s.ok()) {
430  return s;
431  }
432 
433  // Create a run state and start execution.
434  RunState run_state(args.step_id, &devices_);
435  run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
436  CancellationManager step_cancellation_manager;
437  args.call_frame = &call_frame;
438 
439  // Start parallel Executors.
440  const size_t num_executors = executors_and_keys->items.size();
441  ExecutorBarrier* barrier = new ExecutorBarrier(
442  num_executors, run_state.rendez, [&run_state](const Status& ret) {
443  {
444  mutex_lock l(run_state.mu_);
445  run_state.status.Update(ret);
446  }
447  run_state.executors_done.Notify();
448  });
449 
450  args.rendezvous = run_state.rendez;
451  args.cancellation_manager = &step_cancellation_manager;
452  args.runner = [this](Executor::Args::Closure c) {
454  };
455  args.session_state = &session_state_;
456  args.tensor_store = &run_state.tensor_store;
457  args.step_container = &run_state.step_container;
458  if (LogMemory::IsEnabled()) {
459  LogMemory::RecordStep(args.step_id, run_state_args.handle);
460  }
461  args.sync_on_finish = sync_on_finish_;
462 
463  const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
464 
465  bool update_cost_model = false;
466  if (options_.config.graph_options().build_cost_model() > 0) {
467  const int64 build_cost_model_every =
468  options_.config.graph_options().build_cost_model();
469  const int64 build_cost_model_after =
470  options_.config.graph_options().build_cost_model_after();
471  int64 measure_step_count = executor_step_count - build_cost_model_after;
472  if (measure_step_count >= 0) {
473  update_cost_model =
474  ((measure_step_count + 1) % build_cost_model_every == 0);
475  }
476  }
477  if (do_trace || update_cost_model) {
478  run_state.collector.reset(
479  new StepStatsCollector(run_metadata->mutable_step_stats()));
480  args.stats_collector = run_state.collector.get();
481  }
482 
483 #if GOOGLE_CUDA
484  std::unique_ptr<GPUTracer> tracer;
485  if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
486  tracer.reset(CreateGPUTracer());
487  // tracer will be NULL on non-GPU platforms.
488  // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
489  if (tracer) tracer->Start().IgnoreError();
490  }
491 #endif // GOOGLE_CUDA
492 
493  // Register this step with session's cancellation manager, so that
494  // `Session::Close()` will cancel the step.
495  CancellationToken cancellation_token =
496  cancellation_manager_->get_cancellation_token();
497  bool already_cancelled = !cancellation_manager_->RegisterCallback(
498  cancellation_token, [&step_cancellation_manager]() {
499  step_cancellation_manager.StartCancel();
500  });
501  if (already_cancelled) {
502  // NOTE(mrry): If we don't explicitly notify
503  // `run_state.executors_done`, the RunState destructor would
504  // block on this notification.
505  run_state.executors_done.Notify();
506  delete barrier;
507  return errors::Cancelled("Run call was cancelled");
508  }
509 
510  for (const auto& item : executors_and_keys->items) {
511  item.executor->RunAsync(args, barrier->Get());
512  }
513 
514  WaitForNotification(&run_state, &step_cancellation_manager,
515  run_options.timeout_in_ms() > 0
516  ? run_options.timeout_in_ms()
518 
519  if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
520  // The step has been cancelled: make sure we don't attempt to receive the
521  // outputs as this would make it block forever.
522  mutex_lock l(run_state.mu_);
523  run_state.status.Update(errors::Cancelled("Run call was cancelled"));
524  }
525 
526 #if GOOGLE_CUDA
527  if (tracer) {
528  // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
529  tracer->Stop().IgnoreError();
530  tracer->Collect(args.stats_collector).IgnoreError();
531  }
532 #endif // GOOGLE_CUDA
533 
534  {
535  mutex_lock l(run_state.mu_);
536  TF_RETURN_IF_ERROR(run_state.status);
537  }
538 
539  // Receive outputs.
540  if (outputs) {
541  std::vector<Tensor> sorted_outputs;
542  Status s = call_frame.ConsumeRetvals(&sorted_outputs);
543  if (errors::IsInternal(s)) {
544  return errors::InvalidArgument(s.error_message());
545  } else if (!s.ok()) {
546  return s;
547  }
548  const bool unique_outputs =
549  output_names.size() == executors_and_keys->output_name_to_index.size();
550  // first_indices[i] = j implies that j is the smallest value for which
551  // output_names[i] == output_names[j].
552  std::vector<int> first_indices;
553  if (!unique_outputs) {
554  first_indices.resize(output_names.size());
555  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
556  for (int j = 0; j <= i; ++j) {
557  if (output_names[i] == output_names[j]) {
558  first_indices[i] = j;
559  break;
560  }
561  }
562  }
563  }
564  outputs->clear();
565  outputs->reserve(sorted_outputs.size());
566  for (int i = 0; i < static_cast<int>(output_names.size()); ++i) {
567  const string& output_name = output_names[i];
568  if (first_indices.empty() || first_indices[i] == i) {
569  outputs->emplace_back(
570  std::move(sorted_outputs[executors_and_keys
571  ->output_name_to_index[output_name]]));
572  } else {
573  outputs->push_back((*outputs)[first_indices[i]]);
574  }
575  }
576  }
577 
578  // Save the output tensors of this run we choose to keep.
579  TF_RETURN_IF_ERROR(
580  run_state.tensor_store.SaveTensors(output_names, &session_state_));
581 
582  // Build and return the cost model as instructed.
583  mutex_lock l(executor_lock_);
584  if (update_cost_model) {
585  // Build the cost model
586  std::unordered_map<string, const Graph*> device_to_graph;
587  for (const PerPartitionExecutorsAndLib& partition :
588  executors_and_keys->items) {
589  const Graph* graph = partition.graph;
590  const string device = partition.flib->device()->name();
591  device_to_graph[device] = graph;
592  }
593  args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
594 
595  // annotate stats onto cost graph.
596  CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
597  for (const auto& item : executors_and_keys->items) {
598  TF_RETURN_IF_ERROR(
599  cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
600  }
601  }
602 
603  // If requested via RunOptions, output the partition graphs.
604  if (run_options.output_partition_graphs()) {
605  protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
606  run_metadata->mutable_partition_graphs();
607  for (const PerPartitionExecutorsAndLib& exec_and_lib :
608  executors_and_keys->items) {
609  GraphDef* partition_graph_def = partition_graph_defs->Add();
610  exec_and_lib.graph->ToGraphDef(partition_graph_def);
611  }
612  }
613 
614  return Status::OK();
615 }
616 
617 Status NTSession::PRunSetup(const std::vector<string>& input_names,
618  const std::vector<string>& output_names,
619  const std::vector<string>& target_nodes,
620  string* handle) {
621  TF_RETURN_IF_ERROR(CheckNotClosed());
622  {
623  mutex_lock l(graph_def_lock_);
624  if (!graph_created_) {
625  return errors::InvalidArgument(
626  "Session was not created with a graph before PRunSetup()!");
627  }
628  }
629 
630  // Check if we already have an executor for these arguments.
631  ExecutorsAndKeys* executors_and_keys;
632  // TODO(cais): TFDBG support for partial runs.
633  DebugOptions debug_options;
634  RunStateArgs run_state_args(debug_options);
635  run_state_args.is_partial_run = true;
636  TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
637  target_nodes, &executors_and_keys,
638  &run_state_args));
639 
640  // Create the run state and save it for future PRun calls.
641  Executor::Args args;
642  args.step_id = step_id_counter_.fetch_add(1);
643  RunState* run_state =
644  new RunState(input_names, output_names, args.step_id, &devices_);
645  run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
646  {
647  mutex_lock l(executor_lock_);
648  if (!partial_runs_
649  .emplace(run_state_args.handle,
650  std::unique_ptr<RunState>(run_state))
651  .second) {
652  return errors::Internal("The handle '", run_state_args.handle,
653  "' created for this partial run is not unique.");
654  }
655  }
656 
657  // Start parallel Executors.
658  const size_t num_executors = executors_and_keys->items.size();
659  ExecutorBarrier* barrier = new ExecutorBarrier(
660  num_executors, run_state->rendez, [run_state](const Status& ret) {
661  if (!ret.ok()) {
662  mutex_lock l(run_state->mu_);
663  run_state->status.Update(ret);
664  }
665  run_state->executors_done.Notify();
666  });
667 
668  args.rendezvous = run_state->rendez;
669  args.cancellation_manager = cancellation_manager_;
670  args.runner = [this](Executor::Args::Closure c) {
672  };
673  args.session_state = &session_state_;
674  args.tensor_store = &run_state->tensor_store;
675  args.step_container = &run_state->step_container;
676  if (LogMemory::IsEnabled()) {
677  LogMemory::RecordStep(args.step_id, run_state_args.handle);
678  }
679  args.sync_on_finish = sync_on_finish_;
680 
681  if (options_.config.graph_options().build_cost_model()) {
682  run_state->collector.reset(new StepStatsCollector(nullptr));
683  args.stats_collector = run_state->collector.get();
684  }
685 
686  for (auto& item : executors_and_keys->items) {
687  item.executor->RunAsync(args, barrier->Get());
688  }
689 
690  *handle = run_state_args.handle;
691  return Status::OK();
692 }
693 
695  const std::vector<string>& output_names,
696  std::vector<Tensor>* outputs) {
697  TF_RETURN_IF_ERROR(CheckNotClosed());
698  std::vector<string> parts = str_util::Split(handle, ';');
699  const string& key = parts[0];
700  // Get the executors for this partial run.
701  ExecutorsAndKeys* executors_and_keys;
702  RunState* run_state;
703  {
704  mutex_lock l(executor_lock_); // could use reader lock
705  auto exc_it = executors_.find(key);
706  if (exc_it == executors_.end()) {
707  return errors::InvalidArgument(
708  "Must run 'setup' before performing partial runs!");
709  }
710  executors_and_keys = exc_it->second.get();
711 
712  auto prun_it = partial_runs_.find(handle);
713  if (prun_it == partial_runs_.end()) {
714  return errors::InvalidArgument(
715  "Must run 'setup' before performing partial runs!");
716  }
717  run_state = prun_it->second.get();
718 
719  // Make sure that this is a new set of feeds that are still pending.
720  for (const auto& input : inputs) {
721  auto it = run_state->pending_inputs.find(input.first);
722  if (it == run_state->pending_inputs.end()) {
723  return errors::InvalidArgument(
724  "The feed ", input.first,
725  " was not specified in partial_run_setup.");
726  } else if (it->second) {
727  return errors::InvalidArgument("The feed ", input.first,
728  " has already been fed.");
729  }
730  }
731  // Check that this is a new set of fetches that are still pending.
732  for (const auto& output : output_names) {
733  auto it = run_state->pending_outputs.find(output);
734  if (it == run_state->pending_outputs.end()) {
735  return errors::InvalidArgument(
736  "The fetch ", output, " was not specified in partial_run_setup.");
737  } else if (it->second) {
738  return errors::InvalidArgument("The fetch ", output,
739  " has already been fetched.");
740  }
741  }
742  }
743 
744  // Check that this new set of fetches can be computed from all the
745  // feeds we have supplied.
746  TF_RETURN_IF_ERROR(
747  CheckFetch(inputs, output_names, executors_and_keys, run_state));
748 
749  // Send inputs.
750  Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
751 
752  // Receive outputs.
753  if (s.ok()) {
754  s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
755  }
756 
757  // Save the output tensors of this run we choose to keep.
758  if (s.ok()) {
759  s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
760  }
761 
762  {
763  mutex_lock l(executor_lock_);
764  // Delete the run state if there is an error or all fetches are done.
765  bool done = true;
766  if (s.ok()) {
767  {
768  mutex_lock l(run_state->mu_);
769  if (!run_state->status.ok()) {
770  LOG(WARNING) << "An error unrelated to this prun has been detected. "
771  << run_state->status;
772  }
773  }
774  for (const auto& input : inputs) {
775  auto it = run_state->pending_inputs.find(input.first);
776  it->second = true;
777  }
778  for (const auto& name : output_names) {
779  auto it = run_state->pending_outputs.find(name);
780  it->second = true;
781  }
782  done = run_state->PendingDone();
783  }
784  if (done) {
787  partial_runs_.erase(handle);
788  }
789  }
790 
791  return s;
792 }
793 
794 Status NTSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
795  Tensor* retrieved_tensor) {
796  if (resource_tensor.dtype() != DT_RESOURCE) {
797  return errors::InvalidArgument(strings::StrCat(
798  "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
799  resource_tensor.dtype()));
800  }
801 
802  ResourceHandle resource_handle = resource_tensor.scalar<ResourceHandle>()();
803 
804  if (resource_handle.container() ==
805  SessionState::kTensorHandleResourceTypeName) {
806  return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
807  } else {
808  return errors::InvalidArgument(strings::StrCat(
809  "Invalid resource type hash code: ", resource_handle.hash_code(),
810  "(name: ", resource_handle.name(),
811  " type: ", resource_handle.maybe_type_name(), ")"));
812  }
813 }
814 
816  const ExecutorsAndKeys* executors_and_keys,
817  IntraProcessRendezvous* rendez) {
818  Status s;
819  Rendezvous::ParsedKey parsed;
820  // Insert the input tensors into the local rendezvous by their
821  // rendezvous key.
822  for (const auto& input : inputs) {
823  auto it =
824  executors_and_keys->input_name_to_rendezvous_key.find(input.first);
825  if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
826  return errors::Internal("'", input.first, "' is not a pre-defined feed.");
827  }
828  const string& input_key = it->second;
829 
830  s = Rendezvous::ParseKey(input_key, &parsed);
831  if (!s.ok()) {
832  rendez->StartAbort(s);
833  return s;
834  }
835 
836  if (input.second.dtype() == DT_RESOURCE) {
837  Tensor tensor_from_handle;
838  s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
839  if (s.ok()) {
840  s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
841  }
842  } else {
843  s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
844  }
845 
846  if (!s.ok()) {
847  rendez->StartAbort(s);
848  return s;
849  }
850  }
851  return Status::OK();
852 }
853 
855  const std::vector<string>& output_names,
856  const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
857  std::vector<Tensor>* outputs) {
858  Status s;
859  if (!output_names.empty()) {
860  outputs->resize(output_names.size());
861  }
862 
863  Rendezvous::ParsedKey parsed;
864  // Get the outputs from the rendezvous
865  for (size_t output_offset = 0; output_offset < output_names.size();
866  ++output_offset) {
867  const string& output_name = output_names[output_offset];
868  auto it =
869  executors_and_keys->output_name_to_rendezvous_key.find(output_name);
870  if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
871  return errors::Internal("'", output_name,
872  "' is not a pre-defined fetch.");
873  }
874  const string& output_key = it->second;
875  Tensor output_tensor;
876  bool is_dead;
877  IntraProcessRendezvous* rendez = run_state->rendez;
878 
879  s = Rendezvous::ParseKey(output_key, &parsed);
880  if (s.ok()) {
881  // Fetch data from the Rendezvous.
882  s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
884  if (is_dead && s.ok()) {
885  s = errors::InvalidArgument("The tensor returned for ", output_name,
886  " was not valid.");
887  }
888  }
889  if (!s.ok()) {
890  rendez->StartAbort(s);
891  outputs->clear();
892  return s;
893  }
894 
895  (*outputs)[output_offset] = output_tensor;
896  }
897  return Status::OK();
898 }
899 
901  const std::vector<string>& fetches,
902  const ExecutorsAndKeys* executors_and_keys,
903  const RunState* run_state) {
904  const Graph* graph = executors_and_keys->graph.get();
905  const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
906 
907  // Build the set of pending feeds that we haven't seen.
908  std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
909  {
910  mutex_lock l(executor_lock_);
911  for (const auto& input : run_state->pending_inputs) {
912  // Skip if the feed has already been fed.
913  if (input.second) continue;
914  TensorId id(ParseTensorName(input.first));
915  auto it = name_to_node->find(id.first);
916  if (it == name_to_node->end()) {
917  return errors::NotFound("Feed ", input.first, ": not found");
918  }
919  pending_feeds.insert(id);
920  }
921  }
922  for (const auto& it : feeds) {
923  TensorId id(ParseTensorName(it.first));
924  pending_feeds.erase(id);
925  }
926 
927  // Initialize the stack with the fetch nodes.
928  std::vector<const Node*> stack;
929  for (const string& fetch : fetches) {
930  TensorId id(ParseTensorName(fetch));
931  auto it = name_to_node->find(id.first);
932  if (it == name_to_node->end()) {
933  return errors::NotFound("Fetch ", fetch, ": not found");
934  }
935  stack.push_back(it->second);
936  }
937 
938  // Any tensor needed for fetches can't be in pending_feeds.
939  std::vector<bool> visited(graph->num_node_ids(), false);
940  while (!stack.empty()) {
941  const Node* n = stack.back();
942  stack.pop_back();
943 
944  for (const Edge* in_edge : n->in_edges()) {
945  const Node* in_node = in_edge->src();
946  if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
947  return errors::InvalidArgument("Fetch ", in_node->name(), ":",
948  in_edge->src_output(),
949  " can't be computed from the feeds"
950  " that have been fed so far.");
951  }
952  if (!visited[in_node->id()]) {
953  visited[in_node->id()] = true;
954  stack.push_back(in_node);
955  }
956  }
957  }
958  return Status::OK();
959 }
960 
962  gtl::ArraySlice<string> inputs,
963  gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes,
964  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) {
965  int64 handle_name_counter_value = -1;
966  if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
967  handle_name_counter_value = handle_name_counter_.fetch_add(1);
968  }
969 
970  string debug_tensor_watches_summary;
971  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
972  debug_tensor_watches_summary = SummarizeDebugTensorWatches(
973  run_state_args->debug_options.debug_tensor_watch_opts());
974  }
975 
976  // Fast lookup path, no sorting.
977  const string key = strings::StrCat(
978  str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
979  str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
980  "/", debug_tensor_watches_summary);
981  // Set the handle, if it's needed to log memory or for partial run.
982  if (handle_name_counter_value >= 0) {
983  run_state_args->handle =
984  strings::StrCat(key, ";", handle_name_counter_value);
985  }
986 
987  // See if we already have the executors for this run.
988  {
989  mutex_lock l(executor_lock_); // could use reader lock
990  auto it = executors_.find(key);
991  if (it != executors_.end()) {
992  *executors_and_keys = it->second.get();
993  return Status::OK();
994  }
995  }
996 
997  // Slow lookup path, the unsorted key missed the cache.
998  // Sort the inputs and outputs, and look up with the sorted key in case an
999  // earlier call used a different order of inputs and outputs.
1000  //
1001  // We could consider some other signature instead of sorting that
1002  // preserves the same property to avoid the sort in the future.
1003  std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1004  std::sort(inputs_sorted.begin(), inputs_sorted.end());
1005  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1006  std::sort(outputs_sorted.begin(), outputs_sorted.end());
1007  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1008  std::sort(tn_sorted.begin(), tn_sorted.end());
1009 
1010  const string sorted_key = strings::StrCat(
1011  str_util::Join(inputs_sorted, ","), "->",
1012  str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
1013  "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1014  // Set the handle, if its needed to log memory or for partial run.
1015  if (handle_name_counter_value >= 0) {
1016  run_state_args->handle =
1017  strings::StrCat(sorted_key, ";", handle_name_counter_value);
1018  }
1019 
1020  // See if we already have the executors for this run.
1021  {
1022  mutex_lock l(executor_lock_);
1023  auto it = executors_.find(sorted_key);
1024  if (it != executors_.end()) {
1025  *executors_and_keys = it->second.get();
1026  // Insert this under the original key.
1027  executors_.emplace(key, it->second);
1028  return Status::OK();
1029  }
1030  }
1031 
1032  // Nothing found, so create the executors and store in the cache.
1033  BuildGraphOptions options;
1034  options.feed_endpoints = inputs_sorted;
1035  options.fetch_endpoints = outputs_sorted;
1036  options.target_nodes = tn_sorted;
1037  options.use_function_convention = !run_state_args->is_partial_run;
1038  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1039  options.debug_options = run_state_args->debug_options;
1040  }
1041 
1042  std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1043 
1044  // The executor_lock_ is intentionally released while executor is
1045  // being created.
1046  std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1047  TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
1048  run_state_args, &ek->input_types,
1049  &ek->output_types));
1050 
1051  if (run_state_args->is_partial_run) {
1052  ek->graph = std::move(run_state_args->graph);
1053  std::unordered_set<StringPiece, StringPiece::Hasher> names;
1054  for (const string& input : inputs) {
1055  TensorId id(ParseTensorName(input));
1056  names.emplace(id.first);
1057  }
1058  for (const string& output : outputs) {
1059  TensorId id(ParseTensorName(output));
1060  names.emplace(id.first);
1061  }
1062  for (Node* n : ek->graph->nodes()) {
1063  if (names.count(n->name()) > 0) {
1064  ek->name_to_node.insert({n->name(), n});
1065  }
1066  }
1067  }
1068  ek->items.reserve(graphs.size());
1069  const auto& optimizer_opts =
1070  options_.config.graph_options().optimizer_options();
1071  GraphOptimizer optimizer(optimizer_opts);
1072  for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1073  const string& partition_name = iter->first;
1074  std::unique_ptr<Graph>& partition_graph = iter->second;
1075  const int graph_def_version = partition_graph->versions().producer();
1076 
1077  Device* device;
1078  TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1079 
1080  ek->items.resize(ek->items.size() + 1);
1081  auto* item = &(ek->items.back());
1082  item->flib.reset(NewFunctionLibraryRuntime(
1083  device_mgr_.get(), options_.env, device, graph_def_version,
1084  ek->flib_def.get(), optimizer_opts));
1085 
1086  LocalExecutorParams params;
1087  params.device = device;
1088  params.function_library = item->flib.get();
1089  auto lib = item->flib.get();
1090  auto opseg = device->op_segment();
1091  params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
1092  OpKernel** kernel) {
1093  // Caches the kernel only if the node is stateful.
1094  if (!lib->IsStateful(ndef.op())) {
1095  return lib->CreateKernel(ndef, kernel);
1096  }
1097  auto create_fn = [lib, &ndef](OpKernel** kernel) {
1098  return lib->CreateKernel(ndef, kernel);
1099  };
1100  // Kernels created for subgraph nodes need to be cached. On
1101  // cache miss, create_fn() is invoked to create a kernel based
1102  // on the function library here + global op registry.
1103  return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
1104  create_fn);
1105  };
1106  params.delete_kernel = [lib](OpKernel* kernel) {
1107  // If the node is stateful, opseg owns it. Otherwise, delete it.
1108  if (kernel && !lib->IsStateful(kernel->type_string())) {
1109  delete kernel;
1110  }
1111  };
1112  params.node_outputs_cb = node_outputs_callback_;
1113 
1114  optimizer.Optimize(lib, options_.env, device, &iter->second);
1115 
1116  // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
1117  if (!options.debug_options.debug_tensor_watch_opts().empty()) {
1118  TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1119  options.debug_options, partition_graph.get(), params.device));
1120  }
1121 
1122  TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1123  device->name(),
1124  partition_graph.get()));
1125  // NewLocalExecutor takes ownership of partition_graph.
1126  item->graph = partition_graph.get();
1127  item->executor = nullptr;
1128  Executor* executor;
1129  TF_RETURN_IF_ERROR(
1130  NewLocalExecutor(params, partition_graph.release(), &executor));
1131  item->executor.reset(executor);
1132  }
1133 
1134  // Cache the mapping from input/output names to graph elements to
1135  // avoid recomputing it every time.
1136  if (!run_state_args->is_partial_run) {
1137  // For regular `Run()`, we use the function calling convention, and so
1138  // maintain a mapping from input/output names to
1139  // argument/return-value ordinal index.
1140  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
1141  const string& input = inputs_sorted[i];
1142  ek->input_name_to_index[input] = i;
1143  }
1144  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
1145  const string& output = outputs_sorted[i];
1146  ek->output_name_to_index[output] = i;
1147  }
1148  } else {
1149  // For `PRun()`, we use the rendezvous calling convention, and so
1150  // maintain a mapping from input/output names to rendezvous keys.
1151  //
1152  // We always use the first device as the device name portion of the
1153  // key, even if we're feeding another graph.
1154  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
1155  const string& input = inputs_sorted[i];
1156  ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1157  input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1158  }
1159  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
1160  const string& output = outputs_sorted[i];
1161  ek->output_name_to_rendezvous_key[output] =
1162  GetRendezvousKey(output, device_set_.client_device()->attributes(),
1163  FrameAndIter(0, 0));
1164  }
1165  }
1166 
1167  // Reacquire the lock, try to insert into the map.
1168  mutex_lock l(executor_lock_);
1169 
1170  // Another thread may have created the entry before us, in which case we will
1171  // reuse the already created one.
1172  auto insert_result = executors_.emplace(sorted_key, ek);
1173  // Insert the value under the original key, so the fast path lookup will work
1174  // if the user uses the same order of inputs, outputs, and targets again.
1175  executors_.emplace(key, insert_result.first->second);
1176  *executors_and_keys = insert_result.first->second.get();
1177 
1178  return Status::OK();
1179 }
1180 
1182  const BuildGraphOptions& subgraph_options,
1183  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1184  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1185  RunStateArgs* run_state_args, DataTypeVector* input_types,
1186  DataTypeVector* output_types) {
1187  mutex_lock l(graph_def_lock_);
1188  std::unique_ptr<SimpleClientGraph> client_graph;
1189 
1190  std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
1191  SimpleGraphExecutionState* execution_state = nullptr;
1192  if (options_.config.graph_options().place_pruned_graph()) {
1193  // Because we are placing pruned graphs, we need to create a
1194  // new SimpleGraphExecutionState for every new unseen graph,
1195  // and then place it.
1196  SimpleGraphExecutionStateOptions prune_options;
1197  prune_options.device_set = &device_set_;
1198  prune_options.session_options = &options_;
1199  prune_options.stateful_placements = stateful_placements_;
1200  TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph(
1201  execution_state_->original_graph_def().library(), prune_options,
1202  execution_state_->original_graph_def(), subgraph_options,
1203  &temp_exec_state_holder, &client_graph));
1204  execution_state = temp_exec_state_holder.get();
1205  } else {
1206  execution_state = execution_state_.get();
1207  TF_RETURN_IF_ERROR(
1208  execution_state->BuildGraph(subgraph_options, &client_graph));
1209  }
1210 
1211  if (subgraph_options.feed_endpoints.size() !=
1212  client_graph->feed_types.size()) {
1213  return errors::Internal(
1214  "Graph pruning failed: requested number of feed endpoints = ",
1215  subgraph_options.feed_endpoints.size(),
1216  " versus number of pruned feed endpoints = ",
1217  client_graph->feed_types.size());
1218  }
1219  if (subgraph_options.fetch_endpoints.size() !=
1220  client_graph->fetch_types.size()) {
1221  return errors::Internal(
1222  "Graph pruning failed: requested number of fetch endpoints = ",
1223  subgraph_options.fetch_endpoints.size(),
1224  " versus number of pruned fetch endpoints = ",
1225  client_graph->fetch_types.size());
1226  }
1227 
1228  auto current_stateful_placements = execution_state->GetStatefulPlacements();
1229  // Update our current state based on the execution_state's
1230  // placements. If there are any mismatches for a node,
1231  // we should fail, as this should never happen.
1232  for (auto placement_pair : current_stateful_placements) {
1233  const string& node_name = placement_pair.first;
1234  const string& placement = placement_pair.second;
1235  auto iter = stateful_placements_.find(node_name);
1236  if (iter == stateful_placements_.end()) {
1237  stateful_placements_.insert(std::make_pair(node_name, placement));
1238  } else if (iter->second != placement) {
1239  return errors::Internal(
1240  "Stateful placement mismatch. "
1241  "Current assignment of ",
1242  node_name, " to ", iter->second, " does not match ", placement);
1243  }
1244  }
1245 
1246  stateful_placements_ = execution_state->GetStatefulPlacements();
1247 
1248  // Remember the graph in run state if this is a partial run.
1249  if (run_state_args->is_partial_run) {
1250  run_state_args->graph.reset(new Graph(flib_def_.get()));
1251  CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1252  }
1253 
1254  // Partition the graph across devices.
1255  PartitionOptions popts;
1256  popts.node_to_loc = [](const Node* node) {
1257  assert(node != nullptr);
1258  return node->assigned_device_name();
1259  };
1260  popts.new_name = [this](const string& prefix) {
1261  return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1262  };
1263  popts.get_incarnation = [](const string& name) {
1264  // The direct session does not have changing incarnation numbers.
1265  // Just return '1'.
1266  return 1;
1267  };
1268  popts.control_flow_added = false;
1269 
1270  std::unordered_map<string, GraphDef> partitions;
1271  TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1272 
1273  std::vector<string> device_names;
1274  for (auto device : devices_) {
1275  // Extract the LocalName from the device.
1276  device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1277  }
1278 
1279  // Check for valid partitions.
1280  for (const auto& partition : partitions) {
1281  const string local_partition_name =
1282  DeviceNameUtils::LocalName(partition.first);
1283  if (std::count(device_names.begin(), device_names.end(),
1284  local_partition_name) == 0) {
1285  return errors::InvalidArgument(
1286  "Creating a partition for ", local_partition_name,
1287  " which doesn't exist in the list of available devices. Available "
1288  "devices: ",
1289  str_util::Join(device_names, ","));
1290  }
1291  }
1292 
1293  for (const auto& partition : partitions) {
1294  std::unique_ptr<Graph> device_graph(
1295  new Graph(client_graph->flib_def.get()));
1296  GraphConstructorOptions device_opts;
1297  // There are internal operations (e.g., send/recv) that we now allow.
1298  device_opts.allow_internal_ops = true;
1299  device_opts.expect_device_spec = true;
1300  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1301  device_graph.get()));
1302  outputs->emplace(partition.first, std::move(device_graph));
1303  }
1304 
1305  GraphOptimizationPassOptions optimization_options;
1306  optimization_options.session_options = &options_;
1307  optimization_options.flib_def = client_graph->flib_def.get();
1308  optimization_options.partition_graphs = outputs;
1309  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1310  OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1311 
1312  Status s;
1313  for (auto& partition : *outputs) {
1314  const string& partition_name = partition.first;
1315  std::unique_ptr<Graph>* graph = &partition.second;
1316 
1317  VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1318  << partition_name;
1319 
1320  // Give the device an opportunity to rewrite its subgraph.
1321  Device* d;
1322  s = device_mgr_->LookupDevice(partition_name, &d);
1323  if (!s.ok()) break;
1324  // TODO(pbar) The library is currently shared and immutable. There
1325  // may be possible use cases where a device may want to modify
1326  // function definitions - in which case the library would need to be
1327  // replicated per device.
1328  s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph);
1329  if (!s.ok()) {
1330  break;
1331  }
1332  }
1333  *flib_def = std::move(client_graph->flib_def);
1334  std::swap(*input_types, client_graph->feed_types);
1335  std::swap(*output_types, client_graph->fetch_types);
1336  return s;
1337 }
1338 
1340  std::vector<DeviceAttributes>* response) {
1341  response->clear();
1342  response->reserve(devices_.size());
1343  for (Device* d : devices_) {
1344  const DeviceAttributes& attrs = d->attributes();
1345  response->emplace_back(attrs);
1346  }
1348 }
1349 
1351  const std::vector<string>& containers) {
1352  device_mgr_->ClearContainers(containers);
1354 }
1355 
1357  cancellation_manager_->StartCancel();
1358  {
1359  mutex_lock l(closed_lock_);
1360  if (closed_) return ::tensorflow::Status::OK();
1361  closed_ = true;
1362  }
1363  if (factory_ != nullptr) factory_->Deregister(this);
1365 }
1366 
1368  const std::vector<string>& pending_input_names,
1369  const std::vector<string>& pending_output_names, int64 step_id,
1370  const std::vector<Device*>* devices)
1371  : step_container(step_id, [devices](const string& name) {
1372  for (auto d : *devices) {
1373  if (!d->resource_manager()->Cleanup(name).ok()) {
1374  // Do nothing...
1375  }
1376  }
1377  }) {
1378  // Initially all the feeds and fetches are pending.
1379  for (auto& name : pending_input_names) {
1380  pending_inputs[name] = false;
1381  }
1382  for (auto& name : pending_output_names) {
1383  pending_outputs[name] = false;
1384  }
1385 }
1386 
1388  const std::vector<Device*>* devices)
1389  : RunState({}, {}, step_id, devices) {}
1390 
1392  if (rendez != nullptr) {
1393  if (!executors_done.HasBeenNotified()) {
1394  rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1395  executors_done.WaitForNotification();
1396  }
1397  rendez->Unref();
1398  }
1399 }
1400 
1402  for (const auto& it : pending_inputs) {
1403  if (!it.second) return false;
1404  }
1405  for (const auto& it : pending_outputs) {
1406  if (!it.second) return false;
1407  }
1408  return true;
1409 }
1410 
1412  CancellationManager* cm,
1413  int64 timeout_in_ms) {
1414  Status status =
1415  WaitForNotification(&run_state->executors_done, timeout_in_ms);
1416  if (!status.ok()) {
1417  {
1418  mutex_lock l(run_state->mu_);
1419  run_state->status.Update(status);
1420  }
1421  cm->StartCancel();
1422  // We must wait for the executors to complete, because they have borrowed
1423  // references to `cm` and other per-step state. After this notification, it
1424  // is safe to clean up the step.
1425  run_state->executors_done.WaitForNotification();
1426  }
1427 }
1428 
1430  Notification* notification, int64 timeout_in_ms) {
1431  if (timeout_in_ms > 0) {
1432  int64 timeout_in_us = timeout_in_ms * 1000;
1433  bool notified = WaitForNotificationWithTimeout(notification, timeout_in_us);
1434  if (!notified) {
1435  return Status(error::DEADLINE_EXCEEDED,
1436  "Timed out waiting for notification");
1437  }
1438  } else {
1439  notification->WaitForNotification();
1440  }
1441  return Status::OK();
1442 }
1443 
1444 } // namespace tensorflow
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: NTSession.cc:342
std::vector< PerPartitionExecutorsAndLib > items
Definition: NTSession.h:153
static std::atomic_int_fast64_t step_id_counter_
Definition: NTSession.h:336
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: NTSession.cc:1350
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:268
const SessionOptions options_
Definition: NTSession.h:273
static const HistoName names[]
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: NTSession.h:155
static boost::mutex mutex
Definition: LHEProxy.cc:11
IntraProcessRendezvous * rendez
Definition: NTSession.h:170
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: NTSession.cc:794
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
Definition: NTSession.cc:617
void Deregister(const NTSession *session)
Definition: NTSession.cc:159
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: NTSession.cc:356
static const int WARNING
SessionState session_state_
Definition: NTSession.h:306
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: NTSession.h:157
#define LOG(A)
Definition: config.py:1
bool AcceptsOptions(const SessionOptions &options) override
Definition: NTSession.cc:111
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
Definition: NTSession.cc:815
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: NTSession.h:82
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Definition: NTSession.cc:137
RunState(int64 step_id, const std::vector< Device * > *devices)
Definition: NTSession.cc:1387
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:694
static std::string const input
Definition: EdmProvDump.cc:44
~NTSession() override
Definition: NTSession.cc:251
DeviceSet device_set_
Definition: NTSession.h:278
Partition
Definition: HLTHPDFilter.cc:32
std::unordered_map< string, bool > pending_outputs
Definition: NTSession.h:174
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:333
static NTSessionRegistrar registrar
Definition: NTSession.cc:176
::tensorflow::Status CheckNotClosed()
Definition: NTSession.h:257
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:318
stack
Definition: svgfig.py:558
#define CMS_THREAD_SAFE
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: NTSession.cc:1429
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: NTSession.cc:1339
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
Definition: NTSession.cc:900
std::unordered_map< string, size_t > input_name_to_index
Definition: NTSession.h:154
std::atomic< int64 > edge_name_counter_
Definition: NTSession.h:332
::tensorflow::Status Close() override
Definition: NTSession.cc:1356
std::atomic< int64 > handle_name_counter_
Definition: NTSession.h:333
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: NTSession.h:276
Session * NewSession(const SessionOptions &options) override
Definition: NTSession.cc:115
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unique_ptr< Graph > graph
Definition: NTSession.h:195
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
Definition: NTSession.cc:854
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: NTSession.cc:1181
std::vector< Device * > devices_
Definition: NTSession.h:277
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: NTSession.h:344
std::atomic_int_fast64_t step_count
Definition: NTSession.h:149
std::unique_ptr< FunctionLibraryRuntime > flib
Definition: NTSession.h:130
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:209
void SchedClosure(std::function< void()> c)
Definition: NTSession.cc:195
std::vector< NTSession * > sessions_ GUARDED_BY(sessions_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: NTSession.cc:961
std::unordered_map< string, bool > pending_inputs
Definition: NTSession.h:173
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
Definition: NTSession.cc:208
CancellationManager * cancellation_manager_
Definition: NTSession.h:309
std::unique_ptr< Graph > graph
Definition: NTSession.h:150
::tensorflow::Status Create(const GraphDef &graph) override
Definition: NTSession.cc:299
std::unordered_map< StringPiece, Node *, StringPiece::Hasher > NameNodeMap
Definition: NTSession.h:84
const int64 operation_timeout_in_ms_
Definition: NTSession.h:339
NTSessionFactory *const factory_
Definition: NTSession.h:308
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: NTSession.h:325
graphs
Definition: cuy.py:960
const DebugOptions & debug_options
Definition: NTSession.h:196
std::pair< std::string, std::shared_ptr< void > > fetch(const cond::Hash &payloadId, Session &session)
Definition: CondDBFetch.cc:327
CostModelManager cost_model_manager_
Definition: NTSession.h:342
def move(src, dest)
Definition: eostools.py:510
static const int ERROR
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: NTSession.cc:312