39 #include "tensorflow/core/common_runtime/constant_folding.h" 40 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 41 #include "tensorflow/core/common_runtime/device_factory.h" 42 #include "tensorflow/core/common_runtime/executor.h" 43 #include "tensorflow/core/common_runtime/function.h" 44 #include "tensorflow/core/common_runtime/graph_optimizer.h" 45 #include "tensorflow/core/common_runtime/memory_types.h" 46 #include "tensorflow/core/common_runtime/optimization_registry.h" 47 #include "tensorflow/core/common_runtime/step_stats_collector.h" 48 #include "tensorflow/core/framework/function.h" 49 #include "tensorflow/core/framework/graph.pb_text.h" 50 #include "tensorflow/core/framework/graph.pb.h" 51 #include "tensorflow/core/framework/graph_def_util.h" 52 #include "tensorflow/core/framework/log_memory.h" 53 #include "tensorflow/core/framework/node_def.pb.h" 54 #include "tensorflow/core/framework/tensor.h" 55 #include "tensorflow/core/framework/versions.pb.h" 56 #include "tensorflow/core/graph/algorithm.h" 57 #include "tensorflow/core/graph/graph.h" 58 #include "tensorflow/core/graph/graph_constructor.h" 59 #include "tensorflow/core/graph/graph_partition.h" 60 #include "tensorflow/core/graph/subgraph.h" 61 #include "tensorflow/core/graph/tensor_id.h" 62 #include "tensorflow/core/lib/core/errors.h" 63 #include "tensorflow/core/lib/core/notification.h" 64 #include "tensorflow/core/lib/core/refcount.h" 65 #include "tensorflow/core/lib/core/status.h" 66 #include "tensorflow/core/lib/gtl/array_slice.h" 67 #include "tensorflow/core/lib/gtl/stl_util.h" 68 #include "tensorflow/core/lib/monitoring/counter.h" 69 #include "tensorflow/core/lib/strings/numbers.h" 70 #include "tensorflow/core/lib/strings/str_util.h" 71 #include "tensorflow/core/lib/strings/strcat.h" 72 #include "tensorflow/core/platform/cpu_info.h" 73 #include "tensorflow/core/platform/device_tracer.h" 74 #include "tensorflow/core/platform/logging.h" 75 #include "tensorflow/core/platform/mutex.h" 76 #include "tensorflow/core/platform/types.h" 77 #include "tensorflow/core/util/device_name_utils.h" 78 #include "tensorflow/core/util/env_var.h" 85 CMS_THREAD_SAFE auto* nothreads_session_runs = monitoring::Counter<0>::New(
86 "/tensorflow/core/nothreads_session_runs",
87 "The number of times NTSession::Run() has been called.");
93 string GetRendezvousKey(
const string& tensor_name,
94 const DeviceAttributes& device_info,
95 const FrameAndIter& frame_iter) {
96 return strings::StrCat(device_info.name(),
";",
97 strings::FpToString(device_info.incarnation()),
";",
98 device_info.name(),
";", tensor_name,
";",
99 frame_iter.frame_id,
":", frame_iter.iter_id);
109 return options.target ==
"no_threads";
114 if (options.config.graph_options().build_cost_model() > 0) {
115 EnableCPUAllocatorFullStats(
true);
117 std::vector<Device*> devices;
118 const Status s = DeviceFactory::AddDevices(
119 options,
"/job:localhost/replica:0/task:0", &devices);
126 new NTSession(options,
new DeviceMgr(devices),
this);
129 sessions_.push_back(session);
135 const std::vector<string>& containers)
override {
136 std::vector<NTSession*> sessions_to_reset;
145 for (
auto session : sessions_to_reset) {
146 s.Update(
session->Reset(containers));
150 for (
auto session : sessions_to_reset) {
164 std::vector<NTSession*> sessions_
GUARDED_BY(sessions_lock_);
198 const DeviceMgr* device_mgr,
201 device_mgr_(device_mgr),
203 cancellation_manager_(new CancellationManager()),
204 operation_timeout_in_ms_(options_.
config.operation_timeout_in_ms()) {
210 LOG(
ERROR) << status.error_message();
216 int devices_added = 0;
217 if (options.config.log_device_placement()) {
218 const string mapping_str =
device_mgr_->DeviceMappingString();
219 if (mapping_str.empty()) {
220 printf(
"Device mapping: no known devices.\n");
222 printf(
"Device mapping:\n%s", mapping_str.c_str());
224 LOG(
INFO) <<
"Device mapping:\n" << mapping_str;
233 if (devices_added == 0) {
241 if (!closed_)
Close().IgnoreError();
242 for (
auto& it : partial_runs_) {
243 it.second.reset(
nullptr);
245 for (
auto& it : executors_) {
253 execution_state_.reset(
nullptr);
258 const GraphDef& graph,
bool* out_already_initialized) {
261 *out_already_initialized =
true;
268 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
269 GraphExecutionStateOptions
options;
271 options.session_options = &
options_;
280 GraphDef
temp(graph);
282 GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
283 graph_created_ =
true;
284 *out_already_initialized =
false;
290 if (graph.node_size() > 0) {
292 if (graph_created_) {
293 return errors::AlreadyExists(
294 "A Graph has already been created for this session.");
308 bool already_initialized;
313 if (already_initialized) {
314 TF_RETURN_IF_ERROR(
flib_def_->AddLibrary(graph.library()));
315 std::unique_ptr<GraphExecutionState> state;
316 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
317 execution_state_.swap(state);
323 const std::vector<string>& output_names,
324 const std::vector<string>& target_nodes,
325 std::vector<Tensor>*
outputs) {
326 RunMetadata run_metadata;
327 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
332 const DebugOptions& debug_options, int64 session_run_index,
333 int64 executor_step_index,
const std::vector<string>& input_names,
334 const std::vector<string>& output_names,
335 const std::vector<string>& target_names,
336 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
338 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
339 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
340 debug_options.global_step(), session_run_index, executor_step_index,
341 input_names, output_names, target_names));
346 const DebugOptions& debug_options, Graph* graph, Device* device) {
347 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
349 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
351 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
352 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
358 const std::vector<string>& output_names,
359 const std::vector<string>& target_nodes,
361 RunMetadata* run_metadata) {
363 nothreads_session_runs->GetCell()->IncrementBy(1);
366 if (!graph_created_) {
367 return errors::InvalidArgument(
368 "Session was not created with a graph before Run()!");
373 std::vector<string> input_tensor_names;
374 input_tensor_names.reserve(inputs.size());
375 for (
const auto& it : inputs) {
376 input_tensor_names.push_back(it.first);
381 RunStateArgs run_state_args(run_options.debug_options());
388 &executors_and_keys, &run_state_args));
389 const int64 executor_step_count = executors_and_keys->
step_count.fetch_add(1);
391 std::unique_ptr<DebuggerStateInterface> debugger_state;
392 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
394 run_options.debug_options(), args.step_id, executor_step_count,
395 input_tensor_names, output_names, target_nodes, &debugger_state));
400 FunctionCallFrame call_frame(executors_and_keys->
input_types,
402 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
403 for (
const auto& it : inputs) {
404 if (it.second.dtype() == DT_RESOURCE) {
405 Tensor tensor_from_handle;
414 const Status s = call_frame.SetArgs(feed_args);
415 if (errors::IsInternal(s)) {
416 return errors::InvalidArgument(s.error_message());
417 }
else if (!s.ok()) {
424 CancellationManager step_cancellation_manager;
425 args.call_frame = &call_frame;
428 const size_t num_executors = executors_and_keys->
items.size();
429 ExecutorBarrier* barrier =
new ExecutorBarrier(
430 num_executors, run_state.
rendez, [&run_state](
const Status& ret) {
432 mutex_lock l(run_state.mu_);
433 run_state.status.Update(ret);
438 args.rendezvous = run_state.rendez;
439 args.cancellation_manager = &step_cancellation_manager;
442 args.tensor_store = &run_state.tensor_store;
443 args.step_container = &run_state.step_container;
444 if (LogMemory::IsEnabled()) {
445 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
449 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
451 bool update_cost_model =
false;
452 if (
options_.config.graph_options().build_cost_model() > 0) {
453 const int64 build_cost_model_every =
454 options_.config.graph_options().build_cost_model();
455 const int64 build_cost_model_after =
456 options_.config.graph_options().build_cost_model_after();
457 int64 measure_step_count = executor_step_count - build_cost_model_after;
458 if (measure_step_count >= 0) {
460 ((measure_step_count + 1) % build_cost_model_every == 0);
463 if (do_trace || update_cost_model ||
464 run_options.report_tensor_allocations_upon_oom()) {
465 run_state.collector.reset(
466 new StepStatsCollector(run_metadata->mutable_step_stats()));
467 args.stats_collector = run_state.collector.get();
470 std::unique_ptr<DeviceTracer> tracer;
471 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
472 tracer = CreateDeviceTracer();
477 run_state.executors_done.Notify();
486 const CancellationToken cancellation_token =
489 cancellation_token, [&step_cancellation_manager]() {
490 step_cancellation_manager.StartCancel();
492 if (already_cancelled) {
496 run_state.executors_done.Notify();
498 return errors::Cancelled(
"Run call was cancelled");
503 Executor::Args::Runner default_runner = [
this](Executor::Args::Closure
c) {
506 for (
const auto& item : executors_and_keys->items) {
519 args.runner = default_runner;
520 item.executor->RunAsync(
args, barrier->Get());
524 run_options.timeout_in_ms() > 0
525 ? run_options.timeout_in_ms()
531 mutex_lock
l(run_state.mu_);
532 run_state.status.Update(errors::Cancelled(
"Run call was cancelled"));
536 TF_RETURN_IF_ERROR(tracer->Stop());
537 TF_RETURN_IF_ERROR(tracer->Collect(
args.stats_collector));
541 mutex_lock
l(run_state.mu_);
542 TF_RETURN_IF_ERROR(run_state.status);
547 std::vector<Tensor> sorted_outputs;
548 const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
549 if (errors::IsInternal(s)) {
550 return errors::InvalidArgument(s.error_message());
551 }
else if (!s.ok()) {
554 const bool unique_outputs =
555 output_names.size() == executors_and_keys->output_name_to_index.size();
558 std::vector<int> first_indices;
559 if (!unique_outputs) {
560 first_indices.resize(output_names.size());
561 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
562 for (
int j = 0; j <=
i; ++j) {
563 if (output_names[
i] == output_names[j]) {
564 first_indices[
i] = j;
571 outputs->reserve(sorted_outputs.size());
572 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
573 const string& output_name = output_names[
i];
574 if (first_indices.empty() || first_indices[
i] ==
i) {
575 outputs->emplace_back(
576 std::move(sorted_outputs[executors_and_keys
577 ->output_name_to_index[output_name]]));
579 outputs->push_back((*outputs)[first_indices[
i]]);
586 run_state.tensor_store.SaveTensors(output_names, &
session_state_));
587 if (
args.stats_collector) {
588 args.stats_collector->Finalize();
593 if (update_cost_model) {
595 std::unordered_map<string, const Graph*> device_to_graph;
597 executors_and_keys->items) {
598 const Graph* graph = partition.
graph;
599 const string device = partition.
flib->device()->name();
600 device_to_graph[device] = graph;
605 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
606 for (
const auto& item : executors_and_keys->items) {
613 if (run_options.output_partition_graphs()) {
614 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
615 run_metadata->mutable_partition_graphs();
617 executors_and_keys->items) {
618 GraphDef* partition_graph_def = partition_graph_defs->Add();
619 exec_and_lib.
graph->ToGraphDef(partition_graph_def);
627 const std::vector<string>& output_names,
628 const std::vector<string>& target_nodes,
633 if (!graph_created_) {
634 return errors::InvalidArgument(
635 "Session was not created with a graph before PRunSetup()!");
642 DebugOptions debug_options;
646 target_nodes, &executors_and_keys,
658 .emplace(run_state_args.
handle,
659 std::unique_ptr<RunState>(run_state))
661 return errors::Internal(
"The handle '", run_state_args.
handle,
662 "' created for this partial run is not unique.");
667 const size_t num_executors = executors_and_keys->
items.size();
668 ExecutorBarrier* barrier =
new ExecutorBarrier(
669 num_executors, run_state->
rendez, [run_state](
const Status& ret) {
671 mutex_lock l(run_state->mu_);
672 run_state->status.Update(ret);
677 args.rendezvous = run_state->rendez;
679 args.runner = [
this](Executor::Args::Closure
c) {
683 args.tensor_store = &run_state->tensor_store;
684 args.step_container = &run_state->step_container;
685 if (LogMemory::IsEnabled()) {
686 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
690 if (
options_.config.graph_options().build_cost_model()) {
691 run_state->collector.reset(
new StepStatsCollector(
nullptr));
692 args.stats_collector = run_state->collector.get();
695 for (
auto& item : executors_and_keys->items) {
696 item.executor->RunAsync(
args, barrier->Get());
699 *
handle = run_state_args.handle;
704 const std::vector<string>& output_names,
705 std::vector<Tensor>*
outputs) {
707 std::vector<string>
parts = str_util::Split(handle,
';');
708 const string&
key = parts[0];
714 auto exc_it = executors_.find(key);
715 if (exc_it == executors_.end()) {
716 return errors::InvalidArgument(
717 "Must run 'setup' before performing partial runs!");
719 executors_and_keys = exc_it->second.get();
721 auto prun_it = partial_runs_.find(handle);
722 if (prun_it == partial_runs_.end()) {
723 return errors::InvalidArgument(
724 "Must run 'setup' before performing partial runs!");
726 run_state = prun_it->second.get();
729 for (
const auto&
input : inputs) {
732 return errors::InvalidArgument(
733 "The feed ",
input.first,
734 " was not specified in partial_run_setup.");
735 }
else if (it->second) {
736 return errors::InvalidArgument(
"The feed ",
input.first,
737 " has already been fed.");
741 for (
const auto&
output : output_names) {
744 return errors::InvalidArgument(
745 "The fetch ",
output,
" was not specified in partial_run_setup.");
746 }
else if (it->second) {
747 return errors::InvalidArgument(
"The fetch ",
output,
748 " has already been fetched.");
756 CheckFetch(inputs, output_names, executors_and_keys, run_state));
763 s =
RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
777 mutex_lock
l(run_state->
mu_);
778 if (!run_state->status.ok()) {
779 LOG(
WARNING) <<
"An error unrelated to this prun has been detected. " 780 << run_state->status;
783 for (
const auto&
input : inputs) {
787 for (
const auto&
name : output_names) {
796 partial_runs_.erase(handle);
804 Tensor* retrieved_tensor) {
805 if (resource_tensor.dtype() != DT_RESOURCE) {
806 return errors::InvalidArgument(strings::StrCat(
807 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
808 resource_tensor.dtype()));
811 const ResourceHandle& resource_handle =
812 resource_tensor.scalar<ResourceHandle>()();
814 if (resource_handle.container() ==
815 SessionState::kTensorHandleResourceTypeName) {
816 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
818 return errors::InvalidArgument(strings::StrCat(
819 "Invalid resource type hash code: ", resource_handle.hash_code(),
820 "(name: ", resource_handle.name(),
821 " type: ", resource_handle.maybe_type_name(),
822 "). Perhaps a resource tensor was being provided as a feed? That is " 823 "not currently allowed. Please file an issue at " 824 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " 825 "short code snippet that leads to this error message."));
831 IntraProcessRendezvous* rendez) {
833 Rendezvous::ParsedKey parsed;
836 for (
const auto&
input : inputs) {
840 return errors::Internal(
"'",
input.first,
"' is not a pre-defined feed.");
842 const string& input_key = it->second;
844 s = Rendezvous::ParseKey(input_key, &parsed);
846 rendez->StartAbort(s);
850 if (
input.second.dtype() == DT_RESOURCE) {
851 Tensor tensor_from_handle;
854 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle,
false);
857 s = rendez->Send(parsed, Rendezvous::Args(),
input.second,
false);
861 rendez->StartAbort(s);
869 const std::vector<string>& output_names,
871 std::vector<Tensor>*
outputs) {
873 if (!output_names.empty()) {
874 outputs->resize(output_names.size());
877 Rendezvous::ParsedKey parsed;
879 for (
size_t output_offset = 0; output_offset < output_names.size();
881 const string& output_name = output_names[output_offset];
885 return errors::Internal(
"'", output_name,
886 "' is not a pre-defined fetch.");
888 const string& output_key = it->second;
889 Tensor output_tensor;
891 IntraProcessRendezvous* rendez = run_state->
rendez;
893 s = Rendezvous::ParseKey(output_key, &parsed);
896 s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
898 if (is_dead && s.ok()) {
899 s = errors::InvalidArgument(
"The tensor returned for ", output_name,
904 rendez->StartAbort(s);
909 (*outputs)[output_offset] = output_tensor;
915 const std::vector<string>& fetches,
918 const Graph* graph = executors_and_keys->
graph.get();
922 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
927 if (
input.second)
continue;
928 TensorId
id(ParseTensorName(
input.first));
929 auto it = name_to_node->find(
id.
first);
930 if (it == name_to_node->end()) {
933 pending_feeds.insert(
id);
936 for (
const auto& it : feeds) {
937 TensorId
id(ParseTensorName(it.first));
938 pending_feeds.erase(
id);
942 std::vector<const Node*>
stack;
943 for (
const string&
fetch : fetches) {
944 TensorId
id(ParseTensorName(
fetch));
945 auto it = name_to_node->find(
id.
first);
946 if (it == name_to_node->end()) {
949 stack.push_back(it->second);
953 std::vector<bool>
visited(graph->num_node_ids(),
false);
954 while (!stack.empty()) {
955 const Node*
n = stack.back();
958 for (
const Edge* in_edge : n->in_edges()) {
959 const Node* in_node = in_edge->src();
960 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
961 return errors::InvalidArgument(
"Fetch ", in_node->name(),
":",
962 in_edge->src_output(),
963 " can't be computed from the feeds" 964 " that have been fed so far.");
968 stack.push_back(in_node);
976 gtl::ArraySlice<string>
inputs, gtl::ArraySlice<string>
outputs,
977 gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys,
979 int64 handle_name_counter_value = -1;
984 string debug_tensor_watches_summary;
985 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
986 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
991 const string key = strings::StrCat(
992 str_util::Join(inputs,
","),
"->", str_util::Join(outputs,
","),
"/",
993 str_util::Join(target_nodes,
","),
"/", run_state_args->
is_partial_run,
994 "/", debug_tensor_watches_summary);
996 if (handle_name_counter_value >= 0) {
998 strings::StrCat(key,
";", handle_name_counter_value);
1004 auto it = executors_.find(key);
1005 if (it != executors_.end()) {
1006 *executors_and_keys = it->second.get();
1017 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1018 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1019 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1020 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1021 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1022 std::sort(tn_sorted.begin(), tn_sorted.end());
1024 const string sorted_key = strings::StrCat(
1025 str_util::Join(inputs_sorted,
","),
"->",
1026 str_util::Join(outputs_sorted,
","),
"/", str_util::Join(tn_sorted,
","),
1027 "/", run_state_args->
is_partial_run,
"/", debug_tensor_watches_summary);
1029 if (handle_name_counter_value >= 0) {
1031 strings::StrCat(sorted_key,
";", handle_name_counter_value);
1037 auto it = executors_.find(sorted_key);
1038 if (it != executors_.end()) {
1039 *executors_and_keys = it->second.get();
1041 executors_.emplace(key, it->second);
1048 options.feed_endpoints = inputs_sorted;
1049 options.fetch_endpoints = outputs_sorted;
1050 options.target_nodes = tn_sorted;
1051 options.use_function_convention = !run_state_args->
is_partial_run;
1052 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
1060 std::unordered_map<string, std::unique_ptr<Graph>>
graphs;
1061 TF_RETURN_IF_ERROR(
CreateGraphs(options, &graphs, &ek->flib_def,
1062 run_state_args, &ek->input_types,
1063 &ek->output_types));
1067 std::unordered_set<StringPiece, StringPieceHasher>
names;
1068 for (
const string&
input : inputs) {
1069 TensorId
id(ParseTensorName(
input));
1070 names.emplace(
id.
first);
1072 for (
const string&
output : outputs) {
1073 TensorId
id(ParseTensorName(
output));
1074 names.emplace(
id.
first);
1076 for (Node*
n : ek->graph->nodes()) {
1077 if (names.count(
n->name()) > 0) {
1078 ek->name_to_node.insert({
n->name(),
n});
1082 ek->items.reserve(graphs.size());
1083 const auto& optimizer_opts =
1084 options_.config.graph_options().optimizer_options();
1086 int graph_def_version;
1090 execution_state_->original_graph_def().versions().producer();
1092 ek->proc_flr.reset(
new ProcessFunctionLibraryRuntime(
1096 GraphOptimizer optimizer(optimizer_opts);
1097 for (
auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1098 const string& partition_name = iter->first;
1099 std::unique_ptr<Graph>& partition_graph = iter->second;
1102 TF_RETURN_IF_ERROR(
device_mgr_->LookupDevice(partition_name, &device));
1104 ek->items.resize(ek->items.size() + 1);
1105 auto* item = &(ek->items.back());
1106 auto lib = ek->proc_flr->GetFLR(partition_name);
1107 if (
lib ==
nullptr) {
1108 return errors::Internal(
"Could not find device: ", partition_name);
1112 LocalExecutorParams params;
1113 params.device = device;
1114 params.function_library =
lib;
1115 auto opseg = device->op_segment();
1116 params.create_kernel = [
this,
lib, opseg](
const NodeDef& ndef,
1117 OpKernel** kernel) {
1124 if (!
lib->IsStateful(ndef.op()) ||
1125 lib->GetFunctionLibraryDefinition()->Find(ndef.op()) !=
nullptr) {
1126 return lib->CreateKernel(ndef, kernel);
1128 auto create_fn = [
lib, &ndef](OpKernel** kernel) {
1129 return lib->CreateKernel(ndef, kernel);
1137 params.delete_kernel = [
lib](OpKernel* kernel) {
1139 if (kernel && !
lib->IsStateful(kernel->type_string())) {
1145 optimizer.Optimize(
lib,
options_.env, device, &iter->second,
1149 if (!options.debug_options.debug_tensor_watch_opts().empty()) {
1151 options.debug_options, partition_graph.get(), params.device));
1154 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1156 partition_graph.get()));
1158 item->graph = partition_graph.get();
1159 item->executor =
nullptr;
1160 item->device = device;
1163 NewLocalExecutor(params, partition_graph.release(), &executor));
1164 item->executor.reset(executor);
1173 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
1174 const string&
input = inputs_sorted[
i];
1175 ek->input_name_to_index[
input] =
i;
1177 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
1178 const string&
output = outputs_sorted[
i];
1179 ek->output_name_to_index[
output] =
i;
1187 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
1188 const string&
input = inputs_sorted[
i];
1189 ek->input_name_to_rendezvous_key[
input] = GetRendezvousKey(
1190 input,
device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1192 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
1193 const string&
output = outputs_sorted[
i];
1194 ek->output_name_to_rendezvous_key[
output] =
1195 GetRendezvousKey(output,
device_set_.client_device()->attributes(),
1196 FrameAndIter(0, 0));
1205 auto insert_result = executors_.emplace(sorted_key, ek);
1208 executors_.emplace(key, insert_result.first->second);
1209 *executors_and_keys = insert_result.first->second.get();
1215 const BuildGraphOptions& subgraph_options,
1216 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
1217 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1218 RunStateArgs* run_state_args, DataTypeVector* input_types,
1219 DataTypeVector* output_types) {
1221 std::unique_ptr<ClientGraph> client_graph;
1223 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1224 GraphExecutionState* execution_state =
nullptr;
1225 if (
options_.config.graph_options().place_pruned_graph()) {
1229 GraphExecutionStateOptions prune_options;
1231 prune_options.session_options = &
options_;
1232 prune_options.stateful_placements = stateful_placements_;
1233 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1234 execution_state_->original_graph_def().library(), prune_options,
1235 execution_state_->original_graph_def(), subgraph_options,
1236 &temp_exec_state_holder, &client_graph));
1237 execution_state = temp_exec_state_holder.get();
1239 execution_state = execution_state_.get();
1241 execution_state->BuildGraph(subgraph_options, &client_graph));
1244 if (subgraph_options.feed_endpoints.size() !=
1245 client_graph->feed_types.size()) {
1246 return errors::Internal(
1247 "Graph pruning failed: requested number of feed endpoints = ",
1248 subgraph_options.feed_endpoints.size(),
1249 " versus number of pruned feed endpoints = ",
1250 client_graph->feed_types.size());
1252 if (subgraph_options.fetch_endpoints.size() !=
1253 client_graph->fetch_types.size()) {
1254 return errors::Internal(
1255 "Graph pruning failed: requested number of fetch endpoints = ",
1256 subgraph_options.fetch_endpoints.size(),
1257 " versus number of pruned fetch endpoints = ",
1258 client_graph->fetch_types.size());
1261 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1265 for (
auto placement_pair : current_stateful_placements) {
1266 const string& node_name = placement_pair.first;
1267 const string& placement = placement_pair.second;
1268 auto iter = stateful_placements_.find(node_name);
1269 if (iter == stateful_placements_.end()) {
1270 stateful_placements_.insert(std::make_pair(node_name, placement));
1271 }
else if (iter->second != placement) {
1272 return errors::Internal(
1273 "Stateful placement mismatch. " 1274 "Current assignment of ",
1275 node_name,
" to ", iter->second,
" does not match ", placement);
1279 stateful_placements_ = execution_state->GetStatefulPlacements();
1284 CopyGraph(*execution_state->full_graph(), run_state_args->
graph.get());
1288 PartitionOptions popts;
1289 popts.node_to_loc = [](
const Node* node) {
1290 assert(node !=
nullptr);
1291 return node->assigned_device_name();
1293 popts.new_name = [
this](
const string&
prefix) {
1296 popts.get_incarnation = [](
const string&
name) {
1301 popts.flib_def = &client_graph->graph.flib_def();
1302 popts.control_flow_added =
false;
1304 std::unordered_map<string, GraphDef>
partitions;
1305 TF_RETURN_IF_ERROR(
Partition(popts, &client_graph->graph, &partitions));
1307 std::vector<string> device_names;
1310 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1314 for (
const auto& partition : partitions) {
1315 const string local_partition_name =
1316 DeviceNameUtils::LocalName(partition.first);
1317 if (
std::count(device_names.begin(), device_names.end(),
1318 local_partition_name) == 0) {
1319 return errors::InvalidArgument(
1320 "Creating a partition for ", local_partition_name,
1321 " which doesn't exist in the list of available devices. Available " 1323 str_util::Join(device_names,
","));
1327 for (
const auto& partition : partitions) {
1328 std::unique_ptr<Graph> device_graph(
1329 new Graph(client_graph->flib_def.get()));
1330 GraphConstructorOptions device_opts;
1332 device_opts.allow_internal_ops =
true;
1333 device_opts.expect_device_spec =
true;
1334 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1335 device_graph.get()));
1339 GraphOptimizationPassOptions optimization_options;
1340 optimization_options.session_options = &
options_;
1341 optimization_options.flib_def = client_graph->flib_def.get();
1342 optimization_options.partition_graphs =
outputs;
1343 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1344 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1347 for (
auto& partition : *
outputs) {
1348 const string& partition_name = partition.first;
1349 std::unique_ptr<Graph>* graph = &partition.second;
1351 VLOG(2) <<
"Created " << DebugString(graph->get()) <<
" for " 1356 s =
device_mgr_->LookupDevice(partition_name, &d);
1358 s = d->MaybeRewriteGraph(graph);
1363 *flib_def =
std::move(client_graph->flib_def);
1364 std::swap(*input_types, client_graph->feed_types);
1365 std::swap(*output_types, client_graph->fetch_types);
1370 std::vector<DeviceAttributes>* response) {
1372 response->reserve(
devices_.size());
1374 const DeviceAttributes& attrs =
d->attributes();
1375 response->emplace_back(attrs);
1381 const std::vector<string>& containers) {
1398 const std::vector<string>& pending_input_names,
1399 const std::vector<string>& pending_output_names, int64 step_id,
1400 const std::vector<Device*>* devices)
1401 : step_container(step_id, [devices](const
string&
name) {
1402 for (
auto d : *devices) {
1403 if (!
d->resource_manager()->Cleanup(
name).ok()) {
1409 for (
auto&
name : pending_input_names) {
1412 for (
auto&
name : pending_output_names) {
1418 const std::vector<Device*>* devices)
1419 :
RunState({}, {}, step_id, devices) {}
1424 rendez->StartAbort(errors::Cancelled(
"PRun cancellation"));
1433 if (!it.second)
return false;
1436 if (!it.second)
return false;
1442 CancellationManager* cm,
1443 int64 timeout_in_ms) {
1448 mutex_lock
l(run_state->
mu_);
1449 run_state->status.Update(status);
1460 Notification* notification, int64 timeout_in_ms) {
1461 if (timeout_in_ms > 0) {
1462 const int64 timeout_in_us = timeout_in_ms * 1000;
1463 const bool notified =
1464 WaitForNotificationWithTimeout(notification, timeout_in_us);
1466 return Status(error::DEADLINE_EXCEEDED,
1467 "Timed out waiting for notification");
1470 notification->WaitForNotification();
DataTypeVector output_types
static boost::mutex mutex
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
std::vector< PerPartitionExecutorsAndLib > items
static std::atomic_int_fast64_t step_id_counter_
::tensorflow::Status Reset(const std::vector< string > &containers)
FunctionLibraryRuntime * flib
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
const SessionOptions options_
static const HistoName names[]
std::unordered_map< string, string > input_name_to_rendezvous_key
IntraProcessRendezvous * rendez
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
void Deregister(const NTSession *session)
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
SessionState session_state_
std::unordered_map< string, string > output_name_to_rendezvous_key
bool AcceptsOptions(const SessionOptions &options) override
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
std::vector< std::pair< string, Tensor > > NamedTensorList
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
RunState(int64 step_id, const std::vector< Device * > *devices)
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
static std::string const input
DataTypeVector input_types
Notification executors_done
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
static NTSessionRegistrar registrar
::tensorflow::Status CheckNotClosed()
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
std::unordered_map< string, size_t > input_name_to_index
std::atomic< int64 > edge_name_counter_
::tensorflow::Status Close() override
std::atomic< int64 > handle_name_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
Session * NewSession(const SessionOptions &options) override
std::pair< int, edm::FunctionWithDict > OK
std::unique_ptr< Graph > graph
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
std::vector< Device * > devices_
Executor::Args::NodeOutputsCallback node_outputs_callback_
std::atomic_int_fast64_t step_count
def remove(d, key, TELL=False)
void SchedClosure(std::function< void()> c)
std::vector< NTSession * > sessions_ GUARDED_BY(sessions_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
std::unordered_map< string, bool > pending_inputs
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
CancellationManager * cancellation_manager_
std::unique_ptr< Graph > graph
::tensorflow::Status Create(const GraphDef &graph) override
const int64 operation_timeout_in_ms_
NTSessionFactory *const factory_
std::unique_ptr< FunctionLibraryDefinition > flib_def_
const DebugOptions & debug_options
std::pair< std::string, std::shared_ptr< void > > fetch(const cond::Hash &payloadId, Session &session)
CostModelManager cost_model_manager_
::tensorflow::Status Extend(const GraphDef &graph) override
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap