CMS 3D CMS Logo

List of all members | Classes | Public Types | Public Member Functions | Private Types | Private Member Functions | Private Attributes | Static Private Attributes | Friends
tensorflow::TBBSession Class Reference

#include <TBBSession.h>

Inheritance diagram for tensorflow::TBBSession:
Session

Classes

struct  ExecutorsAndKeys
 
struct  PerPartitionExecutorsAndLib
 
struct  RunState
 
struct  RunStateArgs
 

Public Types

typedef std::function< void(Session *)> CloseCallback
 
typedef std::vector< std::pair< string, Tensor > > NamedTensorList
 
typedef std::unordered_map< StringPiece, Node *, StringPiece::Hasher > NameNodeMap
 

Public Member Functions

::tensorflow::Status Close () override
 
::tensorflow::Status Create (const GraphDef &graph) override
 
void ExportCostModels (CostModelManager::CostModelMap *cost_models)
 
::tensorflow::Status Extend (const GraphDef &graph) override
 
::tensorflow::Status ListDevices (std::vector< DeviceAttributes > *response) override
 
::tensorflow::Status Reset (const std::vector< string > &containers)
 
::tensorflow::Status Run (const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
 
::tensorflow::Status Run (const ::tensorflow::RunOptions &run_options, const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs, RunMetadata *run_metadata) override
 
 TBBSession (const SessionOptions &options, const DeviceMgr *device_mgr, TBBSessionFactory *factory)
 
 ~TBBSession () override
 

Private Types

typedef TBBSession ME
 

Private Member Functions

::tensorflow::Status CheckNotClosed ()
 
::tensorflow::Status CreateDebuggerState (const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
 
::tensorflow::Status CreateGraphs (const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
 
::tensorflow::Status DecorateAndPublishGraphForDebug (const DebugOptions &debug_options, Graph *graph, Device *device)
 
::tensorflow::Status ExtendLocked (const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
 
::tensorflow::Status GetOrCreateExecutors (gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
 
bool graph_created_ GUARDED_BY (graph_def_lock_)
 
GraphDef graph_def_ GUARDED_BY (graph_def_lock_)
 
std::unordered_map< string, std::shared_ptr< ExecutorsAndKeys > > executors_ GUARDED_BY (executor_lock_)
 
std::unordered_map< string, std::unique_ptr< RunState > > partial_runs_ GUARDED_BY (executor_lock_)
 
std::unordered_map< string, string > stateful_placements_ GUARDED_BY (graph_def_lock_)
 
std::unique_ptr< SimpleGraphExecutionState > execution_state_ GUARDED_BY (graph_def_lock_)
 
bool closed_ GUARDED_BY (closed_lock_)
 
Status MaybeInitializeExecutionState (const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
 
::tensorflow::Status ResourceHandleToInputTensor (const Tensor &resource_tensor, Tensor *retrieved_tensor)
 
void SchedClosure (tbb::task_arena &arena, tbb::task_group &g, std::function< void()> c)
 
 TF_DISALLOW_COPY_AND_ASSIGN (TBBSession)
 
::tensorflow::Status WaitForNotification (Notification *n, int64 timeout_in_ms)
 
void WaitForNotification (tbb::task_arena &arena, tbb::task_group &group, RunState *run_state, CancellationManager *cm, int64 timeout_in_ms)
 

Private Attributes

CancellationManager * cancellation_manager_
 
mutex closed_lock_
 
CostModelManager cost_model_manager_
 
const std::unique_ptr< const DeviceMgr > device_mgr_
 
DeviceSet device_set_
 
std::vector< Device * > devices_
 
std::atomic< int64 > edge_name_counter_ = {0}
 
mutex executor_lock_
 
TBBSessionFactory *const factory_
 
std::unique_ptr< FunctionLibraryDefinition > flib_def_
 
mutex graph_def_lock_
 
std::atomic< int64 > handle_name_counter_ = {0}
 
Status init_error_
 
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr
 
const int64 operation_timeout_in_ms_ = 0
 
const SessionOptions options_
 
string session_handle_
 
SessionState session_state_
 
bool sync_on_finish_ = true
 

Static Private Attributes

static std::atomic_int_fast64_t step_id_counter_
 

Friends

class DebugGateway
 

Detailed Description

Definition at line 78 of file TBBSession.h.

Member Typedef Documentation

typedef std::function<void(Session*)> tensorflow::TBBSession::CloseCallback

Definition at line 80 of file TBBSession.h.

Definition at line 123 of file TBBSession.h.

typedef std::vector<std::pair<string, Tensor> > tensorflow::TBBSession::NamedTensorList

Definition at line 90 of file TBBSession.h.

typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> tensorflow::TBBSession::NameNodeMap

Definition at line 92 of file TBBSession.h.

Constructor & Destructor Documentation

tensorflow::TBBSession::TBBSession ( const SessionOptions &  options,
const DeviceMgr *  device_mgr,
TBBSessionFactory factory 
)

Definition at line 202 of file TBBSession.cc.

References edmIntegrityCheck::d, device_mgr_, device_set_, devices_, dqm::qstatus::ERROR, MessageLogger_cfi::INFO, LOG, session_handle_, btagGenBb_cfi::Status, mps_update::status, and sync_on_finish_.

205  : options_(options),
206  device_mgr_(device_mgr),
207  factory_(factory),
208  cancellation_manager_(new CancellationManager()),
209  operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
210  // The default value of sync_on_finish will be flipped soon and this
211  // environment variable will be removed as well.
212  Status status =
213  ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
214  if (!status.ok()) {
215  LOG(ERROR) << status.error_message();
216  }
217  // NOTE(mrry): We do not need to use a unique string for the session
218  // handle, because TBBSession owns its devices. This may change
219  // in future versions.
220  session_handle_ = "tbb";
221  int devices_added = 0;
222  if (options.config.log_device_placement()) {
223  const string mapping_str = device_mgr_->DeviceMappingString();
224  if (mapping_str.empty()) {
225  printf("Device mapping: no known devices.\n");
226  } else {
227  printf("Device mapping:\n%s", mapping_str.c_str());
228  }
229  LOG(INFO) << "Device mapping:\n" << mapping_str;
230  }
231  for (auto d : device_mgr_->ListDevices()) {
232  devices_.push_back(d);
233  device_set_.AddDevice(d);
234  d->op_segment()->AddHold(session_handle_);
235 
236  // The first device added is special: it is the 'client device' (a
237  // CPU device) from which we feed and fetch Tensors.
238  if (devices_added == 0) {
239  device_set_.set_client_device(d);
240  }
241  ++devices_added;
242  }
243 }
std::vector< Device * > devices_
Definition: TBBSession.h:257
#define LOG(A)
CancellationManager * cancellation_manager_
Definition: TBBSession.h:289
const SessionOptions options_
Definition: TBBSession.h:253
TBBSessionFactory *const factory_
Definition: TBBSession.h:288
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:256
const int64 operation_timeout_in_ms_
Definition: TBBSession.h:319
static const int ERROR
tensorflow::TBBSession::~TBBSession ( )
override

Definition at line 245 of file TBBSession.cc.

References cancellation_manager_, Close(), edmIntegrityCheck::d, device_mgr_, flib_def_, and session_handle_.

245  {
246  if (!closed_) Close().IgnoreError();
247  for (auto& it : partial_runs_) {
248  it.second.reset(nullptr);
249  }
250  for (auto& it : executors_) {
251  it.second.reset();
252  }
253  for (auto d : device_mgr_->ListDevices()) {
254  d->op_segment()->RemoveHold(session_handle_);
255  }
256  delete cancellation_manager_;
257 
258  execution_state_.reset(nullptr);
259  flib_def_.reset(nullptr);
260 }
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:305
::tensorflow::Status Close() override
Definition: TBBSession.cc:1038
CancellationManager * cancellation_manager_
Definition: TBBSession.h:289
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:256

Member Function Documentation

::tensorflow::Status tensorflow::TBBSession::CheckNotClosed ( )
inlineprivate

Definition at line 237 of file TBBSession.h.

References pfDeepBoostedJetPreprocessParams_cfi::input_names, checklumidiff::l, and btagGenBb_cfi::Status.

Referenced by DecorateAndPublishGraphForDebug(), and Extend().

237  {
238  mutex_lock l(closed_lock_);
239  if (closed_) return errors::Cancelled("Session has been closed.");
241  }
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
tensorflow::Status tensorflow::TBBSession::Close ( )
override

Definition at line 1038 of file TBBSession.cc.

References cancellation_manager_, closed_lock_, tensorflow::TBBSessionFactory::Deregister(), factory_, and checklumidiff::l.

Referenced by ~TBBSession().

1038  {
1039  cancellation_manager_->StartCancel();
1040  {
1041  mutex_lock l(closed_lock_);
1042  if (closed_) return ::tensorflow::Status::OK();
1043  closed_ = true;
1044  }
1045  if (factory_ != nullptr) factory_->Deregister(this);
1047 }
void Deregister(const TBBSession *session)
Definition: TBBSession.cc:162
CancellationManager * cancellation_manager_
Definition: TBBSession.h:289
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
TBBSessionFactory *const factory_
Definition: TBBSession.h:288
Status tensorflow::TBBSession::Create ( const GraphDef &  graph)
override

Definition at line 293 of file TBBSession.cc.

References ExtendLocked(), graph_def_lock_, init_error_, and checklumidiff::l.

293  {
294  TF_RETURN_IF_ERROR(init_error_);
295  if (graph.node_size() > 0) {
296  mutex_lock l(graph_def_lock_);
297  if (graph_created_) {
298  return errors::AlreadyExists(
299  "A Graph has already been created for this session.");
300  }
301  return ExtendLocked(graph);
302  }
303  return Status::OK();
304 }
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: TBBSession.cc:312
Status tensorflow::TBBSession::CreateDebuggerState ( const DebugOptions &  debug_options,
int64  session_run_index,
int64  executor_step_index,
const std::vector< string > &  input_names,
const std::vector< string > &  output_names,
const std::vector< string > &  target_names,
std::unique_ptr< DebuggerStateInterface > *  debugger_state 
)
private

Definition at line 336 of file TBBSession.cc.

References pfDeepBoostedJetPreprocessParams_cfi::input_names.

Referenced by DecorateAndPublishGraphForDebug().

341  {
342  TF_RETURN_IF_ERROR(
343  DebuggerStateRegistry::CreateState(debug_options, debugger_state));
344  TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
345  debug_options.global_step(), session_run_index, executor_step_index,
346  input_names, output_names, target_names));
347  return Status::OK();
348 }
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
Status tensorflow::TBBSession::CreateGraphs ( const BuildGraphOptions &  options,
std::unordered_map< string, std::unique_ptr< Graph >> *  outputs,
std::unique_ptr< FunctionLibraryDefinition > *  flib_def,
RunStateArgs run_state_args,
DataTypeVector *  input_types,
DataTypeVector *  output_types 
)
private

Definition at line 863 of file TBBSession.cc.

References KineDebug3::count(), edmIntegrityCheck::d, device_mgr_, device_set_, devices_, edge_name_counter_, flib_def_, tensorflow::TBBSession::RunStateArgs::graph, graph_def_lock_, tensorflow::TBBSession::RunStateArgs::is_partial_run, checklumidiff::l, eostools::move(), dataset::name, options_, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, tablePrinter::prefix, alignCSCRings::s, btagGenBb_cfi::Status, and std::swap().

Referenced by GetOrCreateExecutors().

868  {
869  mutex_lock l(graph_def_lock_);
870  std::unique_ptr<SimpleClientGraph> client_graph;
871 
872  std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
873  SimpleGraphExecutionState* execution_state = nullptr;
874  if (options_.config.graph_options().place_pruned_graph()) {
875  // Because we are placing pruned graphs, we need to create a
876  // new SimpleGraphExecutionState for every new unseen graph,
877  // and then place it.
878  SimpleGraphExecutionStateOptions prune_options;
879  prune_options.device_set = &device_set_;
880  prune_options.session_options = &options_;
881  prune_options.stateful_placements = stateful_placements_;
882  TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph(
883  execution_state_->original_graph_def().library(), prune_options,
884  execution_state_->original_graph_def(), subgraph_options,
885  &temp_exec_state_holder, &client_graph));
886  execution_state = temp_exec_state_holder.get();
887  } else {
888  execution_state = execution_state_.get();
889  TF_RETURN_IF_ERROR(
890  execution_state->BuildGraph(subgraph_options, &client_graph));
891  }
892 
893  if (subgraph_options.feed_endpoints.size() !=
894  client_graph->feed_types.size()) {
895  return errors::Internal(
896  "Graph pruning failed: requested number of feed endpoints = ",
897  subgraph_options.feed_endpoints.size(),
898  " versus number of pruned feed endpoints = ",
899  client_graph->feed_types.size());
900  }
901  if (subgraph_options.fetch_endpoints.size() !=
902  client_graph->fetch_types.size()) {
903  return errors::Internal(
904  "Graph pruning failed: requested number of fetch endpoints = ",
905  subgraph_options.fetch_endpoints.size(),
906  " versus number of pruned fetch endpoints = ",
907  client_graph->fetch_types.size());
908  }
909 
910  auto current_stateful_placements = execution_state->GetStatefulPlacements();
911  // Update our current state based on the execution_state's
912  // placements. If there are any mismatches for a node,
913  // we should fail, as this should never happen.
914  for (auto placement_pair : current_stateful_placements) {
915  const string& node_name = placement_pair.first;
916  const string& placement = placement_pair.second;
917  auto iter = stateful_placements_.find(node_name);
918  if (iter == stateful_placements_.end()) {
919  stateful_placements_.insert(std::make_pair(node_name, placement));
920  } else if (iter->second != placement) {
921  return errors::Internal(
922  "Stateful placement mismatch. "
923  "Current assignment of ",
924  node_name, " to ", iter->second, " does not match ", placement);
925  }
926  }
927 
928  stateful_placements_ = execution_state->GetStatefulPlacements();
929 
930  // Remember the graph in run state if this is a partial run.
931  if (run_state_args->is_partial_run) {
932  run_state_args->graph.reset(new Graph(flib_def_.get()));
933  CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
934  }
935 
936  // Partition the graph across devices.
937  PartitionOptions popts;
938  popts.node_to_loc = [](const Node* node) {
939  assert(node != nullptr);
940  return node->assigned_device_name();
941  };
942  popts.new_name = [this](const string& prefix) {
943  return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
944  };
945  popts.get_incarnation = [](const string& name) {
946  // The direct session does not have changing incarnation numbers.
947  // Just return '1'.
948  return 1;
949  };
950  popts.control_flow_added = false;
951 
952  std::unordered_map<string, GraphDef> partitions;
953  TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
954 
955  std::vector<string> device_names;
956  for (auto device : devices_) {
957  // Extract the LocalName from the device.
958  device_names.push_back(DeviceNameUtils::LocalName(device->name()));
959  }
960 
961  // Check for valid partitions.
962  for (const auto& partition : partitions) {
963  const string local_partition_name =
964  DeviceNameUtils::LocalName(partition.first);
965  if (std::count(device_names.begin(), device_names.end(),
966  local_partition_name) == 0) {
967  return errors::InvalidArgument(
968  "Creating a partition for ", local_partition_name,
969  " which doesn't exist in the list of available devices. Available "
970  "devices: ",
971  str_util::Join(device_names, ","));
972  }
973  }
974 
975  for (const auto& partition : partitions) {
976  std::unique_ptr<Graph> device_graph(
977  new Graph(client_graph->flib_def.get()));
978  GraphConstructorOptions device_opts;
979  // There are internal operations (e.g., send/recv) that we now allow.
980  device_opts.allow_internal_ops = true;
981  device_opts.expect_device_spec = true;
982  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
983  device_graph.get()));
984  outputs->emplace(partition.first, std::move(device_graph));
985  }
986 
987  GraphOptimizationPassOptions optimization_options;
988  optimization_options.session_options = &options_;
989  optimization_options.flib_def = client_graph->flib_def.get();
990  optimization_options.partition_graphs = outputs;
991  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
992  OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
993 
994  Status s;
995  for (auto& partition : *outputs) {
996  const string& partition_name = partition.first;
997  std::unique_ptr<Graph>* graph = &partition.second;
998 
999  VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1000  << partition_name;
1001 
1002  // Give the device an opportunity to rewrite its subgraph.
1003  Device* d;
1004  s = device_mgr_->LookupDevice(partition_name, &d);
1005  if (!s.ok()) break;
1006  // TODO(pbar) The library is currently shared and immutable. There
1007  // may be possible use cases where a device may want to modify
1008  // function definitions - in which case the library would need to be
1009  // replicated per device.
1010  s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph);
1011  if (!s.ok()) {
1012  break;
1013  }
1014  }
1015  *flib_def = std::move(client_graph->flib_def);
1016  std::swap(*input_types, client_graph->feed_types);
1017  std::swap(*output_types, client_graph->fetch_types);
1018  return s;
1019 }
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:305
std::vector< Device * > devices_
Definition: TBBSession.h:257
std::atomic< int64 > edge_name_counter_
Definition: TBBSession.h:312
Partition
Definition: HLTHPDFilter.cc:32
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
const SessionOptions options_
Definition: TBBSession.h:253
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:256
def move(src, dest)
Definition: eostools.py:510
Status tensorflow::TBBSession::DecorateAndPublishGraphForDebug ( const DebugOptions &  debug_options,
Graph *  graph,
Device *  device 
)
private

