39 #include "tensorflow/core/common_runtime/constant_folding.h" 40 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 41 #include "tensorflow/core/common_runtime/device_factory.h" 42 #include "tensorflow/core/common_runtime/executor.h" 43 #include "tensorflow/core/common_runtime/function.h" 44 #include "tensorflow/core/common_runtime/graph_optimizer.h" 45 #include "tensorflow/core/common_runtime/memory_types.h" 46 #include "tensorflow/core/common_runtime/optimization_registry.h" 47 #include "tensorflow/core/common_runtime/step_stats_collector.h" 48 #include "tensorflow/core/framework/function.h" 49 #include "tensorflow/core/framework/graph.pb_text.h" 50 #include "tensorflow/core/framework/graph.pb.h" 51 #include "tensorflow/core/framework/graph_def_util.h" 52 #include "tensorflow/core/framework/log_memory.h" 53 #include "tensorflow/core/framework/node_def.pb.h" 54 #include "tensorflow/core/framework/tensor.h" 55 #include "tensorflow/core/framework/versions.pb.h" 56 #include "tensorflow/core/graph/algorithm.h" 57 #include "tensorflow/core/graph/graph.h" 58 #include "tensorflow/core/graph/graph_constructor.h" 59 #include "tensorflow/core/graph/graph_partition.h" 60 #include "tensorflow/core/graph/subgraph.h" 61 #include "tensorflow/core/graph/tensor_id.h" 62 #include "tensorflow/core/lib/core/errors.h" 63 #include "tensorflow/core/lib/core/notification.h" 64 #include "tensorflow/core/lib/core/refcount.h" 65 #include "tensorflow/core/lib/core/status.h" 66 #include "tensorflow/core/lib/gtl/array_slice.h" 67 #include "tensorflow/core/lib/gtl/stl_util.h" 68 #include "tensorflow/core/lib/monitoring/counter.h" 69 #include "tensorflow/core/lib/strings/numbers.h" 70 #include "tensorflow/core/lib/strings/str_util.h" 71 #include "tensorflow/core/lib/strings/strcat.h" 72 #include "tensorflow/core/platform/cpu_info.h" 73 #include "tensorflow/core/platform/device_tracer.h" 74 #include "tensorflow/core/platform/logging.h" 75 #include "tensorflow/core/platform/mutex.h" 76 #include "tensorflow/core/platform/types.h" 77 #include "tensorflow/core/util/device_name_utils.h" 78 #include "tensorflow/core/util/env_var.h" 84 CMS_THREAD_SAFE auto* nothreads_session_runs = monitoring::Counter<0>::New(
85 "/tensorflow/core/nothreads_session_runs",
86 "The number of times NTSession::Run() has been called.");
92 string GetRendezvousKey(
const string& tensor_name,
93 const DeviceAttributes& device_info,
94 const FrameAndIter& frame_iter) {
95 return strings::StrCat(device_info.name(),
";",
96 strings::FpToString(device_info.incarnation()),
";",
97 device_info.name(),
";", tensor_name,
";",
98 frame_iter.frame_id,
":", frame_iter.iter_id);
108 return options.target ==
"no_threads";
113 if (options.config.graph_options().build_cost_model() > 0) {
114 EnableCPUAllocatorFullStats(
true);
116 std::vector<Device*> devices;
117 const Status s = DeviceFactory::AddDevices(
118 options,
"/job:localhost/replica:0/task:0", &devices);
125 new NTSession(options,
new DeviceMgr(devices),
this);
128 sessions_.push_back(session);
134 const std::vector<string>& containers)
override {
135 std::vector<NTSession*> sessions_to_reset;
144 for (
auto session : sessions_to_reset) {
145 s.Update(
session->Reset(containers));
149 for (
auto session : sessions_to_reset) {
163 std::vector<NTSession*> sessions_
GUARDED_BY(sessions_lock_);
197 const DeviceMgr* device_mgr,
200 device_mgr_(device_mgr),
202 cancellation_manager_(new CancellationManager()),
203 operation_timeout_in_ms_(options_.
config.operation_timeout_in_ms()) {
209 LOG(
ERROR) << status.error_message();
215 int devices_added = 0;
216 if (options.config.log_device_placement()) {
217 const string mapping_str =
device_mgr_->DeviceMappingString();
218 if (mapping_str.empty()) {
219 printf(
"Device mapping: no known devices.\n");
221 printf(
"Device mapping:\n%s", mapping_str.c_str());
223 LOG(
INFO) <<
"Device mapping:\n" << mapping_str;
232 if (devices_added == 0) {
240 if (!closed_)
Close().IgnoreError();
241 for (
auto& it : partial_runs_) {
242 it.second.reset(
nullptr);
244 for (
auto& it : executors_) {
251 d->ClearResourceMgr();
256 execution_state_.reset(
nullptr);
261 const GraphDef& graph,
bool* out_already_initialized) {
264 *out_already_initialized =
true;
271 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
272 GraphExecutionStateOptions
options;
274 options.session_options = &
options_;
283 GraphDef
temp(graph);
285 GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
286 graph_created_ =
true;
287 *out_already_initialized =
false;
293 if (graph.node_size() > 0) {
295 if (graph_created_) {
296 return errors::AlreadyExists(
297 "A Graph has already been created for this session.");
311 bool already_initialized;
316 if (already_initialized) {
317 TF_RETURN_IF_ERROR(
flib_def_->AddLibrary(graph.library()));
318 std::unique_ptr<GraphExecutionState> state;
319 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
320 execution_state_.swap(state);
326 const std::vector<string>& output_names,
327 const std::vector<string>& target_nodes,
328 std::vector<Tensor>*
outputs) {
329 RunMetadata run_metadata;
330 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
335 const DebugOptions& debug_options, int64 session_run_index,
336 int64 executor_step_index,
const std::vector<string>&
input_names,
337 const std::vector<string>& output_names,
338 const std::vector<string>& target_names,
339 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
341 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
342 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
343 debug_options.global_step(), session_run_index, executor_step_index,
349 const DebugOptions& debug_options,
Graph* graph, Device* device) {
350 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
352 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
354 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
355 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
361 const std::vector<string>& output_names,
362 const std::vector<string>& target_nodes,
364 RunMetadata* run_metadata) {
366 nothreads_session_runs->GetCell()->IncrementBy(1);
369 if (!graph_created_) {
370 return errors::InvalidArgument(
371 "Session was not created with a graph before Run()!");
376 std::vector<string> input_tensor_names;
377 input_tensor_names.reserve(inputs.size());
378 for (
const auto& it : inputs) {
379 input_tensor_names.push_back(it.first);
384 RunStateArgs run_state_args(run_options.debug_options());
391 target_nodes, &executors_and_keys,
393 const int64 executor_step_count = executors_and_keys->
step_count.fetch_add(1);
395 std::unique_ptr<DebuggerStateInterface> debugger_state;
396 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
398 run_options.debug_options(), args.step_id, executor_step_count,
399 input_tensor_names, output_names, target_nodes, &debugger_state));
404 FunctionCallFrame call_frame(executors_and_keys->
input_types,
406 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
407 for (
const auto& it : inputs) {
408 if (it.second.dtype() == DT_RESOURCE) {
409 Tensor tensor_from_handle;
418 const Status s = call_frame.SetArgs(feed_args);
419 if (errors::IsInternal(s)) {
420 return errors::InvalidArgument(s.error_message());
421 }
else if (!s.ok()) {
428 CancellationManager step_cancellation_manager;
429 args.call_frame = &call_frame;
432 const size_t num_executors = executors_and_keys->
items.size();
433 ExecutorBarrier* barrier =
new ExecutorBarrier(
434 num_executors, run_state.
rendez, [&run_state](
const Status& ret) {
436 mutex_lock l(run_state.mu_);
437 run_state.status.Update(ret);
442 args.rendezvous = run_state.rendez;
443 args.cancellation_manager = &step_cancellation_manager;
446 args.tensor_store = &run_state.tensor_store;
447 args.step_container = &run_state.step_container;
448 if (LogMemory::IsEnabled()) {
449 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
453 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
455 bool update_cost_model =
false;
456 if (
options_.config.graph_options().build_cost_model() > 0) {
457 const int64 build_cost_model_every =
458 options_.config.graph_options().build_cost_model();
459 const int64 build_cost_model_after =
460 options_.config.graph_options().build_cost_model_after();
461 int64 measure_step_count = executor_step_count - build_cost_model_after;
462 if (measure_step_count >= 0) {
464 ((measure_step_count + 1) % build_cost_model_every == 0);
467 if (do_trace || update_cost_model ||
468 run_options.report_tensor_allocations_upon_oom()) {
469 run_state.collector.reset(
470 new StepStatsCollector(run_metadata->mutable_step_stats()));
471 args.stats_collector = run_state.collector.get();
474 std::unique_ptr<DeviceTracer> tracer;
475 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
476 tracer = CreateDeviceTracer();
481 run_state.executors_done.Notify();
490 const CancellationToken cancellation_token =
493 cancellation_token, [&step_cancellation_manager]() {
494 step_cancellation_manager.StartCancel();
496 if (already_cancelled) {
500 run_state.executors_done.Notify();
502 return errors::Cancelled(
"Run call was cancelled");
507 Executor::Args::Runner default_runner = [
this](Executor::Args::Closure
c) {
510 for (
const auto& item : executors_and_keys->items) {
523 args.runner = default_runner;
524 item.executor->RunAsync(
args, barrier->Get());
528 run_options.timeout_in_ms() > 0
529 ? run_options.timeout_in_ms()
535 mutex_lock
l(run_state.mu_);
536 run_state.status.Update(errors::Cancelled(
"Run call was cancelled"));
540 TF_RETURN_IF_ERROR(tracer->Stop());
541 TF_RETURN_IF_ERROR(tracer->Collect(
args.stats_collector));
545 mutex_lock
l(run_state.mu_);
546 TF_RETURN_IF_ERROR(run_state.status);
551 std::vector<Tensor> sorted_outputs;
552 const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
553 if (errors::IsInternal(s)) {
554 return errors::InvalidArgument(s.error_message());
555 }
else if (!s.ok()) {
558 const bool unique_outputs =
559 output_names.size() == executors_and_keys->output_name_to_index.size();
562 std::vector<int> first_indices;
563 if (!unique_outputs) {
564 first_indices.resize(output_names.size());
565 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
566 for (
int j = 0; j <=
i; ++j) {
567 if (output_names[
i] == output_names[j]) {
568 first_indices[
i] = j;
575 outputs->reserve(sorted_outputs.size());
576 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
577 const string& output_name = output_names[
i];
578 if (first_indices.empty() || first_indices[
i] ==
i) {
579 outputs->emplace_back(
580 std::move(sorted_outputs[executors_and_keys
581 ->output_name_to_index[output_name]]));
583 outputs->push_back((*outputs)[first_indices[
i]]);
590 run_state.tensor_store.SaveTensors(output_names, &
session_state_));
591 if (
args.stats_collector) {
592 args.stats_collector->Finalize();
597 if (update_cost_model) {
599 std::unordered_map<string, const Graph*> device_to_graph;
601 executors_and_keys->items) {
603 const string device = partition.
flib->device()->name();
604 device_to_graph[device] = graph;
609 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
610 for (
const auto& item : executors_and_keys->items) {
617 if (run_options.output_partition_graphs()) {
618 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
619 run_metadata->mutable_partition_graphs();
621 executors_and_keys->items) {
622 GraphDef* partition_graph_def = partition_graph_defs->Add();
623 exec_and_lib.
graph->ToGraphDef(partition_graph_def);
631 const std::vector<string>& output_names,
632 const std::vector<string>& target_nodes,
637 if (!graph_created_) {
638 return errors::InvalidArgument(
639 "Session was not created with a graph before PRunSetup()!");
646 DebugOptions debug_options;
650 target_nodes, &executors_and_keys,
662 .emplace(run_state_args.
handle,
663 std::unique_ptr<RunState>(run_state))
665 return errors::Internal(
"The handle '", run_state_args.
handle,
666 "' created for this partial run is not unique.");
671 const size_t num_executors = executors_and_keys->
items.size();
672 ExecutorBarrier* barrier =
new ExecutorBarrier(
673 num_executors, run_state->
rendez, [run_state](
const Status& ret) {
675 mutex_lock l(run_state->mu_);
676 run_state->status.Update(ret);
681 args.rendezvous = run_state->rendez;
683 args.runner = [
this](Executor::Args::Closure
c) {
687 args.tensor_store = &run_state->tensor_store;
688 args.step_container = &run_state->step_container;
689 if (LogMemory::IsEnabled()) {
690 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
694 if (
options_.config.graph_options().build_cost_model()) {
695 run_state->collector.reset(
new StepStatsCollector(
nullptr));
696 args.stats_collector = run_state->collector.get();
699 for (
auto& item : executors_and_keys->items) {
700 item.executor->RunAsync(
args, barrier->Get());
703 *
handle = run_state_args.handle;
708 const std::vector<string>& output_names,
709 std::vector<Tensor>*
outputs) {
711 std::vector<string>
parts = str_util::Split(handle,
';');
712 const string&
key = parts[0];
718 auto exc_it = executors_.find(key);
719 if (exc_it == executors_.end()) {
720 return errors::InvalidArgument(
721 "Must run 'setup' before performing partial runs!");
723 executors_and_keys = exc_it->second.get();
725 auto prun_it = partial_runs_.find(handle);
726 if (prun_it == partial_runs_.end()) {
727 return errors::InvalidArgument(
728 "Must run 'setup' before performing partial runs!");
730 run_state = prun_it->second.get();
733 for (
const auto&
input : inputs) {
736 return errors::InvalidArgument(
737 "The feed ",
input.first,
738 " was not specified in partial_run_setup.");
739 }
else if (it->second) {
740 return errors::InvalidArgument(
"The feed ",
input.first,
741 " has already been fed.");
745 for (
const auto&
output : output_names) {
748 return errors::InvalidArgument(
749 "The fetch ",
output,
" was not specified in partial_run_setup.");
750 }
else if (it->second) {
751 return errors::InvalidArgument(
"The fetch ",
output,
752 " has already been fetched.");
760 CheckFetch(inputs, output_names, executors_and_keys, run_state));
767 s =
RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
781 mutex_lock
l(run_state->
mu_);
782 if (!run_state->status.ok()) {
783 LOG(
WARNING) <<
"An error unrelated to this prun has been detected. " 784 << run_state->status;
787 for (
const auto&
input : inputs) {
791 for (
const auto&
name : output_names) {
800 partial_runs_.erase(handle);
808 Tensor* retrieved_tensor) {
809 if (resource_tensor.dtype() != DT_RESOURCE) {
810 return errors::InvalidArgument(strings::StrCat(
811 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
812 resource_tensor.dtype()));
815 const ResourceHandle& resource_handle =
816 resource_tensor.scalar<ResourceHandle>()();
818 if (resource_handle.container() ==
819 SessionState::kTensorHandleResourceTypeName) {
820 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
822 return errors::InvalidArgument(strings::StrCat(
823 "Invalid resource type hash code: ", resource_handle.hash_code(),
824 "(name: ", resource_handle.name(),
825 " type: ", resource_handle.maybe_type_name(),
826 "). Perhaps a resource tensor was being provided as a feed? That is " 827 "not currently allowed. Please file an issue at " 828 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " 829 "short code snippet that leads to this error message."));
835 IntraProcessRendezvous* rendez) {
837 Rendezvous::ParsedKey parsed;
840 for (
const auto&
input : inputs) {
844 return errors::Internal(
"'",
input.first,
"' is not a pre-defined feed.");
846 const string& input_key = it->second;
848 s = Rendezvous::ParseKey(input_key, &parsed);
850 rendez->StartAbort(s);
854 if (
input.second.dtype() == DT_RESOURCE) {
855 Tensor tensor_from_handle;
858 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle,
false);
861 s = rendez->Send(parsed, Rendezvous::Args(),
input.second,
false);
865 rendez->StartAbort(s);
873 const std::vector<string>& output_names,
875 std::vector<Tensor>*
outputs) {
877 if (!output_names.empty()) {
878 outputs->resize(output_names.size());
881 Rendezvous::ParsedKey parsed;
883 for (
size_t output_offset = 0; output_offset < output_names.size();
885 const string& output_name = output_names[output_offset];
889 return errors::Internal(
"'", output_name,
890 "' is not a pre-defined fetch.");
892 const string& output_key = it->second;
893 Tensor output_tensor;
895 IntraProcessRendezvous* rendez = run_state->
rendez;
897 s = Rendezvous::ParseKey(output_key, &parsed);
900 s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
902 if (is_dead && s.ok()) {
903 s = errors::InvalidArgument(
"The tensor returned for ", output_name,
908 rendez->StartAbort(s);
913 (*outputs)[output_offset] = output_tensor;
919 const std::vector<string>& fetches,
922 const Graph* graph = executors_and_keys->
graph.get();
926 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
931 if (
input.second)
continue;
932 TensorId
id(ParseTensorName(
input.first));
933 auto it = name_to_node->find(
id.
first);
934 if (it == name_to_node->end()) {
937 pending_feeds.insert(
id);
940 for (
const auto& it : feeds) {
941 TensorId
id(ParseTensorName(it.first));
942 pending_feeds.erase(
id);
946 std::vector<const Node*>
stack;
947 for (
const string&
fetch : fetches) {
948 TensorId
id(ParseTensorName(
fetch));
949 auto it = name_to_node->find(
id.
first);
950 if (it == name_to_node->end()) {
953 stack.push_back(it->second);
957 std::vector<bool>
visited(graph->num_node_ids(),
false);
958 while (!stack.empty()) {
959 const Node*
n = stack.back();
962 for (
const Edge* in_edge : n->in_edges()) {
963 const Node* in_node = in_edge->src();
964 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
965 return errors::InvalidArgument(
"Fetch ", in_node->name(),
":",
966 in_edge->src_output(),
967 " can't be computed from the feeds" 968 " that have been fed so far.");
972 stack.push_back(in_node);
980 gtl::ArraySlice<string>
inputs, gtl::ArraySlice<string>
outputs,
981 gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys,
983 int64 handle_name_counter_value = -1;
988 string debug_tensor_watches_summary;
989 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
990 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
995 const string key = strings::StrCat(
996 str_util::Join(inputs,
","),
"->", str_util::Join(outputs,
","),
"/",
997 str_util::Join(target_nodes,
","),
"/", run_state_args->
is_partial_run,
998 "/", debug_tensor_watches_summary);
1000 if (handle_name_counter_value >= 0) {
1002 strings::StrCat(key,
";", handle_name_counter_value);
1008 auto it = executors_.find(key);
1009 if (it != executors_.end()) {
1010 *executors_and_keys = it->second.get();
1021 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1022 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1023 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1024 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1025 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1026 std::sort(tn_sorted.begin(), tn_sorted.end());
1028 const string sorted_key = strings::StrCat(
1029 str_util::Join(inputs_sorted,
","),
"->",
1030 str_util::Join(outputs_sorted,
","),
"/", str_util::Join(tn_sorted,
","),
1031 "/", run_state_args->
is_partial_run,
"/", debug_tensor_watches_summary);
1033 if (handle_name_counter_value >= 0) {
1035 strings::StrCat(sorted_key,
";", handle_name_counter_value);
1041 auto it = executors_.find(sorted_key);
1042 if (it != executors_.end()) {
1043 *executors_and_keys = it->second.get();
1045 executors_.emplace(key, it->second);
1052 options.feed_endpoints = inputs_sorted;
1053 options.fetch_endpoints = outputs_sorted;
1054 options.target_nodes = tn_sorted;
1055 options.use_function_convention = !run_state_args->
is_partial_run;
1056 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
1060 std::unique_ptr<FunctionInfo> func_info(
new FunctionInfo);
1065 std::unordered_map<string, std::unique_ptr<Graph>>
graphs;
1066 TF_RETURN_IF_ERROR(
CreateGraphs(options, &graphs, &func_info->flib_def,
1067 run_state_args, &ek->input_types,
1068 &ek->output_types));
1072 std::unordered_set<StringPiece, StringPieceHasher>
names;
1073 for (
const string&
input : inputs) {
1074 TensorId
id(ParseTensorName(
input));
1075 names.emplace(
id.
first);
1077 for (
const string&
output : outputs) {
1078 TensorId
id(ParseTensorName(
output));
1079 names.emplace(
id.
first);
1081 for (
Node*
n : ek->graph->nodes()) {
1082 if (names.count(
n->name()) > 0) {
1083 ek->name_to_node.insert({
n->name(),
n});
1087 ek->items.reserve(graphs.size());
1088 const auto& optimizer_opts =
1089 options_.config.graph_options().optimizer_options();
1091 int graph_def_version;
1095 execution_state_->original_graph_def().versions().producer();
1097 func_info->proc_flr.reset(
new ProcessFunctionLibraryRuntime(
1099 func_info->flib_def.get(), optimizer_opts));
1101 GraphOptimizer optimizer(optimizer_opts);
1102 for (
auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1103 const string& partition_name = iter->first;
1104 std::unique_ptr<Graph>& partition_graph = iter->second;
1107 TF_RETURN_IF_ERROR(
device_mgr_->LookupDevice(partition_name, &device));
1109 ek->items.resize(ek->items.size() + 1);
1110 auto* item = &(ek->items.back());
1111 auto lib = func_info->proc_flr->GetFLR(partition_name);
1112 if (
lib ==
nullptr) {
1113 return errors::Internal(
"Could not find device: ", partition_name);
1117 LocalExecutorParams params;
1118 params.device = device;
1119 params.function_library =
lib;
1120 auto opseg = device->op_segment();
1121 params.create_kernel = [
this,
lib, opseg](
const NodeDef& ndef,
1122 OpKernel** kernel) {
1129 if (!
lib->IsStateful(ndef.op()) ||
1130 lib->GetFunctionLibraryDefinition()->Find(ndef.op()) !=
nullptr) {
1131 return lib->CreateKernel(ndef, kernel);
1133 auto create_fn = [
lib, &ndef](OpKernel** kernel) {
1134 return lib->CreateKernel(ndef, kernel);
1142 params.delete_kernel = [
lib](OpKernel* kernel) {
1144 if (kernel && !
lib->IsStateful(kernel->type_string())) {
1150 optimizer.Optimize(
lib,
options_.env, device, &iter->second,
1154 if (!options.debug_options.debug_tensor_watch_opts().empty()) {
1156 options.debug_options, partition_graph.get(), params.device));
1159 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1161 partition_graph.get()));
1163 item->graph = partition_graph.get();
1164 item->executor =
nullptr;
1165 item->device = device;
1168 NewLocalExecutor(params, partition_graph.release(), &executor));
1169 item->executor.reset(executor);
1178 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
1179 const string&
input = inputs_sorted[
i];
1180 ek->input_name_to_index[
input] =
i;
1182 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
1183 const string&
output = outputs_sorted[
i];
1184 ek->output_name_to_index[
output] =
i;
1192 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
1193 const string&
input = inputs_sorted[
i];
1194 ek->input_name_to_rendezvous_key[
input] = GetRendezvousKey(
1195 input,
device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1197 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
1198 const string&
output = outputs_sorted[
i];
1199 ek->output_name_to_rendezvous_key[
output] =
1200 GetRendezvousKey(output,
device_set_.client_device()->attributes(),
1201 FrameAndIter(0, 0));
1207 functions_.push_back(
std::move(func_info));
1211 auto insert_result = executors_.emplace(sorted_key, ek);
1214 executors_.emplace(key, insert_result.first->second);
1215 *executors_and_keys = insert_result.first->second.get();
1221 const BuildGraphOptions& subgraph_options,
1222 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
1223 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1224 RunStateArgs* run_state_args, DataTypeVector* input_types,
1225 DataTypeVector* output_types) {
1227 std::unique_ptr<ClientGraph> client_graph;
1229 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1230 GraphExecutionState* execution_state =
nullptr;
1231 if (
options_.config.graph_options().place_pruned_graph()) {
1235 GraphExecutionStateOptions prune_options;
1237 prune_options.session_options = &
options_;
1238 prune_options.stateful_placements = stateful_placements_;
1239 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1240 execution_state_->original_graph_def().library(), prune_options,
1241 execution_state_->original_graph_def(), subgraph_options,
1242 &temp_exec_state_holder, &client_graph));
1243 execution_state = temp_exec_state_holder.get();
1245 execution_state = execution_state_.get();
1247 execution_state->BuildGraph(subgraph_options, &client_graph));
1250 if (subgraph_options.feed_endpoints.size() !=
1251 client_graph->feed_types.size()) {
1252 return errors::Internal(
1253 "Graph pruning failed: requested number of feed endpoints = ",
1254 subgraph_options.feed_endpoints.size(),
1255 " versus number of pruned feed endpoints = ",
1256 client_graph->feed_types.size());
1258 if (subgraph_options.fetch_endpoints.size() !=
1259 client_graph->fetch_types.size()) {
1260 return errors::Internal(
1261 "Graph pruning failed: requested number of fetch endpoints = ",
1262 subgraph_options.fetch_endpoints.size(),
1263 " versus number of pruned fetch endpoints = ",
1264 client_graph->fetch_types.size());
1267 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1271 for (
auto placement_pair : current_stateful_placements) {
1272 const string& node_name = placement_pair.first;
1273 const string& placement = placement_pair.second;
1274 auto iter = stateful_placements_.find(node_name);
1275 if (iter == stateful_placements_.end()) {
1276 stateful_placements_.insert(std::make_pair(node_name, placement));
1277 }
else if (iter->second != placement) {
1278 return errors::Internal(
1279 "Stateful placement mismatch. " 1280 "Current assignment of ",
1281 node_name,
" to ", iter->second,
" does not match ", placement);
1285 stateful_placements_ = execution_state->GetStatefulPlacements();
1290 CopyGraph(*execution_state->full_graph(), run_state_args->
graph.get());
1294 PartitionOptions popts;
1295 popts.node_to_loc = [](
const Node* node) {
1296 assert(node !=
nullptr);
1297 return node->assigned_device_name();
1299 popts.new_name = [
this](
const string&
prefix) {
1302 popts.get_incarnation = [](
const string&
name) {
1307 popts.flib_def = &client_graph->graph.flib_def();
1308 popts.control_flow_added =
false;
1310 std::unordered_map<string, GraphDef>
partitions;
1311 TF_RETURN_IF_ERROR(
Partition(popts, &client_graph->graph, &partitions));
1313 std::vector<string> device_names;
1316 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1320 for (
const auto& partition : partitions) {
1321 const string local_partition_name =
1322 DeviceNameUtils::LocalName(partition.first);
1323 if (
std::count(device_names.begin(), device_names.end(),
1324 local_partition_name) == 0) {
1325 return errors::InvalidArgument(
1326 "Creating a partition for ", local_partition_name,
1327 " which doesn't exist in the list of available devices. Available " 1329 str_util::Join(device_names,
","));
1333 for (
const auto& partition : partitions) {
1334 std::unique_ptr<Graph> device_graph(
1335 new Graph(client_graph->flib_def.get()));
1336 GraphConstructorOptions device_opts;
1338 device_opts.allow_internal_ops =
true;
1339 device_opts.expect_device_spec =
true;
1340 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1341 device_graph.get()));
1345 GraphOptimizationPassOptions optimization_options;
1346 optimization_options.session_options = &
options_;
1347 optimization_options.flib_def = client_graph->flib_def.get();
1348 optimization_options.partition_graphs =
outputs;
1349 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1350 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1353 for (
auto& partition : *
outputs) {
1354 const string& partition_name = partition.first;
1355 std::unique_ptr<Graph>* graph = &partition.second;
1357 VLOG(2) <<
"Created " << DebugString(graph->get()) <<
" for " 1362 s =
device_mgr_->LookupDevice(partition_name, &d);
1364 s = d->MaybeRewriteGraph(graph);
1369 *flib_def =
std::move(client_graph->flib_def);
1370 std::swap(*input_types, client_graph->feed_types);
1371 std::swap(*output_types, client_graph->fetch_types);
1376 std::vector<DeviceAttributes>* response) {
1378 response->reserve(
devices_.size());
1380 const DeviceAttributes& attrs =
d->attributes();
1381 response->emplace_back(attrs);
1387 const std::vector<string>& containers) {
1404 const std::vector<string>& pending_input_names,
1405 const std::vector<string>& pending_output_names, int64 step_id,
1406 const std::vector<Device*>* devices)
1407 : step_container(step_id, [devices](const
string&
name) {
1408 for (
auto d : *devices) {
1409 if (!
d->resource_manager()->Cleanup(
name).ok()) {
1415 for (
auto&
name : pending_input_names) {
1418 for (
auto&
name : pending_output_names) {
1424 const std::vector<Device*>* devices)
1425 :
RunState({}, {}, step_id, devices) {}
1430 rendez->StartAbort(errors::Cancelled(
"PRun cancellation"));
1439 if (!it.second)
return false;
1442 if (!it.second)
return false;
1448 CancellationManager* cm,
1449 int64 timeout_in_ms) {
1454 mutex_lock
l(run_state->
mu_);
1455 run_state->status.Update(status);
1466 Notification* notification, int64 timeout_in_ms) {
1467 if (timeout_in_ms > 0) {
1468 const int64 timeout_in_us = timeout_in_ms * 1000;
1469 const bool notified =
1470 WaitForNotificationWithTimeout(notification, timeout_in_us);
1472 return Status(error::DEADLINE_EXCEEDED,
1473 "Timed out waiting for notification");
1476 notification->WaitForNotification();
DataTypeVector output_types
static boost::mutex mutex
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
std::vector< PerPartitionExecutorsAndLib > items
static std::atomic_int_fast64_t step_id_counter_
::tensorflow::Status Reset(const std::vector< string > &containers)
FunctionLibraryRuntime * flib
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
const SessionOptions options_
std::unordered_map< string, string > input_name_to_rendezvous_key
IntraProcessRendezvous * rendez
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
void Deregister(const NTSession *session)
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
SessionState session_state_
std::unordered_map< string, string > output_name_to_rendezvous_key
const std::string names[nVars_]
bool AcceptsOptions(const SessionOptions &options) override
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
std::vector< std::pair< string, Tensor > > NamedTensorList
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
RunState(int64 step_id, const std::vector< Device * > *devices)
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
static std::string const input
DataTypeVector input_types
Notification executors_done
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
static NTSessionRegistrar registrar
::tensorflow::Status CheckNotClosed()
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
std::unordered_map< string, size_t > input_name_to_index
std::atomic< int64 > edge_name_counter_
::tensorflow::Status Close() override
std::atomic< int64 > handle_name_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
Session * NewSession(const SessionOptions &options) override
std::pair< int, edm::FunctionWithDict > OK
std::unique_ptr< Graph > graph
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
std::vector< Device * > devices_
Executor::Args::NodeOutputsCallback node_outputs_callback_
std::atomic_int_fast64_t step_count
def remove(d, key, TELL=False)
void SchedClosure(std::function< void()> c)
std::vector< NTSession * > sessions_ GUARDED_BY(sessions_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
std::unordered_map< string, bool > pending_inputs
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
CancellationManager * cancellation_manager_
std::unique_ptr< Graph > graph
::tensorflow::Status Create(const GraphDef &graph) override
const int64 operation_timeout_in_ms_
NTSessionFactory *const factory_
std::unique_ptr< FunctionLibraryDefinition > flib_def_
const DebugOptions & debug_options
std::pair< std::string, std::shared_ptr< void > > fetch(const cond::Hash &payloadId, Session &session)
DDCompactView::Graph Graph
CostModelManager cost_model_manager_
::tensorflow::Status Extend(const GraphDef &graph) override
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap