39 #include "tensorflow/core/common_runtime/constant_folding.h" 40 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 41 #include "tensorflow/core/common_runtime/device_factory.h" 42 #include "tensorflow/core/common_runtime/executor.h" 43 #include "tensorflow/core/common_runtime/function.h" 44 #include "tensorflow/core/common_runtime/graph_optimizer.h" 45 #include "tensorflow/core/common_runtime/memory_types.h" 46 #include "tensorflow/core/common_runtime/optimization_registry.h" 47 #include "tensorflow/core/common_runtime/simple_placer.h" 48 #include "tensorflow/core/common_runtime/step_stats_collector.h" 49 #include "tensorflow/core/framework/function.h" 50 #include "tensorflow/core/framework/graph.pb_text.h" 51 #include "tensorflow/core/framework/graph.pb.h" 52 #include "tensorflow/core/framework/graph_def_util.h" 53 #include "tensorflow/core/framework/log_memory.h" 54 #include "tensorflow/core/framework/node_def.pb.h" 55 #include "tensorflow/core/framework/tensor.h" 56 #include "tensorflow/core/framework/versions.pb.h" 57 #include "tensorflow/core/graph/algorithm.h" 58 #include "tensorflow/core/graph/graph.h" 59 #include "tensorflow/core/graph/graph_constructor.h" 60 #include "tensorflow/core/graph/graph_partition.h" 61 #include "tensorflow/core/graph/subgraph.h" 62 #include "tensorflow/core/graph/tensor_id.h" 63 #include "tensorflow/core/lib/core/errors.h" 64 #include "tensorflow/core/lib/core/notification.h" 65 #include "tensorflow/core/lib/core/refcount.h" 66 #include "tensorflow/core/lib/core/status.h" 67 #include "tensorflow/core/lib/gtl/array_slice.h" 68 #include "tensorflow/core/lib/gtl/stl_util.h" 69 #include "tensorflow/core/lib/monitoring/counter.h" 70 #include "tensorflow/core/lib/strings/numbers.h" 71 #include "tensorflow/core/lib/strings/str_util.h" 72 #include "tensorflow/core/lib/strings/strcat.h" 73 #include "tensorflow/core/platform/cpu_info.h" 74 #include "tensorflow/core/platform/logging.h" 75 #include "tensorflow/core/platform/mutex.h" 76 #include "tensorflow/core/platform/types.h" 77 #include "tensorflow/core/util/device_name_utils.h" 78 #include "tensorflow/core/util/env_var.h" 81 #include "tensorflow/core/common_runtime/gpu/gpu_tracer.h" 88 CMS_THREAD_SAFE auto* nothreads_session_runs = monitoring::Counter<0>::New(
89 "/tensorflow/core/nothreads_session_runs",
90 "The number of times NTSession::Run() has been called.");
96 string GetRendezvousKey(
const string& tensor_name,
97 const DeviceAttributes& device_info,
98 const FrameAndIter& frame_iter) {
99 return strings::StrCat(device_info.name(),
";",
100 strings::FpToString(device_info.incarnation()),
";",
101 device_info.name(),
";", tensor_name,
";",
102 frame_iter.frame_id,
":", frame_iter.iter_id);
112 return options.target ==
"no_threads";
117 if (options.config.graph_options().build_cost_model() > 0) {
118 EnableCPUAllocatorFullStats(
true);
120 std::vector<Device*> devices;
121 Status s = DeviceFactory::AddDevices(
122 options,
"/job:localhost/replica:0/task:0", &devices);
129 new NTSession(options,
new DeviceMgr(devices),
this);
132 sessions_.push_back(session);
138 const std::vector<string>& containers)
override {
139 std::vector<NTSession*> sessions_to_reset;
148 for (
auto session : sessions_to_reset) {
149 s.Update(
session->Reset(containers));
153 for (
auto session : sessions_to_reset) {
167 std::vector<NTSession*> sessions_
GUARDED_BY(sessions_lock_);
209 const DeviceMgr* device_mgr,
212 device_mgr_(device_mgr),
214 cancellation_manager_(new CancellationManager()),
215 operation_timeout_in_ms_(options_.
config.operation_timeout_in_ms()) {
221 LOG(
ERROR) << status.error_message();
227 int devices_added = 0;
228 if (options.config.log_device_placement()) {
229 const string mapping_str =
device_mgr_->DeviceMappingString();
230 if (mapping_str.empty()) {
231 printf(
"Device mapping: no known devices.\n");
233 printf(
"Device mapping:\n%s", mapping_str.c_str());
235 LOG(
INFO) <<
"Device mapping:\n" << mapping_str;
244 if (devices_added == 0) {
252 if (!closed_)
Close().IgnoreError();
253 for (
auto& it : partial_runs_) {
254 it.second.reset(
nullptr);
256 for (
auto& it : executors_) {
264 execution_state_.reset(
nullptr);
269 const GraphDef& graph,
bool* out_already_initialized) {
272 *out_already_initialized =
true;
279 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
280 SimpleGraphExecutionStateOptions
options;
282 options.session_options = &
options_;
291 GraphDef
temp(graph);
292 TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
293 &temp, options, &execution_state_));
294 graph_created_ =
true;
295 *out_already_initialized =
false;
301 if (graph.node_size() > 0) {
303 if (graph_created_) {
304 return errors::AlreadyExists(
305 "A Graph has already been created for this session.");
319 bool already_initialized;
324 if (already_initialized) {
325 TF_RETURN_IF_ERROR(
flib_def_->AddLibrary(graph.library()));
326 std::unique_ptr<SimpleGraphExecutionState> state;
327 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
328 execution_state_.swap(state);
334 const std::vector<string>& output_names,
335 const std::vector<string>& target_nodes,
336 std::vector<Tensor>*
outputs) {
337 RunMetadata run_metadata;
338 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
343 const DebugOptions& debug_options, int64 session_run_index,
344 int64 executor_step_index,
const std::vector<string>&
input_names,
345 const std::vector<string>& output_names,
346 const std::vector<string>& target_names,
347 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
349 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
350 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
351 debug_options.global_step(), session_run_index, executor_step_index,
357 const DebugOptions& debug_options, Graph* graph, Device* device) {
358 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
360 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
362 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
363 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
369 const std::vector<string>& output_names,
370 const std::vector<string>& target_nodes,
372 RunMetadata* run_metadata) {
374 nothreads_session_runs->GetCell()->IncrementBy(1);
377 if (!graph_created_) {
378 return errors::InvalidArgument(
379 "Session was not created with a graph before Run()!");
384 std::vector<string> input_tensor_names;
385 input_tensor_names.reserve(inputs.size());
386 for (
const auto& it : inputs) {
387 input_tensor_names.push_back(it.first);
393 RunStateArgs run_state_args(run_options.debug_options());
400 &executors_and_keys, &run_state_args));
401 const int64 executor_step_count = executors_and_keys->
step_count.fetch_add(1);
403 std::unique_ptr<DebuggerStateInterface> debugger_state;
404 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
406 run_options.debug_options(), args.step_id, executor_step_count,
407 input_tensor_names, output_names, target_nodes, &debugger_state));
412 FunctionCallFrame call_frame(executors_and_keys->
input_types,
414 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
415 for (
const auto& it : inputs) {
416 if (it.second.dtype() == DT_RESOURCE) {
417 Tensor tensor_from_handle;
426 Status s = call_frame.SetArgs(feed_args);
427 if (errors::IsInternal(s)) {
428 return errors::InvalidArgument(s.error_message());
429 }
else if (!s.ok()) {
436 CancellationManager step_cancellation_manager;
437 args.call_frame = &call_frame;
440 const size_t num_executors = executors_and_keys->
items.size();
441 ExecutorBarrier* barrier =
new ExecutorBarrier(
442 num_executors, run_state.
rendez, [&run_state](
const Status& ret) {
444 mutex_lock l(run_state.mu_);
445 run_state.status.Update(ret);
450 args.rendezvous = run_state.rendez;
451 args.cancellation_manager = &step_cancellation_manager;
452 args.runner = [
this](Executor::Args::Closure
c) {
456 args.tensor_store = &run_state.tensor_store;
457 args.step_container = &run_state.step_container;
458 if (LogMemory::IsEnabled()) {
459 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
463 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
465 bool update_cost_model =
false;
466 if (
options_.config.graph_options().build_cost_model() > 0) {
467 const int64 build_cost_model_every =
468 options_.config.graph_options().build_cost_model();
469 const int64 build_cost_model_after =
470 options_.config.graph_options().build_cost_model_after();
471 int64 measure_step_count = executor_step_count - build_cost_model_after;
472 if (measure_step_count >= 0) {
474 ((measure_step_count + 1) % build_cost_model_every == 0);
477 if (do_trace || update_cost_model) {
478 run_state.collector.reset(
479 new StepStatsCollector(run_metadata->mutable_step_stats()));
480 args.stats_collector = run_state.collector.get();
484 std::unique_ptr<GPUTracer> tracer;
485 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
486 tracer.reset(CreateGPUTracer());
489 if (tracer) tracer->Start().IgnoreError();
491 #endif // GOOGLE_CUDA 495 CancellationToken cancellation_token =
498 cancellation_token, [&step_cancellation_manager]() {
499 step_cancellation_manager.StartCancel();
501 if (already_cancelled) {
505 run_state.executors_done.Notify();
507 return errors::Cancelled(
"Run call was cancelled");
510 for (
const auto& item : executors_and_keys->items) {
511 item.executor->RunAsync(
args, barrier->Get());
515 run_options.timeout_in_ms() > 0
516 ? run_options.timeout_in_ms()
522 mutex_lock
l(run_state.mu_);
523 run_state.status.Update(errors::Cancelled(
"Run call was cancelled"));
529 tracer->Stop().IgnoreError();
530 tracer->Collect(
args.stats_collector).IgnoreError();
532 #endif // GOOGLE_CUDA 535 mutex_lock
l(run_state.mu_);
536 TF_RETURN_IF_ERROR(run_state.status);
541 std::vector<Tensor> sorted_outputs;
542 Status s = call_frame.ConsumeRetvals(&sorted_outputs);
543 if (errors::IsInternal(s)) {
544 return errors::InvalidArgument(s.error_message());
545 }
else if (!s.ok()) {
548 const bool unique_outputs =
549 output_names.size() == executors_and_keys->output_name_to_index.size();
552 std::vector<int> first_indices;
553 if (!unique_outputs) {
554 first_indices.resize(output_names.size());
555 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
556 for (
int j = 0; j <=
i; ++j) {
557 if (output_names[
i] == output_names[j]) {
558 first_indices[
i] = j;
565 outputs->reserve(sorted_outputs.size());
566 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
567 const string& output_name = output_names[
i];
568 if (first_indices.empty() || first_indices[
i] ==
i) {
569 outputs->emplace_back(
570 std::move(sorted_outputs[executors_and_keys
571 ->output_name_to_index[output_name]]));
573 outputs->push_back((*outputs)[first_indices[
i]]);
580 run_state.tensor_store.SaveTensors(output_names, &
session_state_));
584 if (update_cost_model) {
586 std::unordered_map<string, const Graph*> device_to_graph;
588 executors_and_keys->items) {
589 const Graph* graph = partition.
graph;
590 const string device = partition.
flib->device()->name();
591 device_to_graph[device] = graph;
596 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
597 for (
const auto& item : executors_and_keys->items) {
604 if (run_options.output_partition_graphs()) {
605 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
606 run_metadata->mutable_partition_graphs();
608 executors_and_keys->items) {
609 GraphDef* partition_graph_def = partition_graph_defs->Add();
610 exec_and_lib.
graph->ToGraphDef(partition_graph_def);
618 const std::vector<string>& output_names,
619 const std::vector<string>& target_nodes,
624 if (!graph_created_) {
625 return errors::InvalidArgument(
626 "Session was not created with a graph before PRunSetup()!");
633 DebugOptions debug_options;
637 target_nodes, &executors_and_keys,
649 .emplace(run_state_args.
handle,
650 std::unique_ptr<RunState>(run_state))
652 return errors::Internal(
"The handle '", run_state_args.
handle,
653 "' created for this partial run is not unique.");
658 const size_t num_executors = executors_and_keys->
items.size();
659 ExecutorBarrier* barrier =
new ExecutorBarrier(
660 num_executors, run_state->
rendez, [run_state](
const Status& ret) {
662 mutex_lock l(run_state->mu_);
663 run_state->status.Update(ret);
668 args.rendezvous = run_state->rendez;
670 args.runner = [
this](Executor::Args::Closure
c) {
674 args.tensor_store = &run_state->tensor_store;
675 args.step_container = &run_state->step_container;
676 if (LogMemory::IsEnabled()) {
677 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
681 if (
options_.config.graph_options().build_cost_model()) {
682 run_state->collector.reset(
new StepStatsCollector(
nullptr));
683 args.stats_collector = run_state->collector.get();
686 for (
auto& item : executors_and_keys->items) {
687 item.executor->RunAsync(
args, barrier->Get());
690 *
handle = run_state_args.handle;
695 const std::vector<string>& output_names,
696 std::vector<Tensor>*
outputs) {
698 std::vector<string>
parts = str_util::Split(handle,
';');
699 const string&
key = parts[0];
705 auto exc_it = executors_.find(key);
706 if (exc_it == executors_.end()) {
707 return errors::InvalidArgument(
708 "Must run 'setup' before performing partial runs!");
710 executors_and_keys = exc_it->second.get();
712 auto prun_it = partial_runs_.find(handle);
713 if (prun_it == partial_runs_.end()) {
714 return errors::InvalidArgument(
715 "Must run 'setup' before performing partial runs!");
717 run_state = prun_it->second.get();
720 for (
const auto&
input : inputs) {
723 return errors::InvalidArgument(
724 "The feed ",
input.first,
725 " was not specified in partial_run_setup.");
726 }
else if (it->second) {
727 return errors::InvalidArgument(
"The feed ",
input.first,
728 " has already been fed.");
732 for (
const auto&
output : output_names) {
735 return errors::InvalidArgument(
736 "The fetch ",
output,
" was not specified in partial_run_setup.");
737 }
else if (it->second) {
738 return errors::InvalidArgument(
"The fetch ",
output,
739 " has already been fetched.");
747 CheckFetch(inputs, output_names, executors_and_keys, run_state));
754 s =
RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
768 mutex_lock
l(run_state->
mu_);
769 if (!run_state->status.ok()) {
770 LOG(
WARNING) <<
"An error unrelated to this prun has been detected. " 771 << run_state->status;
774 for (
const auto&
input : inputs) {
778 for (
const auto&
name : output_names) {
787 partial_runs_.erase(handle);
795 Tensor* retrieved_tensor) {
796 if (resource_tensor.dtype() != DT_RESOURCE) {
797 return errors::InvalidArgument(strings::StrCat(
798 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
799 resource_tensor.dtype()));
802 ResourceHandle resource_handle = resource_tensor.scalar<ResourceHandle>()();
804 if (resource_handle.container() ==
805 SessionState::kTensorHandleResourceTypeName) {
806 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
808 return errors::InvalidArgument(strings::StrCat(
809 "Invalid resource type hash code: ", resource_handle.hash_code(),
810 "(name: ", resource_handle.name(),
811 " type: ", resource_handle.maybe_type_name(),
")"));
817 IntraProcessRendezvous* rendez) {
819 Rendezvous::ParsedKey parsed;
822 for (
const auto&
input : inputs) {
826 return errors::Internal(
"'",
input.first,
"' is not a pre-defined feed.");
828 const string& input_key = it->second;
830 s = Rendezvous::ParseKey(input_key, &parsed);
832 rendez->StartAbort(s);
836 if (
input.second.dtype() == DT_RESOURCE) {
837 Tensor tensor_from_handle;
840 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle,
false);
843 s = rendez->Send(parsed, Rendezvous::Args(),
input.second,
false);
847 rendez->StartAbort(s);
855 const std::vector<string>& output_names,
857 std::vector<Tensor>*
outputs) {
859 if (!output_names.empty()) {
860 outputs->resize(output_names.size());
863 Rendezvous::ParsedKey parsed;
865 for (
size_t output_offset = 0; output_offset < output_names.size();
867 const string& output_name = output_names[output_offset];
871 return errors::Internal(
"'", output_name,
872 "' is not a pre-defined fetch.");
874 const string& output_key = it->second;
875 Tensor output_tensor;
877 IntraProcessRendezvous* rendez = run_state->
rendez;
879 s = Rendezvous::ParseKey(output_key, &parsed);
882 s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
884 if (is_dead && s.ok()) {
885 s = errors::InvalidArgument(
"The tensor returned for ", output_name,
890 rendez->StartAbort(s);
895 (*outputs)[output_offset] = output_tensor;
901 const std::vector<string>& fetches,
904 const Graph* graph = executors_and_keys->
graph.get();
908 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
913 if (
input.second)
continue;
914 TensorId
id(ParseTensorName(
input.first));
915 auto it = name_to_node->find(
id.
first);
916 if (it == name_to_node->end()) {
919 pending_feeds.insert(
id);
922 for (
const auto& it : feeds) {
923 TensorId
id(ParseTensorName(it.first));
924 pending_feeds.erase(
id);
928 std::vector<const Node*>
stack;
929 for (
const string&
fetch : fetches) {
930 TensorId
id(ParseTensorName(
fetch));
931 auto it = name_to_node->find(
id.
first);
932 if (it == name_to_node->end()) {
935 stack.push_back(it->second);
939 std::vector<bool>
visited(graph->num_node_ids(),
false);
940 while (!stack.empty()) {
941 const Node*
n = stack.back();
944 for (
const Edge* in_edge : n->in_edges()) {
945 const Node* in_node = in_edge->src();
946 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
947 return errors::InvalidArgument(
"Fetch ", in_node->name(),
":",
948 in_edge->src_output(),
949 " can't be computed from the feeds" 950 " that have been fed so far.");
954 stack.push_back(in_node);
962 gtl::ArraySlice<string>
inputs,
963 gtl::ArraySlice<string>
outputs, gtl::ArraySlice<string> target_nodes,
965 int64 handle_name_counter_value = -1;
970 string debug_tensor_watches_summary;
971 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
972 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
977 const string key = strings::StrCat(
978 str_util::Join(inputs,
","),
"->", str_util::Join(outputs,
","),
"/",
979 str_util::Join(target_nodes,
","),
"/", run_state_args->
is_partial_run,
980 "/", debug_tensor_watches_summary);
982 if (handle_name_counter_value >= 0) {
984 strings::StrCat(key,
";", handle_name_counter_value);
990 auto it = executors_.find(key);
991 if (it != executors_.end()) {
992 *executors_and_keys = it->second.get();
1003 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1004 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1005 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1006 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1007 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1008 std::sort(tn_sorted.begin(), tn_sorted.end());
1010 const string sorted_key = strings::StrCat(
1011 str_util::Join(inputs_sorted,
","),
"->",
1012 str_util::Join(outputs_sorted,
","),
"/", str_util::Join(tn_sorted,
","),
1013 "/", run_state_args->
is_partial_run,
"/", debug_tensor_watches_summary);
1015 if (handle_name_counter_value >= 0) {
1017 strings::StrCat(sorted_key,
";", handle_name_counter_value);
1023 auto it = executors_.find(sorted_key);
1024 if (it != executors_.end()) {
1025 *executors_and_keys = it->second.get();
1027 executors_.emplace(key, it->second);
1034 options.feed_endpoints = inputs_sorted;
1035 options.fetch_endpoints = outputs_sorted;
1036 options.target_nodes = tn_sorted;
1037 options.use_function_convention = !run_state_args->
is_partial_run;
1038 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
1046 std::unordered_map<string, std::unique_ptr<Graph>>
graphs;
1047 TF_RETURN_IF_ERROR(
CreateGraphs(options, &graphs, &ek->flib_def,
1048 run_state_args, &ek->input_types,
1049 &ek->output_types));
1053 std::unordered_set<StringPiece, StringPiece::Hasher>
names;
1054 for (
const string&
input : inputs) {
1055 TensorId
id(ParseTensorName(
input));
1056 names.emplace(
id.
first);
1058 for (
const string&
output : outputs) {
1059 TensorId
id(ParseTensorName(
output));
1060 names.emplace(
id.
first);
1062 for (Node*
n : ek->graph->nodes()) {
1063 if (names.count(
n->name()) > 0) {
1064 ek->name_to_node.insert({
n->name(),
n});
1068 ek->items.reserve(graphs.size());
1069 const auto& optimizer_opts =
1070 options_.config.graph_options().optimizer_options();
1071 GraphOptimizer optimizer(optimizer_opts);
1072 for (
auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1073 const string& partition_name = iter->first;
1074 std::unique_ptr<Graph>& partition_graph = iter->second;
1075 const int graph_def_version = partition_graph->versions().producer();
1078 TF_RETURN_IF_ERROR(
device_mgr_->LookupDevice(partition_name, &device));
1080 ek->items.resize(ek->items.size() + 1);
1081 auto* item = &(ek->items.back());
1082 item->flib.reset(NewFunctionLibraryRuntime(
1084 ek->flib_def.get(), optimizer_opts));
1086 LocalExecutorParams params;
1087 params.device = device;
1088 params.function_library = item->flib.get();
1089 auto lib = item->flib.get();
1090 auto opseg = device->op_segment();
1091 params.create_kernel = [
this,
lib, opseg](
const NodeDef& ndef,
1092 OpKernel** kernel) {
1094 if (!
lib->IsStateful(ndef.op())) {
1095 return lib->CreateKernel(ndef, kernel);
1097 auto create_fn = [
lib, &ndef](OpKernel** kernel) {
1098 return lib->CreateKernel(ndef, kernel);
1106 params.delete_kernel = [
lib](OpKernel* kernel) {
1108 if (kernel && !
lib->IsStateful(kernel->type_string())) {
1114 optimizer.Optimize(
lib,
options_.env, device, &iter->second);
1117 if (!options.debug_options.debug_tensor_watch_opts().empty()) {
1119 options.debug_options, partition_graph.get(), params.device));
1122 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1124 partition_graph.get()));
1126 item->graph = partition_graph.get();
1127 item->executor =
nullptr;
1130 NewLocalExecutor(params, partition_graph.release(), &executor));
1131 item->executor.reset(executor);
1140 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
1141 const string&
input = inputs_sorted[
i];
1142 ek->input_name_to_index[
input] =
i;
1144 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
1145 const string&
output = outputs_sorted[
i];
1146 ek->output_name_to_index[
output] =
i;
1154 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
1155 const string&
input = inputs_sorted[
i];
1156 ek->input_name_to_rendezvous_key[
input] = GetRendezvousKey(
1157 input,
device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1159 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
1160 const string&
output = outputs_sorted[
i];
1161 ek->output_name_to_rendezvous_key[
output] =
1162 GetRendezvousKey(output,
device_set_.client_device()->attributes(),
1163 FrameAndIter(0, 0));
1172 auto insert_result = executors_.emplace(sorted_key, ek);
1175 executors_.emplace(key, insert_result.first->second);
1176 *executors_and_keys = insert_result.first->second.get();
1182 const BuildGraphOptions& subgraph_options,
1183 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
1184 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1185 RunStateArgs* run_state_args, DataTypeVector* input_types,
1186 DataTypeVector* output_types) {
1188 std::unique_ptr<SimpleClientGraph> client_graph;
1190 std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
1191 SimpleGraphExecutionState* execution_state =
nullptr;
1192 if (
options_.config.graph_options().place_pruned_graph()) {
1196 SimpleGraphExecutionStateOptions prune_options;
1198 prune_options.session_options = &
options_;
1199 prune_options.stateful_placements = stateful_placements_;
1200 TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph(
1201 execution_state_->original_graph_def().library(), prune_options,
1202 execution_state_->original_graph_def(), subgraph_options,
1203 &temp_exec_state_holder, &client_graph));
1204 execution_state = temp_exec_state_holder.get();
1206 execution_state = execution_state_.get();
1208 execution_state->BuildGraph(subgraph_options, &client_graph));
1211 if (subgraph_options.feed_endpoints.size() !=
1212 client_graph->feed_types.size()) {
1213 return errors::Internal(
1214 "Graph pruning failed: requested number of feed endpoints = ",
1215 subgraph_options.feed_endpoints.size(),
1216 " versus number of pruned feed endpoints = ",
1217 client_graph->feed_types.size());
1219 if (subgraph_options.fetch_endpoints.size() !=
1220 client_graph->fetch_types.size()) {
1221 return errors::Internal(
1222 "Graph pruning failed: requested number of fetch endpoints = ",
1223 subgraph_options.fetch_endpoints.size(),
1224 " versus number of pruned fetch endpoints = ",
1225 client_graph->fetch_types.size());
1228 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1232 for (
auto placement_pair : current_stateful_placements) {
1233 const string& node_name = placement_pair.first;
1234 const string& placement = placement_pair.second;
1235 auto iter = stateful_placements_.find(node_name);
1236 if (iter == stateful_placements_.end()) {
1237 stateful_placements_.insert(std::make_pair(node_name, placement));
1238 }
else if (iter->second != placement) {
1239 return errors::Internal(
1240 "Stateful placement mismatch. " 1241 "Current assignment of ",
1242 node_name,
" to ", iter->second,
" does not match ", placement);
1246 stateful_placements_ = execution_state->GetStatefulPlacements();
1251 CopyGraph(*execution_state->full_graph(), run_state_args->
graph.get());
1255 PartitionOptions popts;
1256 popts.node_to_loc = [](
const Node* node) {
1257 assert(node !=
nullptr);
1258 return node->assigned_device_name();
1260 popts.new_name = [
this](
const string&
prefix) {
1263 popts.get_incarnation = [](
const string&
name) {
1268 popts.control_flow_added =
false;
1270 std::unordered_map<string, GraphDef> partitions;
1271 TF_RETURN_IF_ERROR(
Partition(popts, &client_graph->graph, &partitions));
1273 std::vector<string> device_names;
1276 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1280 for (
const auto& partition : partitions) {
1281 const string local_partition_name =
1282 DeviceNameUtils::LocalName(partition.first);
1283 if (
std::count(device_names.begin(), device_names.end(),
1284 local_partition_name) == 0) {
1285 return errors::InvalidArgument(
1286 "Creating a partition for ", local_partition_name,
1287 " which doesn't exist in the list of available devices. Available " 1289 str_util::Join(device_names,
","));
1293 for (
const auto& partition : partitions) {
1294 std::unique_ptr<Graph> device_graph(
1295 new Graph(client_graph->flib_def.get()));
1296 GraphConstructorOptions device_opts;
1298 device_opts.allow_internal_ops =
true;
1299 device_opts.expect_device_spec =
true;
1300 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1301 device_graph.get()));
1305 GraphOptimizationPassOptions optimization_options;
1306 optimization_options.session_options = &
options_;
1307 optimization_options.flib_def = client_graph->flib_def.get();
1308 optimization_options.partition_graphs =
outputs;
1309 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1310 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1313 for (
auto& partition : *
outputs) {
1314 const string& partition_name = partition.first;
1315 std::unique_ptr<Graph>* graph = &partition.second;
1317 VLOG(2) <<
"Created " << DebugString(graph->get()) <<
" for " 1322 s =
device_mgr_->LookupDevice(partition_name, &d);
1328 s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph);
1333 *flib_def =
std::move(client_graph->flib_def);
1334 std::swap(*input_types, client_graph->feed_types);
1335 std::swap(*output_types, client_graph->fetch_types);
1340 std::vector<DeviceAttributes>* response) {
1342 response->reserve(
devices_.size());
1344 const DeviceAttributes& attrs =
d->attributes();
1345 response->emplace_back(attrs);
1351 const std::vector<string>& containers) {
1368 const std::vector<string>& pending_input_names,
1369 const std::vector<string>& pending_output_names, int64 step_id,
1370 const std::vector<Device*>* devices)
1371 : step_container(step_id, [devices](const
string&
name) {
1372 for (
auto d : *devices) {
1373 if (!
d->resource_manager()->Cleanup(
name).ok()) {
1379 for (
auto&
name : pending_input_names) {
1382 for (
auto&
name : pending_output_names) {
1388 const std::vector<Device*>* devices)
1389 :
RunState({}, {}, step_id, devices) {}
1394 rendez->StartAbort(errors::Cancelled(
"PRun cancellation"));
1403 if (!it.second)
return false;
1406 if (!it.second)
return false;
1412 CancellationManager* cm,
1413 int64 timeout_in_ms) {
1418 mutex_lock
l(run_state->
mu_);
1419 run_state->status.Update(status);
1430 Notification* notification, int64 timeout_in_ms) {
1431 if (timeout_in_ms > 0) {
1432 int64 timeout_in_us = timeout_in_ms * 1000;
1433 bool notified = WaitForNotificationWithTimeout(notification, timeout_in_us);
1435 return Status(error::DEADLINE_EXCEEDED,
1436 "Timed out waiting for notification");
1439 notification->WaitForNotification();
DataTypeVector output_types
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
std::vector< PerPartitionExecutorsAndLib > items
static std::atomic_int_fast64_t step_id_counter_
::tensorflow::Status Reset(const std::vector< string > &containers)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
const SessionOptions options_
static const HistoName names[]
std::unordered_map< string, string > input_name_to_rendezvous_key
static boost::mutex mutex
IntraProcessRendezvous * rendez
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
void Deregister(const NTSession *session)
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
SessionState session_state_
std::unordered_map< string, string > output_name_to_rendezvous_key
bool AcceptsOptions(const SessionOptions &options) override
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
std::vector< std::pair< string, Tensor > > NamedTensorList
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
RunState(int64 step_id, const std::vector< Device * > *devices)
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
static std::string const input
DataTypeVector input_types
Notification executors_done
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
static NTSessionRegistrar registrar
::tensorflow::Status CheckNotClosed()
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
std::unordered_map< string, size_t > input_name_to_index
std::atomic< int64 > edge_name_counter_
::tensorflow::Status Close() override
std::atomic< int64 > handle_name_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
Session * NewSession(const SessionOptions &options) override
std::pair< int, edm::FunctionWithDict > OK
std::unique_ptr< Graph > graph
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
std::vector< Device * > devices_
Executor::Args::NodeOutputsCallback node_outputs_callback_
std::atomic_int_fast64_t step_count
std::unique_ptr< FunctionLibraryRuntime > flib
def remove(d, key, TELL=False)
void SchedClosure(std::function< void()> c)
std::vector< NTSession * > sessions_ GUARDED_BY(sessions_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
std::unordered_map< string, bool > pending_inputs
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
CancellationManager * cancellation_manager_
std::unique_ptr< Graph > graph
::tensorflow::Status Create(const GraphDef &graph) override
std::unordered_map< StringPiece, Node *, StringPiece::Hasher > NameNodeMap
const int64 operation_timeout_in_ms_
NTSessionFactory *const factory_
std::unique_ptr< FunctionLibraryDefinition > flib_def_
const DebugOptions & debug_options
std::pair< std::string, std::shared_ptr< void > > fetch(const cond::Hash &payloadId, Session &session)
CostModelManager cost_model_manager_
::tensorflow::Status Extend(const GraphDef &graph) override