Definition at line 350 of file TBBSession.cc.

References createfilelist::args, EnergyCorrector::c, cancellation_manager_, CheckNotClosed(), cost_model_manager_, CreateDebuggerState(), device_mgr_, devices_, executor_lock_, tensorflow::TBBSession::RunState::executors_done, tensorflow::TBBSession::PerPartitionExecutorsAndLib::flib, GetOrCreateExecutors(), tensorflow::TBBSession::PerPartitionExecutorsAndLib::graph, graph_def_lock_, mps_fire::i, tensorflow::TBBSession::ExecutorsAndKeys::input_name_to_index, tensorflow::TBBSession::ExecutorsAndKeys::input_types, PatBasicFWLiteJetAnalyzer_Selector_cfg::inputs, tensorflow::TBBSession::ExecutorsAndKeys::items, checklumidiff::l, eostools::move(), operation_timeout_in_ms_, options_, tensorflow::TBBSession::ExecutorsAndKeys::output_types, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, tensorflow::TBBSession::RunState::rendez, ResourceHandleToInputTensor(), Run(), alignCSCRings::s, SchedClosure(), session_state_, btagGenBb_cfi::Status, tensorflow::TBBSession::ExecutorsAndKeys::step_count, step_id_counter_, sync_on_finish_, and WaitForNotification().

