40 #include "tbb/task_group.h" 42 #include "tensorflow/core/common_runtime/constant_folding.h" 43 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 44 #include "tensorflow/core/common_runtime/device_factory.h" 45 #include "tensorflow/core/common_runtime/executor.h" 46 #include "tensorflow/core/common_runtime/function.h" 47 #include "tensorflow/core/common_runtime/graph_optimizer.h" 48 #include "tensorflow/core/common_runtime/memory_types.h" 49 #include "tensorflow/core/common_runtime/optimization_registry.h" 50 #include "tensorflow/core/common_runtime/simple_placer.h" 51 #include "tensorflow/core/common_runtime/step_stats_collector.h" 52 #include "tensorflow/core/framework/function.h" 53 #include "tensorflow/core/framework/graph.pb_text.h" 54 #include "tensorflow/core/framework/graph.pb.h" 55 #include "tensorflow/core/framework/graph_def_util.h" 56 #include "tensorflow/core/framework/log_memory.h" 57 #include "tensorflow/core/framework/node_def.pb.h" 58 #include "tensorflow/core/framework/tensor.h" 59 #include "tensorflow/core/framework/versions.pb.h" 60 #include "tensorflow/core/graph/algorithm.h" 61 #include "tensorflow/core/graph/graph.h" 62 #include "tensorflow/core/graph/graph_constructor.h" 63 #include "tensorflow/core/graph/graph_partition.h" 64 #include "tensorflow/core/graph/subgraph.h" 65 #include "tensorflow/core/graph/tensor_id.h" 66 #include "tensorflow/core/lib/core/errors.h" 67 #include "tensorflow/core/lib/core/notification.h" 68 #include "tensorflow/core/lib/core/refcount.h" 69 #include "tensorflow/core/lib/core/status.h" 70 #include "tensorflow/core/lib/gtl/array_slice.h" 71 #include "tensorflow/core/lib/gtl/stl_util.h" 72 #include "tensorflow/core/lib/monitoring/counter.h" 73 #include "tensorflow/core/lib/strings/numbers.h" 74 #include "tensorflow/core/lib/strings/str_util.h" 75 #include "tensorflow/core/lib/strings/strcat.h" 76 #include "tensorflow/core/platform/cpu_info.h" 77 #include "tensorflow/core/platform/logging.h" 78 #include "tensorflow/core/platform/mutex.h" 79 #include "tensorflow/core/platform/types.h" 80 #include "tensorflow/core/util/device_name_utils.h" 81 #include "tensorflow/core/util/env_var.h" 84 #include "tensorflow/core/common_runtime/gpu/gpu_tracer.h" 92 "/tensorflow/core/tbb_session_runs",
93 "The number of times TBBSession::Run() has been called.");
99 string GetRendezvousKey(
const string& tensor_name,
100 const DeviceAttributes& device_info,
101 const FrameAndIter& frame_iter) {
102 return strings::StrCat(device_info.name(),
";",
103 strings::FpToString(device_info.incarnation()),
";",
104 device_info.name(),
";", tensor_name,
";",
105 frame_iter.frame_id,
":", frame_iter.iter_id);
115 return options.target ==
"tbb";
120 if (options.config.graph_options().build_cost_model() > 0) {
121 EnableCPUAllocatorFullStats(
true);
123 std::vector<Device*> devices;
124 Status s = DeviceFactory::AddDevices(
125 options,
"/job:localhost/replica:0/task:0", &devices);
132 new TBBSession(options,
new DeviceMgr(devices),
this);
135 sessions_.push_back(session);
141 const std::vector<string>& containers)
override {
142 std::vector<TBBSession*> sessions_to_reset;
151 for (
auto session : sessions_to_reset) {
152 s.Update(
session->Reset(containers));
156 for (
auto session : sessions_to_reset) {
170 std::vector<TBBSession*> sessions_
GUARDED_BY(sessions_lock_);
199 arena.execute( [&g,&
c] () {g.run(
c ); } );
203 const DeviceMgr* device_mgr,
206 device_mgr_(device_mgr),
208 cancellation_manager_(new CancellationManager()),
209 operation_timeout_in_ms_(options_.
config.operation_timeout_in_ms()) {
215 LOG(
ERROR) << status.error_message();
221 int devices_added = 0;
222 if (options.config.log_device_placement()) {
223 const string mapping_str =
device_mgr_->DeviceMappingString();
224 if (mapping_str.empty()) {
225 printf(
"Device mapping: no known devices.\n");
227 printf(
"Device mapping:\n%s", mapping_str.c_str());
229 LOG(
INFO) <<
"Device mapping:\n" << mapping_str;
238 if (devices_added == 0) {
246 if (!closed_)
Close().IgnoreError();
247 for (
auto& it : partial_runs_) {
248 it.second.reset(
nullptr);
250 for (
auto& it : executors_) {
258 execution_state_.reset(
nullptr);
263 const GraphDef& graph,
bool* out_already_initialized) {
266 *out_already_initialized =
true;
273 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
274 SimpleGraphExecutionStateOptions
options;
276 options.session_options = &
options_;
285 GraphDef
temp(graph);
286 TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
287 &temp, options, &execution_state_));
288 graph_created_ =
true;
289 *out_already_initialized =
false;
295 if (graph.node_size() > 0) {
297 if (graph_created_) {
298 return errors::AlreadyExists(
299 "A Graph has already been created for this session.");
313 bool already_initialized;
318 if (already_initialized) {
319 TF_RETURN_IF_ERROR(
flib_def_->AddLibrary(graph.library()));
320 std::unique_ptr<SimpleGraphExecutionState> state;
321 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
322 execution_state_.swap(state);
328 const std::vector<string>& output_names,
329 const std::vector<string>& target_nodes,
330 std::vector<Tensor>*
outputs) {
331 RunMetadata run_metadata;
332 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
337 const DebugOptions& debug_options, int64 session_run_index,
338 int64 executor_step_index,
const std::vector<string>& input_names,
339 const std::vector<string>& output_names,
340 const std::vector<string>& target_names,
341 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
343 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
344 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
345 debug_options.global_step(), session_run_index, executor_step_index,
346 input_names, output_names, target_names));
351 const DebugOptions& debug_options, Graph* graph, Device* device) {
352 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
354 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
356 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
357 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
363 const std::vector<string>& output_names,
364 const std::vector<string>& target_nodes,
366 RunMetadata* run_metadata) {
368 tbb_session_runs->GetCell()->IncrementBy(1);
371 if (!graph_created_) {
372 return errors::InvalidArgument(
373 "Session was not created with a graph before Run()!");
378 std::vector<string> input_tensor_names;
379 input_tensor_names.reserve(inputs.size());
380 for (
const auto& it : inputs) {
381 input_tensor_names.push_back(it.first);
387 RunStateArgs run_state_args(run_options.debug_options());
394 &executors_and_keys, &run_state_args));
395 const int64 executor_step_count = executors_and_keys->
step_count.fetch_add(1);
397 std::unique_ptr<DebuggerStateInterface> debugger_state;
398 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
400 run_options.debug_options(), args.step_id, executor_step_count,
401 input_tensor_names, output_names, target_nodes, &debugger_state));
406 FunctionCallFrame call_frame(executors_and_keys->
input_types,
408 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
409 for (
const auto& it : inputs) {
410 if (it.second.dtype() == DT_RESOURCE) {
411 Tensor tensor_from_handle;
420 Status s = call_frame.SetArgs(feed_args);
421 if (errors::IsInternal(s)) {
422 return errors::InvalidArgument(s.error_message());
423 }
else if (!s.ok()) {
430 CancellationManager step_cancellation_manager;
431 args.call_frame = &call_frame;
435 tbb::task_arena taskArena;
436 tbb::task_group taskGroup;
438 auto doneWithTaskGroup = [&taskArena, &taskGroup](
void *) { taskArena.execute([&taskGroup]() { taskGroup.wait();}); };
439 std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup) > guard(&taskGroup, doneWithTaskGroup);
442 const size_t num_executors = executors_and_keys->
items.size();
443 ExecutorBarrier* barrier =
new ExecutorBarrier(
444 num_executors, run_state.
rendez, [&run_state](
const Status& ret) {
446 mutex_lock l(run_state.mu_);
447 run_state.status.Update(ret);
452 args.rendezvous = run_state.rendez;
453 args.cancellation_manager = &step_cancellation_manager;
454 args.runner = [
this, &taskArena, &taskGroup](Executor::Args::Closure
c) {
458 args.tensor_store = &run_state.tensor_store;
459 args.step_container = &run_state.step_container;
460 if (LogMemory::IsEnabled()) {
461 LogMemory::RecordStep(
args.step_id, run_state_args.handle);
465 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
467 bool update_cost_model =
false;
468 if (
options_.config.graph_options().build_cost_model() > 0) {
469 const int64 build_cost_model_every =
470 options_.config.graph_options().build_cost_model();
471 const int64 build_cost_model_after =
472 options_.config.graph_options().build_cost_model_after();
473 int64 measure_step_count = executor_step_count - build_cost_model_after;
474 if (measure_step_count >= 0) {
476 ((measure_step_count + 1) % build_cost_model_every == 0);
479 if (do_trace || update_cost_model) {
480 run_state.collector.reset(
481 new StepStatsCollector(run_metadata->mutable_step_stats()));
482 args.stats_collector = run_state.collector.get();
486 std::unique_ptr<GPUTracer> tracer;
487 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
488 tracer.reset(CreateGPUTracer());
491 if (tracer) tracer->Start().IgnoreError();
493 #endif // GOOGLE_CUDA 497 CancellationToken cancellation_token =
500 cancellation_token, [&step_cancellation_manager]() {
501 step_cancellation_manager.StartCancel();
503 if (already_cancelled) {
507 run_state.executors_done.Notify();
509 return errors::Cancelled(
"Run call was cancelled");
512 for (
const auto& item : executors_and_keys->items) {
513 item.executor->RunAsync(
args, barrier->Get());
519 run_options.timeout_in_ms() > 0
520 ? run_options.timeout_in_ms()
526 mutex_lock
l(run_state.mu_);
527 run_state.status.Update(errors::Cancelled(
"Run call was cancelled"));
533 tracer->Stop().IgnoreError();
534 tracer->Collect(
args.stats_collector).IgnoreError();
536 #endif // GOOGLE_CUDA 539 mutex_lock
l(run_state.mu_);
540 TF_RETURN_IF_ERROR(run_state.status);
545 std::vector<Tensor> sorted_outputs;
546 Status s = call_frame.ConsumeRetvals(&sorted_outputs);
547 if (errors::IsInternal(s)) {
548 return errors::InvalidArgument(s.error_message());
549 }
else if (!s.ok()) {
552 const bool unique_outputs =
553 output_names.size() == executors_and_keys->output_name_to_index.size();
556 std::vector<int> first_indices;
557 if (!unique_outputs) {
558 first_indices.resize(output_names.size());
559 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
560 for (
int j = 0; j <=
i; ++j) {
561 if (output_names[
i] == output_names[j]) {
562 first_indices[
i] = j;
569 outputs->reserve(sorted_outputs.size());
570 for (
int i = 0; i < static_cast<int>(output_names.size()); ++
i) {
571 const string& output_name = output_names[
i];
572 if (first_indices.empty() || first_indices[
i] ==
i) {
573 outputs->emplace_back(
574 std::move(sorted_outputs[executors_and_keys
575 ->output_name_to_index[output_name]]));
577 outputs->push_back((*outputs)[first_indices[
i]]);
584 run_state.tensor_store.SaveTensors(output_names, &
session_state_));
588 if (update_cost_model) {
590 std::unordered_map<string, const Graph*> device_to_graph;
592 executors_and_keys->items) {
593 const Graph* graph = partition.
graph;
594 const string device = partition.
flib->device()->name();
595 device_to_graph[device] = graph;
600 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
601 for (
const auto& item : executors_and_keys->items) {
608 if (run_options.output_partition_graphs()) {
609 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
610 run_metadata->mutable_partition_graphs();
612 executors_and_keys->items) {
613 GraphDef* partition_graph_def = partition_graph_defs->Add();
614 exec_and_lib.
graph->ToGraphDef(partition_graph_def);
623 Tensor* retrieved_tensor) {
624 if (resource_tensor.dtype() != DT_RESOURCE) {
625 return errors::InvalidArgument(strings::StrCat(
626 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
627 resource_tensor.dtype()));
630 ResourceHandle resource_handle = resource_tensor.scalar<ResourceHandle>()();
632 if (resource_handle.container() ==
633 SessionState::kTensorHandleResourceTypeName) {
634 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
636 return errors::InvalidArgument(strings::StrCat(
637 "Invalid resource type hash code: ", resource_handle.hash_code(),
638 "(name: ", resource_handle.name(),
639 " type: ", resource_handle.maybe_type_name(),
")"));
644 gtl::ArraySlice<string>
inputs,
645 gtl::ArraySlice<string>
outputs, gtl::ArraySlice<string> target_nodes,
647 int64 handle_name_counter_value = -1;
652 string debug_tensor_watches_summary;
653 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
654 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
659 const string key = strings::StrCat(
660 str_util::Join(inputs,
","),
"->", str_util::Join(outputs,
","),
"/",
661 str_util::Join(target_nodes,
","),
"/", run_state_args->
is_partial_run,
662 "/", debug_tensor_watches_summary);
664 if (handle_name_counter_value >= 0) {
666 strings::StrCat(key,
";", handle_name_counter_value);
672 auto it = executors_.find(key);
673 if (it != executors_.end()) {
674 *executors_and_keys = it->second.get();
685 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
686 std::sort(inputs_sorted.begin(), inputs_sorted.end());
687 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
688 std::sort(outputs_sorted.begin(), outputs_sorted.end());
689 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
690 std::sort(tn_sorted.begin(), tn_sorted.end());
692 const string sorted_key = strings::StrCat(
693 str_util::Join(inputs_sorted,
","),
"->",
694 str_util::Join(outputs_sorted,
","),
"/", str_util::Join(tn_sorted,
","),
695 "/", run_state_args->
is_partial_run,
"/", debug_tensor_watches_summary);
697 if (handle_name_counter_value >= 0) {
699 strings::StrCat(sorted_key,
";", handle_name_counter_value);
705 auto it = executors_.find(sorted_key);
706 if (it != executors_.end()) {
707 *executors_and_keys = it->second.get();
709 executors_.emplace(key, it->second);
716 options.feed_endpoints = inputs_sorted;
717 options.fetch_endpoints = outputs_sorted;
718 options.target_nodes = tn_sorted;
719 options.use_function_convention = !run_state_args->
is_partial_run;
720 if (!run_state_args->
debug_options.debug_tensor_watch_opts().empty()) {
728 std::unordered_map<string, std::unique_ptr<Graph>>
graphs;
729 TF_RETURN_IF_ERROR(
CreateGraphs(options, &graphs, &ek->flib_def,
730 run_state_args, &ek->input_types,
735 std::unordered_set<StringPiece, StringPiece::Hasher>
names;
736 for (
const string&
input : inputs) {
737 TensorId
id(ParseTensorName(
input));
738 names.emplace(
id.
first);
740 for (
const string&
output : outputs) {
741 TensorId
id(ParseTensorName(
output));
742 names.emplace(
id.
first);
744 for (Node*
n : ek->graph->nodes()) {
745 if (names.count(
n->name()) > 0) {
746 ek->name_to_node.insert({
n->name(),
n});
750 ek->items.reserve(graphs.size());
751 const auto& optimizer_opts =
752 options_.config.graph_options().optimizer_options();
753 GraphOptimizer optimizer(optimizer_opts);
754 for (
auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
755 const string& partition_name = iter->first;
756 std::unique_ptr<Graph>& partition_graph = iter->second;
757 const int graph_def_version = partition_graph->versions().producer();
760 TF_RETURN_IF_ERROR(
device_mgr_->LookupDevice(partition_name, &device));
762 ek->items.resize(ek->items.size() + 1);
763 auto* item = &(ek->items.back());
764 item->flib.reset(NewFunctionLibraryRuntime(
766 ek->flib_def.get(), optimizer_opts));
768 LocalExecutorParams params;
769 params.device = device;
770 params.function_library = item->flib.get();
771 auto lib = item->flib.get();
772 auto opseg = device->op_segment();
773 params.create_kernel = [
this,
lib, opseg](
const NodeDef& ndef,
776 if (!
lib->IsStateful(ndef.op())) {
777 return lib->CreateKernel(ndef, kernel);
779 auto create_fn = [
lib, &ndef](OpKernel** kernel) {
780 return lib->CreateKernel(ndef, kernel);
788 params.delete_kernel = [
lib](OpKernel* kernel) {
790 if (kernel && !
lib->IsStateful(kernel->type_string())) {
796 optimizer.Optimize(
lib,
options_.env, device, &iter->second);
799 if (!options.debug_options.debug_tensor_watch_opts().empty()) {
801 options.debug_options, partition_graph.get(), params.device));
804 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
806 partition_graph.get()));
808 item->graph = partition_graph.get();
809 item->executor =
nullptr;
812 NewLocalExecutor(params, partition_graph.release(), &executor));
813 item->executor.reset(executor);
822 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
823 const string&
input = inputs_sorted[
i];
824 ek->input_name_to_index[
input] =
i;
826 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
827 const string&
output = outputs_sorted[
i];
828 ek->output_name_to_index[
output] =
i;
836 for (
size_t i = 0;
i < inputs_sorted.size(); ++
i) {
837 const string&
input = inputs_sorted[
i];
838 ek->input_name_to_rendezvous_key[
input] = GetRendezvousKey(
839 input,
device_set_.client_device()->attributes(), FrameAndIter(0, 0));
841 for (
size_t i = 0;
i < outputs_sorted.size(); ++
i) {
842 const string&
output = outputs_sorted[
i];
843 ek->output_name_to_rendezvous_key[
output] =
844 GetRendezvousKey(output,
device_set_.client_device()->attributes(),
854 auto insert_result = executors_.emplace(sorted_key, ek);
857 executors_.emplace(key, insert_result.first->second);
858 *executors_and_keys = insert_result.first->second.get();
864 const BuildGraphOptions& subgraph_options,
865 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
866 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
867 RunStateArgs* run_state_args, DataTypeVector* input_types,
868 DataTypeVector* output_types) {
870 std::unique_ptr<SimpleClientGraph> client_graph;
872 std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
873 SimpleGraphExecutionState* execution_state =
nullptr;
874 if (
options_.config.graph_options().place_pruned_graph()) {
878 SimpleGraphExecutionStateOptions prune_options;
880 prune_options.session_options = &
options_;
881 prune_options.stateful_placements = stateful_placements_;
882 TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph(
883 execution_state_->original_graph_def().library(), prune_options,
884 execution_state_->original_graph_def(), subgraph_options,
885 &temp_exec_state_holder, &client_graph));
886 execution_state = temp_exec_state_holder.get();
888 execution_state = execution_state_.get();
890 execution_state->BuildGraph(subgraph_options, &client_graph));
893 if (subgraph_options.feed_endpoints.size() !=
894 client_graph->feed_types.size()) {
895 return errors::Internal(
896 "Graph pruning failed: requested number of feed endpoints = ",
897 subgraph_options.feed_endpoints.size(),
898 " versus number of pruned feed endpoints = ",
899 client_graph->feed_types.size());
901 if (subgraph_options.fetch_endpoints.size() !=
902 client_graph->fetch_types.size()) {
903 return errors::Internal(
904 "Graph pruning failed: requested number of fetch endpoints = ",
905 subgraph_options.fetch_endpoints.size(),
906 " versus number of pruned fetch endpoints = ",
907 client_graph->fetch_types.size());
910 auto current_stateful_placements = execution_state->GetStatefulPlacements();
914 for (
auto placement_pair : current_stateful_placements) {
915 const string& node_name = placement_pair.first;
916 const string& placement = placement_pair.second;
917 auto iter = stateful_placements_.find(node_name);
918 if (iter == stateful_placements_.end()) {
919 stateful_placements_.insert(std::make_pair(node_name, placement));
920 }
else if (iter->second != placement) {
921 return errors::Internal(
922 "Stateful placement mismatch. " 923 "Current assignment of ",
924 node_name,
" to ", iter->second,
" does not match ", placement);
928 stateful_placements_ = execution_state->GetStatefulPlacements();
933 CopyGraph(*execution_state->full_graph(), run_state_args->
graph.get());
937 PartitionOptions popts;
938 popts.node_to_loc = [](
const Node* node) {
939 assert(node !=
nullptr);
940 return node->assigned_device_name();
942 popts.new_name = [
this](
const string&
prefix) {
945 popts.get_incarnation = [](
const string&
name) {
950 popts.control_flow_added =
false;
952 std::unordered_map<string, GraphDef> partitions;
953 TF_RETURN_IF_ERROR(
Partition(popts, &client_graph->graph, &partitions));
955 std::vector<string> device_names;
958 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
962 for (
const auto& partition : partitions) {
963 const string local_partition_name =
964 DeviceNameUtils::LocalName(partition.first);
965 if (
std::count(device_names.begin(), device_names.end(),
966 local_partition_name) == 0) {
967 return errors::InvalidArgument(
968 "Creating a partition for ", local_partition_name,
969 " which doesn't exist in the list of available devices. Available " 971 str_util::Join(device_names,
","));
975 for (
const auto& partition : partitions) {
976 std::unique_ptr<Graph> device_graph(
977 new Graph(client_graph->flib_def.get()));
978 GraphConstructorOptions device_opts;
980 device_opts.allow_internal_ops =
true;
981 device_opts.expect_device_spec =
true;
982 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
983 device_graph.get()));
987 GraphOptimizationPassOptions optimization_options;
988 optimization_options.session_options = &
options_;
989 optimization_options.flib_def = client_graph->flib_def.get();
990 optimization_options.partition_graphs =
outputs;
991 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
992 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
995 for (
auto& partition : *
outputs) {
996 const string& partition_name = partition.first;
997 std::unique_ptr<Graph>* graph = &partition.second;
999 VLOG(2) <<
"Created " << DebugString(graph->get()) <<
" for " 1004 s =
device_mgr_->LookupDevice(partition_name, &d);
1010 s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph);
1015 *flib_def =
std::move(client_graph->flib_def);
1016 std::swap(*input_types, client_graph->feed_types);
1017 std::swap(*output_types, client_graph->fetch_types);
1022 std::vector<DeviceAttributes>* response) {
1024 response->reserve(
devices_.size());
1026 const DeviceAttributes& attrs =
d->attributes();
1027 response->emplace_back(attrs);
1033 const std::vector<string>& containers) {
1050 const std::vector<string>& pending_input_names,
1051 const std::vector<string>& pending_output_names, int64 step_id,
1052 const std::vector<Device*>* devices)
1053 : step_container(step_id, [devices](const
string&
name) {
1054 for (
auto d : *devices) {
1055 if (!
d->resource_manager()->Cleanup(
name).ok()) {
1061 for (
auto&
name : pending_input_names) {
1064 for (
auto&
name : pending_output_names) {
1070 const std::vector<Device*>* devices)
1071 :
RunState({}, {}, step_id, devices) {}
1076 rendez->StartAbort(errors::Cancelled(
"PRun cancellation"));
1085 if (!it.second)
return false;
1088 if (!it.second)
return false;
1094 tbb::task_group& taskGroup,
1096 CancellationManager* cm,
1097 int64 timeout_in_ms) {
1100 arena.execute([&taskGroup]() { taskGroup.wait();} );
1106 mutex_lock
l(run_state->
mu_);
1107 run_state->status.Update(status);
1118 Notification* notification, int64 timeout_in_ms) {
1119 if (timeout_in_ms > 0) {
1120 int64 timeout_in_us = timeout_in_ms * 1000;
1121 bool notified = WaitForNotificationWithTimeout(notification, timeout_in_us);
1123 return Status(error::DEADLINE_EXCEEDED,
1124 "Timed out waiting for notification");
1127 notification->WaitForNotification();
std::unique_ptr< FunctionLibraryDefinition > flib_def_
std::vector< PerPartitionExecutorsAndLib > items
RunState(int64 step_id, const std::vector< Device * > *devices)
DataTypeVector output_types
std::unordered_map< string, size_t > input_name_to_index
static const HistoName names[]
Notification executors_done
::tensorflow::Status Close() override
static boost::mutex mutex
std::vector< Device * > devices_
std::atomic< int64 > handle_name_counter_
std::vector< std::pair< string, Tensor > > NamedTensorList
std::atomic< int64 > edge_name_counter_
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
std::unique_ptr< FunctionLibraryRuntime > flib
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
CostModelManager cost_model_manager_
std::unordered_map< string, bool > pending_inputs
::tensorflow::Status Create(const GraphDef &graph) override
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Status Reset(const SessionOptions &options, const std::vector< string > &containers) override
Session * NewSession(const SessionOptions &options) override
static std::string const input
void Deregister(const TBBSession *session)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
static NTSessionRegistrar registrar
IntraProcessRendezvous * rendez
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::vector< TBBSession * > sessions_ GUARDED_BY(sessions_lock_)
void SchedClosure(tbb::task_arena &arena, tbb::task_group &g, std::function< void()> c)
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
CancellationManager * cancellation_manager_
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
::tensorflow::Status Reset(const std::vector< string > &containers)
const DebugOptions & debug_options
::tensorflow::Status Extend(const GraphDef &graph) override
DataTypeVector input_types
std::pair< int, edm::FunctionWithDict > OK
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status CheckNotClosed()
std::atomic_int_fast64_t step_count
const SessionOptions options_
TBBSessionFactory *const factory_
static std::atomic_int_fast64_t step_id_counter_
def remove(d, key, TELL=False)
const std::unique_ptr< const DeviceMgr > device_mgr_
const int64 operation_timeout_in_ms_
Executor::Args::NodeOutputsCallback node_outputs_callback_
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
TBBSession(const SessionOptions &options, const DeviceMgr *device_mgr, TBBSessionFactory *factory)
std::unique_ptr< Graph > graph
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
SessionState session_state_
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
bool AcceptsOptions(const SessionOptions &options) override