37 #include "tbb/task_group.h" 41 #include "tensorflow/core/common_runtime/constant_folding.h" 42 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 43 #include "tensorflow/core/common_runtime/device_factory.h" 44 #include "tensorflow/core/common_runtime/executor.h" 45 #include "tensorflow/core/common_runtime/function.h" 46 #include "tensorflow/core/common_runtime/graph_optimizer.h" 47 #include "tensorflow/core/common_runtime/memory_types.h" 48 #include "tensorflow/core/common_runtime/optimization_registry.h" 49 #include "tensorflow/core/common_runtime/step_stats_collector.h" 50 #include "tensorflow/core/framework/function.h" 51 #include "tensorflow/core/framework/graph.pb_text.h" 52 #include "tensorflow/core/framework/graph.pb.h" 53 #include "tensorflow/core/framework/graph_def_util.h" 54 #include "tensorflow/core/framework/log_memory.h" 55 #include "tensorflow/core/framework/node_def.pb.h" 56 #include "tensorflow/core/framework/tensor.h" 57 #include "tensorflow/core/framework/versions.pb.h" 58 #include "tensorflow/core/graph/algorithm.h" 59 #include "tensorflow/core/graph/graph.h" 60 #include "tensorflow/core/graph/graph_constructor.h" 61 #include "tensorflow/core/graph/graph_partition.h" 62 #include "tensorflow/core/graph/subgraph.h" 63 #include "tensorflow/core/graph/tensor_id.h" 64 #include "tensorflow/core/lib/core/errors.h" 65 #include "tensorflow/core/lib/core/notification.h" 66 #include "tensorflow/core/lib/core/refcount.h" 67 #include "tensorflow/core/lib/core/status.h" 68 #include "tensorflow/core/lib/gtl/array_slice.h" 69 #include "tensorflow/core/lib/gtl/stl_util.h" 70 #include "tensorflow/core/lib/monitoring/counter.h" 71 #include "tensorflow/core/lib/strings/numbers.h" 72 #include "tensorflow/core/lib/strings/str_util.h" 73 #include "tensorflow/core/lib/strings/strcat.h" 74 #include "tensorflow/core/platform/cpu_info.h" 75 #include "tensorflow/core/platform/device_tracer.h" 76 #include "tensorflow/core/platform/logging.h" 77 #include "tensorflow/core/platform/mutex.h" 78 #include "tensorflow/core/platform/types.h" 79 #include "tensorflow/core/util/device_name_utils.h" 80 #include "tensorflow/core/util/env_var.h" 87 "/tensorflow/core/tbb_session_runs",
88 "The number of times TBBSession::Run() has been called.");
93 string GetRendezvousKey(
const string& tensor_name,
94 const DeviceAttributes& device_info,
95 const FrameAndIter& frame_iter) {
96 return strings::StrCat(device_info.name(),
";",
97 strings::FpToString(device_info.incarnation()),
";",
98 device_info.name(),
";", tensor_name,
";",
99 frame_iter.frame_id,
":", frame_iter.iter_id);
109 return options.target ==
"tbb";
114 if (options.config.graph_options().build_cost_model() > 0) {
115 EnableCPUAllocatorFullStats(
true);
117 std::vector<Device*> devices;
118 const Status s = DeviceFactory::AddDevices(
119 options,
"/job:localhost/replica:0/task:0", &devices);
126 new TBBSession(options,
new DeviceMgr(devices),
this);
129 sessions_.push_back(session);
135 const std::vector<string>& containers)
override {
136 std::vector<TBBSession*> sessions_to_reset;
145 for (
auto session : sessions_to_reset) {
146 s.Update(
session->Reset(containers));
150 for (
auto session : sessions_to_reset) {
164 std::vector<TBBSession*> sessions_
GUARDED_BY(sessions_lock_);
194 arena.execute( [&g, &
c] () {g.run(
c ); } );
198 const DeviceMgr* device_mgr,
201 device_mgr_(device_mgr),
203 cancellation_manager_(new CancellationManager()),
204 operation_timeout_in_ms_(options_.
config.operation_timeout_in_ms()) {
210 LOG(
ERROR) << status.error_message();
216 int devices_added = 0;
217 if (options.config.log_device_placement()) {
218 const string mapping_str =
device_mgr_->DeviceMappingString();
219 if (mapping_str.empty()) {
220 printf(
"Device mapping: no known devices.\n");
222 printf(
"Device mapping:\n%s", mapping_str.c_str());
224 LOG(
INFO) <<
"Device mapping:\n" << mapping_str;
233 if (devices_added == 0) {
241 if (!closed_)
Close().IgnoreError();
242 for (
auto& it : partial_runs_) {
243 it.second.reset(
nullptr);
245 for (
auto& it : executors_) {
252 d->ClearResourceMgr();
257 execution_state_.reset(
nullptr);
262 const GraphDef& graph,
bool* out_already_initialized) {
265 *out_already_initialized =
true;
272 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
273 GraphExecutionStateOptions
options;
275 options.session_options = &
options_;
284 GraphDef
temp(graph);
286 GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
287 graph_created_ =
true;
288 *out_already_initialized =
false;
294 if (graph.node_size() > 0) {
296 if (graph_created_) {
297 return errors::AlreadyExists(
298 "A Graph has already been created for this session.");
312 bool already_initialized;
317 if (already_initialized) {
318 TF_RETURN_IF_ERROR(
flib_def_->AddLibrary(graph.library()));
319 std::unique_ptr<GraphExecutionState> state;
320 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
321 execution_state_.swap(state);
327 const std::vector<string>& output_names,
328 const std::vector<string>& target_nodes,
329 std::vector<Tensor>*
outputs) {
330 RunMetadata run_metadata;
331 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
336 const DebugOptions& debug_options, int64 session_run_index,
337 int64 executor_step_index,
const std::vector<string>&
input_names,
338 const std::vector<string>& output_names,
339 const std::vector<string>& target_names,
340 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
342 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
343 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
344 debug_options.global_step(), session_run_index, executor_step_index,
350 const DebugOptions& debug_options,
Graph* graph, Device* device) {
351 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
353 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
355 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
356 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
362 const std::vector<string>& output_names,
363 const std::vector<string>& target_nodes,
365 RunMetadata* run_metadata) {
367 tbb_session_runs->GetCell()->IncrementBy(1);
370 if (!graph_created_) {
371 return errors::InvalidArgument(
372 "Session was not created with a graph before Run()!");
377 std::vector<string> input_tensor_names;
378 input_tensor_names.reserve(inputs.size());
379 for (
const auto& it : inputs) {
380 input_tensor_names.push_back(it.first);
385 RunStateArgs run_state_args(run_options.debug_options());
392 target_nodes, &executors_and_keys,
394 const int64 executor_step_count = executors_and_keys->
step_count.fetch_add(1);
396 std::unique_ptr<DebuggerStateInterface> debugger_state;
397 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
399 run_options.debug_options(), args.step_id, executor_step_count,
400 input_tensor_names, output_names, target_nodes, &debugger_state));
405 FunctionCallFrame call_frame(executors_and_keys->
input_types,
407 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
408 for (
const auto& it : inputs) {
409 if (it.second.dtype() == DT_RESOURCE) {
410 Tensor tensor_from_handle;
419 const Status s = call_frame.SetArgs(feed_args);
420 if (errors::IsInternal(s)) {
421 return errors::InvalidArgument(s.error_message());
422 }
else if (!s.ok()) {
429 CancellationManager step_cancellation_manager;
430 args.call_frame = &call_frame;
434 tbb::task_arena taskArena;
435 tbb::task_group taskGroup;
437 auto doneWithTaskGroup = [&taskArena, &taskGroup](
void *) { taskArena.execute([&taskGroup]() { taskGroup.wait();}); };
438 std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup) > guard(&taskGroup, doneWithTaskGroup);
441 const size_t num_executors = executors_and_keys->
items.size();
442 ExecutorBarrier* barrier =
new ExecutorBarrier(
443 num_executors, run_state.
rendez, [&run_state](
const Status& ret) {
445 mutex_lock l(run_state.mu_);
446 run_state.status.Update(ret);
451 args.rendezvous = run_state.rendez;
452 args.cancellation_manager = &step_cancellation_manager;
455 args.tensor_store = &run_state.tensor_store;
456 args.step_container = &run_state.step_container;
457 if (LogMemory::IsEnabled()) {
458 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
462 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
464 bool update_cost_model =
false;
465 if (
options_.config.graph_options().build_cost_model() > 0) {
466 const int64 build_cost_model_every =
467 options_.config.graph_options().build_cost_model();
468 const int64 build_cost_model_after =
469 options_.config.graph_options().build_cost_model_after();
470 int64 measure_step_count = executor_step_count - build_cost_model_after;
471 if (measure_step_count >= 0) {
473 ((measure_step_count + 1) % build_cost_model_every == 0);
476 if (do_trace || update_cost_model ||
477 run_options.report_tensor_allocations_upon_oom()) {
478 run_state.collector.reset(
479 new StepStatsCollector(run_metadata->mutable_step_stats()));
480 args.stats_collector = run_state.collector.get();
483 std::unique_ptr<DeviceTracer> tracer;
484 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
485 tracer = CreateDeviceTracer();
490 run_state.executors_done.Notify();
499 const CancellationToken cancellation_token =
502 cancellation_token, [&step_cancellation_manager]() {
503 step_cancellation_manager.StartCancel();
505 if (already_cancelled) {
509 run_state.executors_done.Notify();
511 return errors::Cancelled(
"Run call was cancelled");
516 Executor::Args::Runner default_runner = [
this, &taskArena, &taskGroup](Executor::Args::Closure
c) {
519 for (
const auto& item : executors_and_keys->items) {
532 args.runner = default_runner;
533 item.executor->RunAsync(
args, barrier->Get());
539 run_options.timeout_in_ms() > 0
540 ? run_options.timeout_in_ms()
546 mutex_lock
l(run_state.mu_);
547 run_state.status.Update(errors::Cancelled(
"Run call was cancelled"));
551 TF_RETURN_IF_ERROR(tracer->Stop());
552 TF_RETURN_IF_ERROR(tracer->Collect(
args.stats_collector));
556 mutex_lock
l(run_state.mu_);
557 TF_RETURN_IF_ERROR(run_state.status);
562 std::vector<Tensor> sorted_outputs;
563 const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
564 if (errors::IsInternal(s)) {
565 return errors::InvalidArgument(s.error_message());
566 }
else if (!s.ok()) {
569 const bool unique_outputs =
570 output_names.size() == executors_and_keys->output_name_to_index.size();
573 std::vector<int> first_indices;
574 if (!unique_outputs) {
575 first_indices.resize(output_names.size());
576 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
577 for (
int j = 0; j <=
i; ++j) {
578 if (output_names[
i] == output_names[j]) {
579 first_indices[
i] = j;
586 outputs->reserve(sorted_outputs.size());
587 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
588 const string& output_name = output_names[
i];
589 if (first_indices.empty() || first_indices[
i] ==
i) {
590 outputs->emplace_back(
591 std::move(sorted_outputs[executors_and_keys
592 ->output_name_to_index[output_name]]));
594 outputs->push_back((*outputs)[first_indices[
i]]);
601 run_state.tensor_store.SaveTensors(output_names, &
session_state_));
602 if (
args.stats_collector) {
603 args.stats_collector->Finalize();
608 if (update_cost_model) {
610 std::unordered_map<string, const Graph*> device_to_graph;
612 executors_and_keys->items) {
614 const string device = partition.
flib->device()->name();
615 device_to_graph[device] = graph;
620 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
621 for (
const auto& item : executors_and_keys->items) {
628 if (run_options.output_partition_graphs()) {
629 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
630 run_metadata->mutable_partition_graphs();
632 executors_and_keys->items) {
633 GraphDef* partition_graph_def = partition_graph_defs->Add();
634 exec_and_lib.
graph->ToGraphDef(partition_graph_def);
642 Tensor* retrieved_tensor) {
643 if (resource_tensor.dtype() != DT_RESOURCE) {
644 return errors::InvalidArgument(strings::StrCat(
645 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
646 resource_tensor.dtype()));
649 const ResourceHandle& resource_handle =
650 resource_tensor.scalar<ResourceHandle>()();
652 if (resource_handle.container() ==
653 SessionState::kTensorHandleResourceTypeName) {
654 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
656 return errors::InvalidArgument(strings::StrCat(
657 "Invalid resource type hash code: ", resource_handle.hash_code(),
658 "(name: ", resource_handle.name(),
659 " type: ", resource_handle.maybe_type_name(),
660 "). Perhaps a resource tensor was being provided as a feed? That is " 661 "not currently allowed. Please file an issue at " 662 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " 663 "short code snippet that leads to this error message."));
668 gtl::ArraySlice<string>
inputs, gtl::ArraySlice<string>
outputs,
669 gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys,
671 int64 handle_name_counter_value = -1;
676 string debug_tensor_watches_summary;
677 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
678 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
683 const string key = strings::StrCat(
684 str_util::Join(inputs,
","),
"->", str_util::Join(outputs,
","),
"/",
685 str_util::Join(target_nodes,
","),
"/", run_state_args->
is_partial_run,
686 "/", debug_tensor_watches_summary);
688 if (handle_name_counter_value >= 0) {
690 strings::StrCat(key,
";", handle_name_counter_value);
696 auto it = executors_.find(key);
697 if (it != executors_.end()) {
698 *executors_and_keys = it->second.get();
709 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
710 std::sort(inputs_sorted.begin(), inputs_sorted.end());
711 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
712 std::sort(outputs_sorted.begin(), outputs_sorted.end());
713 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
714 std::sort(tn_sorted.begin(), tn_sorted.end());
716 const string sorted_key = strings::StrCat(
717 str_util::Join(inputs_sorted,
","),
"->",
718 str_util::Join(outputs_sorted,
","),
"/", str_util::Join(tn_sorted,
","),
719 "/", run_state_args->
is_partial_run,
"/", debug_tensor_watches_summary);
721 if (handle_name_counter_value >= 0) {
723 strings::StrCat(sorted_key,
";", handle_name_counter_value);
729 auto it = executors_.find(sorted_key);
730 if (it != executors_.end()) {
731 *executors_and_keys = it->second.get();
733 executors_.emplace(key, it->second);
740 options.feed_endpoints = inputs_sorted;
741 options.fetch_endpoints = outputs_sorted;
742 options.target_nodes = tn_sorted;
743 options.use_function_convention = !run_state_args->
is_partial_run;
744 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
748 std::unique_ptr<FunctionInfo> func_info(
new FunctionInfo);
753 std::unordered_map<string, std::unique_ptr<Graph>>
graphs;
754 TF_RETURN_IF_ERROR(
CreateGraphs(options, &graphs, &func_info->flib_def,
755 run_state_args, &ek->input_types,
760 std::unordered_set<StringPiece, StringPieceHasher>
names;
761 for (
const string&
input : inputs) {
762 TensorId
id(ParseTensorName(
input));
763 names.emplace(
id.
first);
765 for (
const string&
output : outputs) {
766 TensorId
id(ParseTensorName(
output));
767 names.emplace(
id.
first);
769 for (
Node*
n : ek->graph->nodes()) {
770 if (names.count(
n->name()) > 0) {
771 ek->name_to_node.insert({
n->name(),
n});
775 ek->items.reserve(graphs.size());
776 const auto& optimizer_opts =
777 options_.config.graph_options().optimizer_options();
779 int graph_def_version;
783 execution_state_->original_graph_def().versions().producer();
785 func_info->proc_flr.reset(
new ProcessFunctionLibraryRuntime(
787 func_info->flib_def.get(), optimizer_opts));
789 GraphOptimizer optimizer(optimizer_opts);
790 for (
auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
791 const string& partition_name = iter->first;
792 std::unique_ptr<Graph>& partition_graph = iter->second;
795 TF_RETURN_IF_ERROR(
device_mgr_->LookupDevice(partition_name, &device));
797 ek->items.resize(ek->items.size() + 1);
798 auto* item = &(ek->items.back());
799 auto lib = func_info->proc_flr->GetFLR(partition_name);
800 if (
lib ==
nullptr) {
801 return errors::Internal(
"Could not find device: ", partition_name);
805 LocalExecutorParams params;
806 params.device = device;
807 params.function_library =
lib;
808 auto opseg = device->op_segment();
809 params.create_kernel = [
this,
lib, opseg](
const NodeDef& ndef,
817 if (!
lib->IsStateful(ndef.op()) ||
818 lib->GetFunctionLibraryDefinition()->Find(ndef.op()) !=
nullptr) {
819 return lib->CreateKernel(ndef, kernel);
821 auto create_fn = [
lib, &ndef](OpKernel** kernel) {
822 return lib->CreateKernel(ndef, kernel);
830 params.delete_kernel = [
lib](OpKernel* kernel) {
832 if (kernel && !
lib->IsStateful(kernel->type_string())) {
838 optimizer.Optimize(
lib,
options_.env, device, &iter->second,
842 if (!options.debug_options.debug_tensor_watch_opts().empty()) {
844 options.debug_options, partition_graph.get(), params.device));
847 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
849 partition_graph.get()));
851 item->graph = partition_graph.get();
852 item->executor =
nullptr;
853 item->device = device;
856 NewLocalExecutor(params, partition_graph.release(), &executor));
857 item->executor.reset(executor);
866 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
867 const string&
input = inputs_sorted[
i];
868 ek->input_name_to_index[
input] =
i;
870 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
871 const string&
output = outputs_sorted[
i];
872 ek->output_name_to_index[
output] =
i;
880 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
881 const string&
input = inputs_sorted[
i];
882 ek->input_name_to_rendezvous_key[
input] = GetRendezvousKey(
883 input,
device_set_.client_device()->attributes(), FrameAndIter(0, 0));
885 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
886 const string&
output = outputs_sorted[
i];
887 ek->output_name_to_rendezvous_key[
output] =
888 GetRendezvousKey(output,
device_set_.client_device()->attributes(),
895 functions_.push_back(
std::move(func_info));
899 auto insert_result = executors_.emplace(sorted_key, ek);
902 executors_.emplace(key, insert_result.first->second);
903 *executors_and_keys = insert_result.first->second.get();
909 const BuildGraphOptions& subgraph_options,
910 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
911 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
912 RunStateArgs* run_state_args, DataTypeVector* input_types,
913 DataTypeVector* output_types) {
915 std::unique_ptr<ClientGraph> client_graph;
917 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
918 GraphExecutionState* execution_state =
nullptr;
919 if (
options_.config.graph_options().place_pruned_graph()) {
923 GraphExecutionStateOptions prune_options;
925 prune_options.session_options = &
options_;
926 prune_options.stateful_placements = stateful_placements_;
927 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
928 execution_state_->original_graph_def().library(), prune_options,
929 execution_state_->original_graph_def(), subgraph_options,
930 &temp_exec_state_holder, &client_graph));
931 execution_state = temp_exec_state_holder.get();
933 execution_state = execution_state_.get();
935 execution_state->BuildGraph(subgraph_options, &client_graph));
938 if (subgraph_options.feed_endpoints.size() !=
939 client_graph->feed_types.size()) {
940 return errors::Internal(
941 "Graph pruning failed: requested number of feed endpoints = ",
942 subgraph_options.feed_endpoints.size(),
943 " versus number of pruned feed endpoints = ",
944 client_graph->feed_types.size());
946 if (subgraph_options.fetch_endpoints.size() !=
947 client_graph->fetch_types.size()) {
948 return errors::Internal(
949 "Graph pruning failed: requested number of fetch endpoints = ",
950 subgraph_options.fetch_endpoints.size(),
951 " versus number of pruned fetch endpoints = ",
952 client_graph->fetch_types.size());
955 auto current_stateful_placements = execution_state->GetStatefulPlacements();
959 for (
auto placement_pair : current_stateful_placements) {
960 const string& node_name = placement_pair.first;
961 const string& placement = placement_pair.second;
962 auto iter = stateful_placements_.find(node_name);
963 if (iter == stateful_placements_.end()) {
964 stateful_placements_.insert(std::make_pair(node_name, placement));
965 }
else if (iter->second != placement) {
966 return errors::Internal(
967 "Stateful placement mismatch. " 968 "Current assignment of ",
969 node_name,
" to ", iter->second,
" does not match ", placement);
973 stateful_placements_ = execution_state->GetStatefulPlacements();
978 CopyGraph(*execution_state->full_graph(), run_state_args->
graph.get());
982 PartitionOptions popts;
983 popts.node_to_loc = [](
const Node* node) {
984 assert(node !=
nullptr);
985 return node->assigned_device_name();
987 popts.new_name = [
this](
const string&
prefix) {
990 popts.get_incarnation = [](
const string&
name) {
995 popts.flib_def = &client_graph->graph.flib_def();
996 popts.control_flow_added =
false;
998 std::unordered_map<string, GraphDef>
partitions;
999 TF_RETURN_IF_ERROR(
Partition(popts, &client_graph->graph, &partitions));
1001 std::vector<string> device_names;
1004 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1008 for (
const auto& partition : partitions) {
1009 const string local_partition_name =
1010 DeviceNameUtils::LocalName(partition.first);
1011 if (
std::count(device_names.begin(), device_names.end(),
1012 local_partition_name) == 0) {
1013 return errors::InvalidArgument(
1014 "Creating a partition for ", local_partition_name,
1015 " which doesn't exist in the list of available devices. Available " 1017 str_util::Join(device_names,
","));
1021 for (
const auto& partition : partitions) {
1022 std::unique_ptr<Graph> device_graph(
1023 new Graph(client_graph->flib_def.get()));
1024 GraphConstructorOptions device_opts;
1026 device_opts.allow_internal_ops =
true;
1027 device_opts.expect_device_spec =
true;
1028 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1029 device_graph.get()));
1033 GraphOptimizationPassOptions optimization_options;
1034 optimization_options.session_options = &
options_;
1035 optimization_options.flib_def = client_graph->flib_def.get();
1036 optimization_options.partition_graphs =
outputs;
1037 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1038 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1041 for (
auto& partition : *
outputs) {
1042 const string& partition_name = partition.first;
1043 std::unique_ptr<Graph>* graph = &partition.second;
1045 VLOG(2) <<
"Created " << DebugString(graph->get()) <<
" for " 1050 s =
device_mgr_->LookupDevice(partition_name, &d);
1052 s = d->MaybeRewriteGraph(graph);
1057 *flib_def =
std::move(client_graph->flib_def);
1058 std::swap(*input_types, client_graph->feed_types);
1059 std::swap(*output_types, client_graph->fetch_types);
1064 std::vector<DeviceAttributes>* response) {
1066 response->reserve(
devices_.size());
1068 const DeviceAttributes& attrs =
d->attributes();
1069 response->emplace_back(attrs);
1075 const std::vector<string>& containers) {
1092 const std::vector<string>& pending_input_names,
1093 const std::vector<string>& pending_output_names, int64 step_id,
1094 const std::vector<Device*>* devices)
1095 : step_container(step_id, [devices](const
string&
name) {
1096 for (
auto d : *devices) {
1097 if (!
d->resource_manager()->Cleanup(
name).ok()) {
1103 for (
auto&
name : pending_input_names) {
1106 for (
auto&
name : pending_output_names) {
1112 const std::vector<Device*>* devices)
1113 :
RunState({}, {}, step_id, devices) {}
1118 rendez->StartAbort(errors::Cancelled(
"PRun cancellation"));
1127 if (!it.second)
return false;
1130 if (!it.second)
return false;
1136 RunState* run_state, CancellationManager* cm, int64 timeout_in_ms) {
1139 arena.execute([&taskGroup]() { taskGroup.wait();} );
1145 mutex_lock
l(run_state->
mu_);
1146 run_state->status.Update(status);
1157 Notification* notification, int64 timeout_in_ms) {
1158 if (timeout_in_ms > 0) {
1159 const int64 timeout_in_us = timeout_in_ms * 1000;
1160 const bool notified =
1161 WaitForNotificationWithTimeout(notification, timeout_in_us);
1163 return Status(error::DEADLINE_EXCEEDED,
1164 "Timed out waiting for notification");
1167 notification->WaitForNotification();
std::unique_ptr< FunctionLibraryDefinition > flib_def_
static boost::mutex mutex
std::vector< PerPartitionExecutorsAndLib > items
RunState(int64 step_id, const std::vector< Device * > *devices)
DataTypeVector output_types
std::unordered_map< string, size_t > input_name_to_index
Notification executors_done
::tensorflow::Status Close() override
std::vector< Device * > devices_
std::atomic< int64 > handle_name_counter_
std::vector< std::pair< string, Tensor > > NamedTensorList
std::atomic< int64 > edge_name_counter_
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
CostModelManager cost_model_manager_
std::unordered_map< string, bool > pending_inputs
::tensorflow::Status Create(const GraphDef &graph) override
const std::string names[nVars_]
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Session * NewSession(const SessionOptions &options) override
static std::string const input
FunctionLibraryRuntime * flib
void Deregister(const TBBSession *session)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
static NTSessionRegistrar registrar
IntraProcessRendezvous * rendez
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::vector< TBBSession * > sessions_ GUARDED_BY(sessions_lock_)
void SchedClosure(tbb::task_arena &arena, tbb::task_group &g, std::function< void()> c)
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
CancellationManager * cancellation_manager_
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
::tensorflow::Status Reset(const std::vector< string > &containers)
const DebugOptions & debug_options
::tensorflow::Status Extend(const GraphDef &graph) override
DataTypeVector input_types
std::pair< int, edm::FunctionWithDict > OK
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status CheckNotClosed()
std::atomic_int_fast64_t step_count
const SessionOptions options_
TBBSessionFactory *const factory_
static std::atomic_int_fast64_t step_id_counter_
def remove(d, key, TELL=False)
const std::unique_ptr< const DeviceMgr > device_mgr_
const int64 operation_timeout_in_ms_
Executor::Args::NodeOutputsCallback node_outputs_callback_
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
TBBSession(const SessionOptions &options, const DeviceMgr *device_mgr, TBBSessionFactory *factory)
std::unique_ptr< Graph > graph
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
SessionState session_state_
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
DDCompactView::Graph Graph
bool AcceptsOptions(const SessionOptions &options) override