Referenced by GetOrCreateExecutors().

351  {
352  std::unique_ptr<DebugGraphDecoratorInterface> decorator;
353  TF_RETURN_IF_ERROR(
354  DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
355 
356  TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
357  TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
358  return Status::OK();
359 }
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
void tensorflow::TBBSession::ExportCostModels ( CostModelManager::CostModelMap *  cost_models)
inline

Definition at line 118 of file TBBSession.h.

118  {
119  cost_model_manager_.ExportCostModels(cost_models);
120  }
CostModelManager cost_model_manager_
Definition: TBBSession.h:322
Status tensorflow::TBBSession::Extend ( const GraphDef &  graph)
override

Definition at line 306 of file TBBSession.cc.

References CheckNotClosed(), ExtendLocked(), graph_def_lock_, and checklumidiff::l.

306  {
307  TF_RETURN_IF_ERROR(CheckNotClosed());
308  mutex_lock l(graph_def_lock_);
309  return ExtendLocked(graph);
310 }
::tensorflow::Status CheckNotClosed()
Definition: TBBSession.h:237
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: TBBSession.cc:312
Status tensorflow::TBBSession::ExtendLocked ( const GraphDef &  graph)
private

Definition at line 312 of file TBBSession.cc.

References flib_def_, and MaybeInitializeExecutionState().

Referenced by Create(), and Extend().

312  {
313  bool already_initialized;
314  // If this is the first call, we can initialize the execution state
315  // with `graph` and do not need to call `Extend()`.
316  TF_RETURN_IF_ERROR(
317  MaybeInitializeExecutionState(graph, &already_initialized));
318  if (already_initialized) {
319  TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
320  std::unique_ptr<SimpleGraphExecutionState> state;
321  TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
322  execution_state_.swap(state);
323  }
324  return Status::OK();
325 }
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:305
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: TBBSession.cc:262
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
Status tensorflow::TBBSession::GetOrCreateExecutors ( gtl::ArraySlice< string >  inputs,
gtl::ArraySlice< string >  outputs,
gtl::ArraySlice< string >  target_nodes,
ExecutorsAndKeys **  executors_and_keys,
RunStateArgs run_state_args 
)
private

Definition at line 643 of file TBBSession.cc.

References CreateGraphs(), tensorflow::TBBSession::RunStateArgs::debug_options, DecorateAndPublishGraphForDebug(), device_mgr_, device_set_, executor_lock_, plotBeamSpotDB::first, tensorflow::TBBSession::RunStateArgs::graph, cuy::graphs, tensorflow::TBBSession::RunStateArgs::handle, handle_name_counter_, mps_fire::i, triggerObjects_cff::id, input, tensorflow::TBBSession::RunStateArgs::is_partial_run, crabWrapper::key, checklumidiff::l, mps_check::lib, eostools::move(), gen::n, cscdqm::h::names, node_outputs_callback_, AlcaSiPixelAliHarvester0T_cff::options, options_, convertSQLitetoXML_cfg::output, and session_handle_.

Referenced by DecorateAndPublishGraphForDebug().

646  {
647  int64 handle_name_counter_value = -1;
648  if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
649  handle_name_counter_value = handle_name_counter_.fetch_add(1);
650  }
651 
652  string debug_tensor_watches_summary;
653  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
654  debug_tensor_watches_summary = SummarizeDebugTensorWatches(
655  run_state_args->debug_options.debug_tensor_watch_opts());
656  }
657 
658  // Fast lookup path, no sorting.
659  const string key = strings::StrCat(
660  str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
661  str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
662  "/", debug_tensor_watches_summary);
663  // Set the handle, if it's needed to log memory or for partial run.
664  if (handle_name_counter_value >= 0) {
665  run_state_args->handle =
666  strings::StrCat(key, ";", handle_name_counter_value);
667  }
668 
669  // See if we already have the executors for this run.
670  {
671  mutex_lock l(executor_lock_); // could use reader lock
672  auto it = executors_.find(key);
673  if (it != executors_.end()) {
674  *executors_and_keys = it->second.get();
675  return Status::OK();
676  }
677  }
678 
679  // Slow lookup path, the unsorted key missed the cache.
680  // Sort the inputs and outputs, and look up with the sorted key in case an
681  // earlier call used a different order of inputs and outputs.
682  //
683  // We could consider some other signature instead of sorting that
684  // preserves the same property to avoid the sort in the future.
685  std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
686  std::sort(inputs_sorted.begin(), inputs_sorted.end());
687  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
688  std::sort(outputs_sorted.begin(), outputs_sorted.end());
689  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
690  std::sort(tn_sorted.begin(), tn_sorted.end());
691 
692  const string sorted_key = strings::StrCat(
693  str_util::Join(inputs_sorted, ","), "->",
694  str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
695  "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
696  // Set the handle, if its needed to log memory or for partial run.
697  if (handle_name_counter_value >= 0) {
698  run_state_args->handle =
699  strings::StrCat(sorted_key, ";", handle_name_counter_value);
700  }
701 
702  // See if we already have the executors for this run.
703  {
704  mutex_lock l(executor_lock_);
705  auto it = executors_.find(sorted_key);
706  if (it != executors_.end()) {
707  *executors_and_keys = it->second.get();
708  // Insert this under the original key.
709  executors_.emplace(key, it->second);
710  return Status::OK();
711  }
712  }
713 
714  // Nothing found, so create the executors and store in the cache.
715  BuildGraphOptions options;
716  options.feed_endpoints = inputs_sorted;
717  options.fetch_endpoints = outputs_sorted;
718  options.target_nodes = tn_sorted;
719  options.use_function_convention = !run_state_args->is_partial_run;
720  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
721  options.debug_options = run_state_args->debug_options;
722  }
723 
724  std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
725 
726  // The executor_lock_ is intentionally released while executor is
727  // being created.
728  std::unordered_map<string, std::unique_ptr<Graph>> graphs;
729  TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
730  run_state_args, &ek->input_types,
731  &ek->output_types));
732 
733  if (run_state_args->is_partial_run) {
734  ek->graph = std::move(run_state_args->graph);
735  std::unordered_set<StringPiece, StringPiece::Hasher> names;
736  for (const string& input : inputs) {
737  TensorId id(ParseTensorName(input));
738  names.emplace(id.first);
739  }
740  for (const string& output : outputs) {
741  TensorId id(ParseTensorName(output));
742  names.emplace(id.first);
743  }
744  for (Node* n : ek->graph->nodes()) {
745  if (names.count(n->name()) > 0) {
746  ek->name_to_node.insert({n->name(), n});
747  }
748  }
749  }
750  ek->items.reserve(graphs.size());
751  const auto& optimizer_opts =
752  options_.config.graph_options().optimizer_options();
753  GraphOptimizer optimizer(optimizer_opts);
754  for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
755  const string& partition_name = iter->first;
756  std::unique_ptr<Graph>& partition_graph = iter->second;
757  const int graph_def_version = partition_graph->versions().producer();
758 
759  Device* device;
760  TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
761 
762  ek->items.resize(ek->items.size() + 1);
763  auto* item = &(ek->items.back());
764  item->flib.reset(NewFunctionLibraryRuntime(
765  device_mgr_.get(), options_.env, device, graph_def_version,
766  ek->flib_def.get(), optimizer_opts));
767 
768  LocalExecutorParams params;
769  params.device = device;
770  params.function_library = item->flib.get();
771  auto lib = item->flib.get();
772  auto opseg = device->op_segment();
773  params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
774  OpKernel** kernel) {
775  // Caches the kernel only if the node is stateful.
776  if (!lib->IsStateful(ndef.op())) {
777  return lib->CreateKernel(ndef, kernel);
778  }
779  auto create_fn = [lib, &ndef](OpKernel** kernel) {
780  return lib->CreateKernel(ndef, kernel);
781  };
782  // Kernels created for subgraph nodes need to be cached. On
783  // cache miss, create_fn() is invoked to create a kernel based
784  // on the function library here + global op registry.
785  return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
786  create_fn);
787  };
788  params.delete_kernel = [lib](OpKernel* kernel) {
789  // If the node is stateful, opseg owns it. Otherwise, delete it.
790  if (kernel && !lib->IsStateful(kernel->type_string())) {
791  delete kernel;
792  }
793  };
794  params.node_outputs_cb = node_outputs_callback_;
795 
796  optimizer.Optimize(lib, options_.env, device, &iter->second);
797 
798  // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
799  if (!options.debug_options.debug_tensor_watch_opts().empty()) {
800  TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
801  options.debug_options, partition_graph.get(), params.device));
802  }
803 
804  TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
805  device->name(),
806  partition_graph.get()));
807  // NewLocalExecutor takes ownership of partition_graph.
808  item->graph = partition_graph.get();
809  item->executor = nullptr;
810  Executor* executor;
811  TF_RETURN_IF_ERROR(
812  NewLocalExecutor(params, partition_graph.release(), &executor));
813  item->executor.reset(executor);
814  }
815 
816  // Cache the mapping from input/output names to graph elements to
817  // avoid recomputing it every time.
818  if (!run_state_args->is_partial_run) {
819  // For regular `Run()`, we use the function calling convention, and so
820  // maintain a mapping from input/output names to
821  // argument/return-value ordinal index.
822  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
823  const string& input = inputs_sorted[i];
824  ek->input_name_to_index[input] = i;
825  }
826  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
827  const string& output = outputs_sorted[i];
828  ek->output_name_to_index[output] = i;
829  }
830  } else {
831  // For `PRun()`, we use the rendezvous calling convention, and so
832  // maintain a mapping from input/output names to rendezvous keys.
833  //
834  // We always use the first device as the device name portion of the
835  // key, even if we're feeding another graph.
836  for (size_t i = 0; i < inputs_sorted.size(); ++i) {
837  const string& input = inputs_sorted[i];
838  ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
839  input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
840  }
841  for (size_t i = 0; i < outputs_sorted.size(); ++i) {
842  const string& output = outputs_sorted[i];
843  ek->output_name_to_rendezvous_key[output] =
844  GetRendezvousKey(output, device_set_.client_device()->attributes(),
845  FrameAndIter(0, 0));
846  }
847  }
848 
849  // Reacquire the lock, try to insert into the map.
850  mutex_lock l(executor_lock_);
851 
852  // Another thread may have created the entry before us, in which case we will
853  // reuse the already created one.
854  auto insert_result = executors_.emplace(sorted_key, ek);
855  // Insert the value under the original key, so the fast path lookup will work
856  // if the user uses the same order of inputs, outputs, and targets again.
857  executors_.emplace(key, insert_result.first->second);
858  *executors_and_keys = insert_result.first->second.get();
859 
860  return Status::OK();
861 }
static const HistoName names[]
std::atomic< int64 > handle_name_counter_
Definition: TBBSession.h:313
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: TBBSession.cc:350
static std::string const input
Definition: EdmProvDump.cc:44
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: TBBSession.cc:863
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
const SessionOptions options_
Definition: TBBSession.h:253
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:256
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: TBBSession.h:324
graphs
Definition: cuy.py:960
def move(src, dest)
Definition: eostools.py:510
bool graph_created_ tensorflow::TBBSession::GUARDED_BY ( graph_def_lock_  )
private
GraphDef graph_def_ tensorflow::TBBSession::GUARDED_BY ( graph_def_lock_  )
private
std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys> > executors_ tensorflow::TBBSession::GUARDED_BY ( executor_lock_  )
private
std::unordered_map<string, std::unique_ptr<RunState> > partial_runs_ tensorflow::TBBSession::GUARDED_BY ( executor_lock_  )
private
std::unordered_map<string, string> stateful_placements_ tensorflow::TBBSession::GUARDED_BY ( graph_def_lock_  )
private
std::unique_ptr<SimpleGraphExecutionState> execution_state_ tensorflow::TBBSession::GUARDED_BY ( graph_def_lock_  )
private
bool closed_ tensorflow::TBBSession::GUARDED_BY ( closed_lock_  )
private
tensorflow::Status tensorflow::TBBSession::ListDevices ( std::vector< DeviceAttributes > *  response)
override

