37 #include "tbb/task_group.h" 41 #include "tensorflow/core/common_runtime/constant_folding.h" 42 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 43 #include "tensorflow/core/common_runtime/device_factory.h" 44 #include "tensorflow/core/common_runtime/executor.h" 45 #include "tensorflow/core/common_runtime/function.h" 46 #include "tensorflow/core/common_runtime/graph_optimizer.h" 47 #include "tensorflow/core/common_runtime/memory_types.h" 48 #include "tensorflow/core/common_runtime/optimization_registry.h" 49 #include "tensorflow/core/common_runtime/step_stats_collector.h" 50 #include "tensorflow/core/framework/function.h" 51 #include "tensorflow/core/framework/graph.pb_text.h" 52 #include "tensorflow/core/framework/graph.pb.h" 53 #include "tensorflow/core/framework/graph_def_util.h" 54 #include "tensorflow/core/framework/log_memory.h" 55 #include "tensorflow/core/framework/node_def.pb.h" 56 #include "tensorflow/core/framework/tensor.h" 57 #include "tensorflow/core/framework/versions.pb.h" 58 #include "tensorflow/core/graph/algorithm.h" 59 #include "tensorflow/core/graph/graph.h" 60 #include "tensorflow/core/graph/graph_constructor.h" 61 #include "tensorflow/core/graph/graph_partition.h" 62 #include "tensorflow/core/graph/subgraph.h" 63 #include "tensorflow/core/graph/tensor_id.h" 64 #include "tensorflow/core/lib/core/errors.h" 65 #include "tensorflow/core/lib/core/notification.h" 66 #include "tensorflow/core/lib/core/refcount.h" 67 #include "tensorflow/core/lib/core/status.h" 68 #include "tensorflow/core/lib/gtl/array_slice.h" 69 #include "tensorflow/core/lib/gtl/stl_util.h" 70 #include "tensorflow/core/lib/monitoring/counter.h" 71 #include "tensorflow/core/lib/strings/numbers.h" 72 #include "tensorflow/core/lib/strings/str_util.h" 73 #include "tensorflow/core/lib/strings/strcat.h" 74 #include "tensorflow/core/platform/cpu_info.h" 75 #include "tensorflow/core/platform/device_tracer.h" 76 #include "tensorflow/core/platform/logging.h" 77 #include "tensorflow/core/platform/mutex.h" 78 #include "tensorflow/core/platform/types.h" 79 #include "tensorflow/core/util/device_name_utils.h" 80 #include "tensorflow/core/util/env_var.h" 88 "/tensorflow/core/tbb_session_runs",
89 "The number of times TBBSession::Run() has been called.");
94 string GetRendezvousKey(
const string& tensor_name,
95 const DeviceAttributes& device_info,
96 const FrameAndIter& frame_iter) {
97 return strings::StrCat(device_info.name(),
";",
98 strings::FpToString(device_info.incarnation()),
";",
99 device_info.name(),
";", tensor_name,
";",
100 frame_iter.frame_id,
":", frame_iter.iter_id);
110 return options.target ==
"tbb";
115 if (options.config.graph_options().build_cost_model() > 0) {
116 EnableCPUAllocatorFullStats(
true);
118 std::vector<Device*> devices;
119 const Status s = DeviceFactory::AddDevices(
120 options,
"/job:localhost/replica:0/task:0", &devices);
127 new TBBSession(options,
new DeviceMgr(devices),
this);
130 sessions_.push_back(session);
136 const std::vector<string>& containers)
override {
137 std::vector<TBBSession*> sessions_to_reset;
146 for (
auto session : sessions_to_reset) {
147 s.Update(
session->Reset(containers));
151 for (
auto session : sessions_to_reset) {
165 std::vector<TBBSession*> sessions_
GUARDED_BY(sessions_lock_);
195 arena.execute( [&g, &
c] () {g.run(
c ); } );
199 const DeviceMgr* device_mgr,
202 device_mgr_(device_mgr),
204 cancellation_manager_(new CancellationManager()),
205 operation_timeout_in_ms_(options_.
config.operation_timeout_in_ms()) {
211 LOG(
ERROR) << status.error_message();
217 int devices_added = 0;
218 if (options.config.log_device_placement()) {
219 const string mapping_str =
device_mgr_->DeviceMappingString();
220 if (mapping_str.empty()) {
221 printf(
"Device mapping: no known devices.\n");
223 printf(
"Device mapping:\n%s", mapping_str.c_str());
225 LOG(
INFO) <<
"Device mapping:\n" << mapping_str;
234 if (devices_added == 0) {
242 if (!closed_)
Close().IgnoreError();
243 for (
auto& it : partial_runs_) {
244 it.second.reset(
nullptr);
246 for (
auto& it : executors_) {
254 execution_state_.reset(
nullptr);
259 const GraphDef& graph,
bool* out_already_initialized) {
262 *out_already_initialized =
true;
269 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
270 GraphExecutionStateOptions
options;
272 options.session_options = &
options_;
281 GraphDef
temp(graph);
283 GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
284 graph_created_ =
true;
285 *out_already_initialized =
false;
291 if (graph.node_size() > 0) {
293 if (graph_created_) {
294 return errors::AlreadyExists(
295 "A Graph has already been created for this session.");
309 bool already_initialized;
314 if (already_initialized) {
315 TF_RETURN_IF_ERROR(
flib_def_->AddLibrary(graph.library()));
316 std::unique_ptr<GraphExecutionState> state;
317 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
318 execution_state_.swap(state);
324 const std::vector<string>& output_names,
325 const std::vector<string>& target_nodes,
326 std::vector<Tensor>*
outputs) {
327 RunMetadata run_metadata;
328 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
333 const DebugOptions& debug_options, int64 session_run_index,
334 int64 executor_step_index,
const std::vector<string>& input_names,
335 const std::vector<string>& output_names,
336 const std::vector<string>& target_names,
337 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
339 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
340 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
341 debug_options.global_step(), session_run_index, executor_step_index,
342 input_names, output_names, target_names));
347 const DebugOptions& debug_options, Graph* graph, Device* device) {
348 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
350 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
352 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
353 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
359 const std::vector<string>& output_names,
360 const std::vector<string>& target_nodes,
362 RunMetadata* run_metadata) {
364 tbb_session_runs->GetCell()->IncrementBy(1);
367 if (!graph_created_) {
368 return errors::InvalidArgument(
369 "Session was not created with a graph before Run()!");
374 std::vector<string> input_tensor_names;
375 input_tensor_names.reserve(inputs.size());
376 for (
const auto& it : inputs) {
377 input_tensor_names.push_back(it.first);
382 RunStateArgs run_state_args(run_options.debug_options());
389 &executors_and_keys, &run_state_args));
390 const int64 executor_step_count = executors_and_keys->
step_count.fetch_add(1);
392 std::unique_ptr<DebuggerStateInterface> debugger_state;
393 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
395 run_options.debug_options(), args.step_id, executor_step_count,
396 input_tensor_names, output_names, target_nodes, &debugger_state));
401 FunctionCallFrame call_frame(executors_and_keys->
input_types,
403 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
404 for (
const auto& it : inputs) {
405 if (it.second.dtype() == DT_RESOURCE) {
406 Tensor tensor_from_handle;
415 const Status s = call_frame.SetArgs(feed_args);
416 if (errors::IsInternal(s)) {
417 return errors::InvalidArgument(s.error_message());
418 }
else if (!s.ok()) {
425 CancellationManager step_cancellation_manager;
426 args.call_frame = &call_frame;
430 tbb::task_arena taskArena;
431 tbb::task_group taskGroup;
433 auto doneWithTaskGroup = [&taskArena, &taskGroup](
void *) { taskArena.execute([&taskGroup]() { taskGroup.wait();}); };
434 std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup) > guard(&taskGroup, doneWithTaskGroup);
437 const size_t num_executors = executors_and_keys->
items.size();
438 ExecutorBarrier* barrier =
new ExecutorBarrier(
439 num_executors, run_state.
rendez, [&run_state](
const Status& ret) {
441 mutex_lock l(run_state.mu_);
442 run_state.status.Update(ret);
447 args.rendezvous = run_state.rendez;
448 args.cancellation_manager = &step_cancellation_manager;
451 args.tensor_store = &run_state.tensor_store;
452 args.step_container = &run_state.step_container;
453 if (LogMemory::IsEnabled()) {
454 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
458 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
460 bool update_cost_model =
false;
461 if (
options_.config.graph_options().build_cost_model() > 0) {
462 const int64 build_cost_model_every =
463 options_.config.graph_options().build_cost_model();
464 const int64 build_cost_model_after =
465 options_.config.graph_options().build_cost_model_after();
466 int64 measure_step_count = executor_step_count - build_cost_model_after;
467 if (measure_step_count >= 0) {
469 ((measure_step_count + 1) % build_cost_model_every == 0);
472 if (do_trace || update_cost_model ||
473 run_options.report_tensor_allocations_upon_oom()) {
474 run_state.collector.reset(
475 new StepStatsCollector(run_metadata->mutable_step_stats()));
476 args.stats_collector = run_state.collector.get();
479 std::unique_ptr<DeviceTracer> tracer;
480 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
481 tracer = CreateDeviceTracer();
486 run_state.executors_done.Notify();
495 const CancellationToken cancellation_token =
498 cancellation_token, [&step_cancellation_manager]() {
499 step_cancellation_manager.StartCancel();
501 if (already_cancelled) {
505 run_state.executors_done.Notify();
507 return errors::Cancelled(
"Run call was cancelled");
512 Executor::Args::Runner default_runner = [
this, &taskArena, &taskGroup](Executor::Args::Closure
c) {
515 for (
const auto& item : executors_and_keys->items) {
528 args.runner = default_runner;
529 item.executor->RunAsync(
args, barrier->Get());
535 run_options.timeout_in_ms() > 0
536 ? run_options.timeout_in_ms()
542 mutex_lock
l(run_state.mu_);
543 run_state.status.Update(errors::Cancelled(
"Run call was cancelled"));
547 TF_RETURN_IF_ERROR(tracer->Stop());
548 TF_RETURN_IF_ERROR(tracer->Collect(
args.stats_collector));
552 mutex_lock
l(run_state.mu_);
553 TF_RETURN_IF_ERROR(run_state.status);
558 std::vector<Tensor> sorted_outputs;
559 const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
560 if (errors::IsInternal(s)) {
561 return errors::InvalidArgument(s.error_message());
562 }
else if (!s.ok()) {
565 const bool unique_outputs =
566 output_names.size() == executors_and_keys->output_name_to_index.size();
569 std::vector<int> first_indices;
570 if (!unique_outputs) {
571 first_indices.resize(output_names.size());
572 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
573 for (
int j = 0; j <=
i; ++j) {
574 if (output_names[
i] == output_names[j]) {
575 first_indices[
i] = j;
582 outputs->reserve(sorted_outputs.size());
583 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
584 const string& output_name = output_names[
i];
585 if (first_indices.empty() || first_indices[
i] ==
i) {
586 outputs->emplace_back(
587 std::move(sorted_outputs[executors_and_keys
588 ->output_name_to_index[output_name]]));
590 outputs->push_back((*outputs)[first_indices[
i]]);
597 run_state.tensor_store.SaveTensors(output_names, &
session_state_));
598 if (
args.stats_collector) {
599 args.stats_collector->Finalize();
604 if (update_cost_model) {
606 std::unordered_map<string, const Graph*> device_to_graph;
608 executors_and_keys->items) {
609 const Graph* graph = partition.
graph;
610 const string device = partition.
flib->device()->name();
611 device_to_graph[device] = graph;
616 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
617 for (
const auto& item : executors_and_keys->items) {
624 if (run_options.output_partition_graphs()) {
625 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
626 run_metadata->mutable_partition_graphs();
628 executors_and_keys->items) {
629 GraphDef* partition_graph_def = partition_graph_defs->Add();
630 exec_and_lib.
graph->ToGraphDef(partition_graph_def);
638 Tensor* retrieved_tensor) {
639 if (resource_tensor.dtype() != DT_RESOURCE) {
640 return errors::InvalidArgument(strings::StrCat(
641 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
642 resource_tensor.dtype()));
645 const ResourceHandle& resource_handle =
646 resource_tensor.scalar<ResourceHandle>()();
648 if (resource_handle.container() ==
649 SessionState::kTensorHandleResourceTypeName) {
650 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
652 return errors::InvalidArgument(strings::StrCat(
653 "Invalid resource type hash code: ", resource_handle.hash_code(),
654 "(name: ", resource_handle.name(),
655 " type: ", resource_handle.maybe_type_name(),
656 "). Perhaps a resource tensor was being provided as a feed? That is " 657 "not currently allowed. Please file an issue at " 658 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " 659 "short code snippet that leads to this error message."));
664 gtl::ArraySlice<string>
inputs, gtl::ArraySlice<string>
outputs,
665 gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys,
667 int64 handle_name_counter_value = -1;
672 string debug_tensor_watches_summary;
673 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
674 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
679 const string key = strings::StrCat(
680 str_util::Join(inputs,
","),
"->", str_util::Join(outputs,
","),
"/",
681 str_util::Join(target_nodes,
","),
"/", run_state_args->
is_partial_run,
682 "/", debug_tensor_watches_summary);
684 if (handle_name_counter_value >= 0) {
686 strings::StrCat(key,
";", handle_name_counter_value);
692 auto it = executors_.find(key);
693 if (it != executors_.end()) {
694 *executors_and_keys = it->second.get();
705 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
706 std::sort(inputs_sorted.begin(), inputs_sorted.end());
707 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
708 std::sort(outputs_sorted.begin(), outputs_sorted.end());
709 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
710 std::sort(tn_sorted.begin(), tn_sorted.end());
712 const string sorted_key = strings::StrCat(
713 str_util::Join(inputs_sorted,
","),
"->",
714 str_util::Join(outputs_sorted,
","),
"/", str_util::Join(tn_sorted,
","),
715 "/", run_state_args->
is_partial_run,
"/", debug_tensor_watches_summary);
717 if (handle_name_counter_value >= 0) {
719 strings::StrCat(sorted_key,
";", handle_name_counter_value);
725 auto it = executors_.find(sorted_key);
726 if (it != executors_.end()) {
727 *executors_and_keys = it->second.get();
729 executors_.emplace(key, it->second);
736 options.feed_endpoints = inputs_sorted;
737 options.fetch_endpoints = outputs_sorted;
738 options.target_nodes = tn_sorted;
739 options.use_function_convention = !run_state_args->
is_partial_run;
740 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
748 std::unordered_map<string, std::unique_ptr<Graph>>
graphs;
749 TF_RETURN_IF_ERROR(
CreateGraphs(options, &graphs, &ek->flib_def,
750 run_state_args, &ek->input_types,
755 std::unordered_set<StringPiece, StringPieceHasher>
names;
756 for (
const string&
input : inputs) {
757 TensorId
id(ParseTensorName(
input));
758 names.emplace(
id.
first);
760 for (
const string&
output : outputs) {
761 TensorId
id(ParseTensorName(
output));
762 names.emplace(
id.
first);
764 for (Node*
n : ek->graph->nodes()) {
765 if (names.count(
n->name()) > 0) {
766 ek->name_to_node.insert({
n->name(),
n});
770 ek->items.reserve(graphs.size());
771 const auto& optimizer_opts =
772 options_.config.graph_options().optimizer_options();
774 int graph_def_version;
778 execution_state_->original_graph_def().versions().producer();
780 ek->proc_flr.reset(
new ProcessFunctionLibraryRuntime(
784 GraphOptimizer optimizer(optimizer_opts);
785 for (
auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
786 const string& partition_name = iter->first;
787 std::unique_ptr<Graph>& partition_graph = iter->second;
790 TF_RETURN_IF_ERROR(
device_mgr_->LookupDevice(partition_name, &device));
792 ek->items.resize(ek->items.size() + 1);
793 auto* item = &(ek->items.back());
794 auto lib = ek->proc_flr->GetFLR(partition_name);
795 if (
lib ==
nullptr) {
796 return errors::Internal(
"Could not find device: ", partition_name);
800 LocalExecutorParams params;
801 params.device = device;
802 params.function_library =
lib;
803 auto opseg = device->op_segment();
804 params.create_kernel = [
this,
lib, opseg](
const NodeDef& ndef,
812 if (!
lib->IsStateful(ndef.op()) ||
813 lib->GetFunctionLibraryDefinition()->Find(ndef.op()) !=
nullptr) {
814 return lib->CreateKernel(ndef, kernel);
816 auto create_fn = [
lib, &ndef](OpKernel** kernel) {
817 return lib->CreateKernel(ndef, kernel);
825 params.delete_kernel = [
lib](OpKernel* kernel) {
827 if (kernel && !
lib->IsStateful(kernel->type_string())) {
833 optimizer.Optimize(
lib,
options_.env, device, &iter->second,
837 if (!options.debug_options.debug_tensor_watch_opts().empty()) {
839 options.debug_options, partition_graph.get(), params.device));
842 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
844 partition_graph.get()));
846 item->graph = partition_graph.get();
847 item->executor =
nullptr;
848 item->device = device;
851 NewLocalExecutor(params, partition_graph.release(), &executor));
852 item->executor.reset(executor);
861 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
862 const string&
input = inputs_sorted[
i];
863 ek->input_name_to_index[
input] =
i;
865 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
866 const string&
output = outputs_sorted[
i];
867 ek->output_name_to_index[
output] =
i;
875 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
876 const string&
input = inputs_sorted[
i];
877 ek->input_name_to_rendezvous_key[
input] = GetRendezvousKey(
878 input,
device_set_.client_device()->attributes(), FrameAndIter(0, 0));
880 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
881 const string&
output = outputs_sorted[
i];
882 ek->output_name_to_rendezvous_key[
output] =
883 GetRendezvousKey(output,
device_set_.client_device()->attributes(),
893 auto insert_result = executors_.emplace(sorted_key, ek);
896 executors_.emplace(key, insert_result.first->second);
897 *executors_and_keys = insert_result.first->second.get();
903 const BuildGraphOptions& subgraph_options,
904 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
905 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
906 RunStateArgs* run_state_args, DataTypeVector* input_types,
907 DataTypeVector* output_types) {
909 std::unique_ptr<ClientGraph> client_graph;
911 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
912 GraphExecutionState* execution_state =
nullptr;
913 if (
options_.config.graph_options().place_pruned_graph()) {
917 GraphExecutionStateOptions prune_options;
919 prune_options.session_options = &
options_;
920 prune_options.stateful_placements = stateful_placements_;
921 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
922 execution_state_->original_graph_def().library(), prune_options,
923 execution_state_->original_graph_def(), subgraph_options,
924 &temp_exec_state_holder, &client_graph));
925 execution_state = temp_exec_state_holder.get();
927 execution_state = execution_state_.get();
929 execution_state->BuildGraph(subgraph_options, &client_graph));
932 if (subgraph_options.feed_endpoints.size() !=
933 client_graph->feed_types.size()) {
934 return errors::Internal(
935 "Graph pruning failed: requested number of feed endpoints = ",
936 subgraph_options.feed_endpoints.size(),
937 " versus number of pruned feed endpoints = ",
938 client_graph->feed_types.size());
940 if (subgraph_options.fetch_endpoints.size() !=
941 client_graph->fetch_types.size()) {
942 return errors::Internal(
943 "Graph pruning failed: requested number of fetch endpoints = ",
944 subgraph_options.fetch_endpoints.size(),
945 " versus number of pruned fetch endpoints = ",
946 client_graph->fetch_types.size());
949 auto current_stateful_placements = execution_state->GetStatefulPlacements();
953 for (
auto placement_pair : current_stateful_placements) {
954 const string& node_name = placement_pair.first;
955 const string& placement = placement_pair.second;
956 auto iter = stateful_placements_.find(node_name);
957 if (iter == stateful_placements_.end()) {
958 stateful_placements_.insert(std::make_pair(node_name, placement));
959 }
else if (iter->second != placement) {
960 return errors::Internal(
961 "Stateful placement mismatch. " 962 "Current assignment of ",
963 node_name,
" to ", iter->second,
" does not match ", placement);
967 stateful_placements_ = execution_state->GetStatefulPlacements();
972 CopyGraph(*execution_state->full_graph(), run_state_args->
graph.get());
976 PartitionOptions popts;
977 popts.node_to_loc = [](
const Node* node) {
978 assert(node !=
nullptr);
979 return node->assigned_device_name();
981 popts.new_name = [
this](
const string&
prefix) {
984 popts.get_incarnation = [](
const string&
name) {
989 popts.flib_def = &client_graph->graph.flib_def();
990 popts.control_flow_added =
false;
992 std::unordered_map<string, GraphDef>
partitions;
993 TF_RETURN_IF_ERROR(
Partition(popts, &client_graph->graph, &partitions));
995 std::vector<string> device_names;
998 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1002 for (
const auto& partition : partitions) {
1003 const string local_partition_name =
1004 DeviceNameUtils::LocalName(partition.first);
1005 if (
std::count(device_names.begin(), device_names.end(),
1006 local_partition_name) == 0) {
1007 return errors::InvalidArgument(
1008 "Creating a partition for ", local_partition_name,
1009 " which doesn't exist in the list of available devices. Available " 1011 str_util::Join(device_names,
","));
1015 for (
const auto& partition : partitions) {
1016 std::unique_ptr<Graph> device_graph(
1017 new Graph(client_graph->flib_def.get()));
1018 GraphConstructorOptions device_opts;
1020 device_opts.allow_internal_ops =
true;
1021 device_opts.expect_device_spec =
true;
1022 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1023 device_graph.get()));
1027 GraphOptimizationPassOptions optimization_options;
1028 optimization_options.session_options = &
options_;
1029 optimization_options.flib_def = client_graph->flib_def.get();
1030 optimization_options.partition_graphs =
outputs;
1031 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1032 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1035 for (
auto& partition : *
outputs) {
1036 const string& partition_name = partition.first;
1037 std::unique_ptr<Graph>* graph = &partition.second;
1039 VLOG(2) <<
"Created " << DebugString(graph->get()) <<
" for " 1044 s =
device_mgr_->LookupDevice(partition_name, &d);
1046 s = d->MaybeRewriteGraph(graph);
1051 *flib_def =
std::move(client_graph->flib_def);
1052 std::swap(*input_types, client_graph->feed_types);
1053 std::swap(*output_types, client_graph->fetch_types);
1058 std::vector<DeviceAttributes>* response) {
1060 response->reserve(
devices_.size());
1062 const DeviceAttributes& attrs =
d->attributes();
1063 response->emplace_back(attrs);
1069 const std::vector<string>& containers) {
1086 const std::vector<string>& pending_input_names,
1087 const std::vector<string>& pending_output_names, int64 step_id,
1088 const std::vector<Device*>* devices)
1089 : step_container(step_id, [devices](const
string&
name) {
1090 for (
auto d : *devices) {
1091 if (!
d->resource_manager()->Cleanup(
name).ok()) {
1097 for (
auto&
name : pending_input_names) {
1100 for (
auto&
name : pending_output_names) {
1106 const std::vector<Device*>* devices)
1107 :
RunState({}, {}, step_id, devices) {}
1112 rendez->StartAbort(errors::Cancelled(
"PRun cancellation"));
1121 if (!it.second)
return false;
1124 if (!it.second)
return false;
1130 RunState* run_state, CancellationManager* cm, int64 timeout_in_ms) {
1133 arena.execute([&taskGroup]() { taskGroup.wait();} );
1139 mutex_lock
l(run_state->
mu_);
1140 run_state->status.Update(status);
1151 Notification* notification, int64 timeout_in_ms) {
1152 if (timeout_in_ms > 0) {
1153 const int64 timeout_in_us = timeout_in_ms * 1000;
1154 const bool notified =
1155 WaitForNotificationWithTimeout(notification, timeout_in_us);
1157 return Status(error::DEADLINE_EXCEEDED,
1158 "Timed out waiting for notification");
1161 notification->WaitForNotification();
std::unique_ptr< FunctionLibraryDefinition > flib_def_
static boost::mutex mutex
std::vector< PerPartitionExecutorsAndLib > items
RunState(int64 step_id, const std::vector< Device * > *devices)
DataTypeVector output_types
std::unordered_map< string, size_t > input_name_to_index
static const HistoName names[]
Notification executors_done
::tensorflow::Status Close() override
std::vector< Device * > devices_
std::atomic< int64 > handle_name_counter_
std::vector< std::pair< string, Tensor > > NamedTensorList
std::atomic< int64 > edge_name_counter_
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
CostModelManager cost_model_manager_
std::unordered_map< string, bool > pending_inputs
::tensorflow::Status Create(const GraphDef &graph) override
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Session * NewSession(const SessionOptions &options) override
static std::string const input
FunctionLibraryRuntime * flib
void Deregister(const TBBSession *session)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
static NTSessionRegistrar registrar
IntraProcessRendezvous * rendez
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::vector< TBBSession * > sessions_ GUARDED_BY(sessions_lock_)
void SchedClosure(tbb::task_arena &arena, tbb::task_group &g, std::function< void()> c)
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
CancellationManager * cancellation_manager_
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
::tensorflow::Status Reset(const std::vector< string > &containers)
const DebugOptions & debug_options
::tensorflow::Status Extend(const GraphDef &graph) override
DataTypeVector input_types
std::pair< int, edm::FunctionWithDict > OK
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status CheckNotClosed()
std::atomic_int_fast64_t step_count
const SessionOptions options_
TBBSessionFactory *const factory_
static std::atomic_int_fast64_t step_id_counter_
def remove(d, key, TELL=False)
const std::unique_ptr< const DeviceMgr > device_mgr_
const int64 operation_timeout_in_ms_
Executor::Args::NodeOutputsCallback node_outputs_callback_
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
TBBSession(const SessionOptions &options, const DeviceMgr *device_mgr, TBBSessionFactory *factory)
std::unique_ptr< Graph > graph
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
SessionState session_state_
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
bool AcceptsOptions(const SessionOptions &options) override