Definition at line 1021 of file TBBSession.cc.

References edmIntegrityCheck::d, and devices_.

1022  {
1023  response->clear();
1024  response->reserve(devices_.size());
1025  for (Device* d : devices_) {
1026  const DeviceAttributes& attrs = d->attributes();
1027  response->emplace_back(attrs);
1028  }
1030 }
std::vector< Device * > devices_
Definition: TBBSession.h:257
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
Status tensorflow::TBBSession::MaybeInitializeExecutionState ( const GraphDef &  graph,
bool *  out_already_initialized 
)
private

Definition at line 262 of file TBBSession.cc.

References device_set_, flib_def_, AlcaSiPixelAliHarvester0T_cff::options, options_, and groupFilesInBlocks::temp.

Referenced by ExtendLocked().

263  {
264  // If already initialized, do nothing.
265  if (flib_def_ && execution_state_) {
266  *out_already_initialized = true;
267  return Status::OK();
268  }
269  // Set up the per-session execution state.
270  // NOTE(mrry): The function library created here will be used for
271  // all subsequent extensions of the graph.
272  flib_def_.reset(
273  new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
274  SimpleGraphExecutionStateOptions options;
275  options.device_set = &device_set_;
276  options.session_options = &options_;
277  // TODO(mrry,suharshs): We explicitly copy `graph` so that
278  // `MakeForBaseGraph()` can take ownership of its
279  // contents. Previously this happened implicitly in calls to the
280  // `SimpleGraphExecutionState`. Other sessions call
281  // `MakeForBaseGraph` in such a way that we can destructively read
282  // the passed-in `GraphDef`. In principle we could do the same here,
283  // with a wider refactoring; we might revise the direct session so
284  // that it copies the graph fewer times.
285  GraphDef temp(graph);
286  TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
287  &temp, options, &execution_state_));
288  graph_created_ = true;
289  *out_already_initialized = false;
290  return Status::OK();
291 }
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:305
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
const SessionOptions options_
Definition: TBBSession.h:253
tensorflow::Status tensorflow::TBBSession::Reset ( const std::vector< string > &  containers)

Definition at line 1032 of file TBBSession.cc.

References device_mgr_.

1033  {
1034  device_mgr_->ClearContainers(containers);
1036 }
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:256
Status tensorflow::TBBSession::ResourceHandleToInputTensor ( const Tensor &  resource_tensor,
Tensor *  retrieved_tensor 
)
private

Definition at line 622 of file TBBSession.cc.

References session_state_.

Referenced by DecorateAndPublishGraphForDebug().

623  {
624  if (resource_tensor.dtype() != DT_RESOURCE) {
625  return errors::InvalidArgument(strings::StrCat(
626  "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
627  resource_tensor.dtype()));
628  }
629 
630  ResourceHandle resource_handle = resource_tensor.scalar<ResourceHandle>()();
631 
632  if (resource_handle.container() ==
633  SessionState::kTensorHandleResourceTypeName) {
634  return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
635  } else {
636  return errors::InvalidArgument(strings::StrCat(
637  "Invalid resource type hash code: ", resource_handle.hash_code(),
638  "(name: ", resource_handle.name(),
639  " type: ", resource_handle.maybe_type_name(), ")"));
640  }
641 }
SessionState session_state_
Definition: TBBSession.h:286
Status tensorflow::TBBSession::Run ( const NamedTensorList inputs,
const std::vector< string > &  output_names,
const std::vector< string > &  target_nodes,
std::vector< Tensor > *  outputs 
)
override

Definition at line 327 of file TBBSession.cc.

Referenced by DecorateAndPublishGraphForDebug().

330  {
331  RunMetadata run_metadata;
332  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
333  &run_metadata);
334 }
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: TBBSession.cc:327
::tensorflow::Status tensorflow::TBBSession::Run ( const ::tensorflow::RunOptions &  run_options,
const NamedTensorList inputs,
const std::vector< string > &  output_names,
const std::vector< string > &  target_nodes,
std::vector< Tensor > *  outputs,
RunMetadata *  run_metadata 
)
override
void tensorflow::TBBSession::SchedClosure ( tbb::task_arena &  arena,
tbb::task_group &  g,
std::function< void()>  c 
)
private

Definition at line 198 of file TBBSession.cc.

References EnergyCorrector::c.

Referenced by DecorateAndPublishGraphForDebug().

198  {
199  arena.execute( [&g,&c] () {g.run( c ); } );
200 }
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Definition: Activities.doc:4
tensorflow::TBBSession::TF_DISALLOW_COPY_AND_ASSIGN ( TBBSession  )
private
tensorflow::Status tensorflow::TBBSession::WaitForNotification ( Notification *  n,
int64  timeout_in_ms 
)
private

Definition at line 1117 of file TBBSession.cc.

References btagGenBb_cfi::Status.

Referenced by DecorateAndPublishGraphForDebug(), and WaitForNotification().

1118  {
1119  if (timeout_in_ms > 0) {
1120  int64 timeout_in_us = timeout_in_ms * 1000;
1121  bool notified = WaitForNotificationWithTimeout(notification, timeout_in_us);
1122  if (!notified) {
1123  return Status(error::DEADLINE_EXCEEDED,
1124  "Timed out waiting for notification");
1125  }
1126  } else {
1127  notification->WaitForNotification();
1128  }
1129  return Status::OK();
1130 }
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
void tensorflow::TBBSession::WaitForNotification ( tbb::task_arena &  arena,
tbb::task_group &  group,
RunState run_state,
CancellationManager *  cm,
int64  timeout_in_ms 
)
private

Definition at line 1093 of file TBBSession.cc.

References tensorflow::TBBSession::RunState::executors_done, checklumidiff::l, tensorflow::TBBSession::RunState::mu_, btagGenBb_cfi::Status, mps_update::status, and WaitForNotification().

1097  {
1098  // Doing the wait in the arena adds this thread to the arena
1099  // and therefore tasks associated to the group can run on this thread
1100  arena.execute([&taskGroup]() { taskGroup.wait();} );
1101 
1102  Status status =
1103  WaitForNotification(&run_state->executors_done, timeout_in_ms);
1104  if (!status.ok()) {
1105  {
1106  mutex_lock l(run_state->mu_);
1107  run_state->status.Update(status);
1108  }
1109  cm->StartCancel();
1110  // We must wait for the executors to complete, because they have borrowed
1111  // references to `cm` and other per-step state. After this notification, it
1112  // is safe to clean up the step.
1113  run_state->executors_done.WaitForNotification();
1114  }
1115 }
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: TBBSession.cc:1117

Friends And Related Function Documentation

friend class DebugGateway
friend

Definition at line 329 of file TBBSession.h.

Member Data Documentation

CancellationManager* tensorflow::TBBSession::cancellation_manager_
private

Definition at line 289 of file TBBSession.h.

Referenced by Close(), DecorateAndPublishGraphForDebug(), and ~TBBSession().

mutex tensorflow::TBBSession::closed_lock_
private

Definition at line 308 of file TBBSession.h.

Referenced by Close().

CostModelManager tensorflow::TBBSession::cost_model_manager_
private

Definition at line 322 of file TBBSession.h.

Referenced by DecorateAndPublishGraphForDebug().

const std::unique_ptr<const DeviceMgr> tensorflow::TBBSession::device_mgr_
private
DeviceSet tensorflow::TBBSession::device_set_
private
std::vector<Device*> tensorflow::TBBSession::devices_
private
std::atomic<int64> tensorflow::TBBSession::edge_name_counter_ = {0}
private

Definition at line 312 of file TBBSession.h.

Referenced by CreateGraphs().

mutex tensorflow::TBBSession::executor_lock_
private

Definition at line 272 of file TBBSession.h.

Referenced by DecorateAndPublishGraphForDebug(), and GetOrCreateExecutors().

TBBSessionFactory* const tensorflow::TBBSession::factory_
private

Definition at line 288 of file TBBSession.h.

Referenced by Close().

std::unique_ptr<FunctionLibraryDefinition> tensorflow::TBBSession::flib_def_
private
mutex tensorflow::TBBSession::graph_def_lock_
private

Definition at line 263 of file TBBSession.h.

Referenced by Create(), CreateGraphs(), DecorateAndPublishGraphForDebug(), and Extend().

std::atomic<int64> tensorflow::TBBSession::handle_name_counter_ = {0}
private

Definition at line 313 of file TBBSession.h.

Referenced by GetOrCreateExecutors().

Status tensorflow::TBBSession::init_error_
private

Definition at line 266 of file TBBSession.h.

Referenced by Create().

Executor::Args::NodeOutputsCallback tensorflow::TBBSession::node_outputs_callback_ = nullptr
private

Definition at line 324 of file TBBSession.h.

Referenced by GetOrCreateExecutors().

const int64 tensorflow::TBBSession::operation_timeout_in_ms_ = 0
private

Definition at line 319 of file TBBSession.h.

Referenced by DecorateAndPublishGraphForDebug().

const SessionOptions tensorflow::TBBSession::options_
private
string tensorflow::TBBSession::session_handle_
private

Definition at line 260 of file TBBSession.h.

Referenced by GetOrCreateExecutors(), TBBSession(), and ~TBBSession().

SessionState tensorflow::TBBSession::session_state_
private

Definition at line 286 of file TBBSession.h.

Referenced by DecorateAndPublishGraphForDebug(), and ResourceHandleToInputTensor().

std::atomic_int_fast64_t tensorflow::TBBSession::step_id_counter_
staticprivate

Definition at line 316 of file TBBSession.h.

Referenced by DecorateAndPublishGraphForDebug().

bool tensorflow::TBBSession::sync_on_finish_ = true
private

Definition at line 269 of file TBBSession.h.

Referenced by DecorateAndPublishGraphForDebug(), and TBBSession().