36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 42 #include <unordered_map> 43 #include <unordered_set> 46 #include "tbb/task_arena.h" 48 #include "tensorflow/core/common_runtime/costmodel_manager.h" 49 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 50 #include "tensorflow/core/common_runtime/device_mgr.h" 51 #include "tensorflow/core/common_runtime/device_set.h" 52 #include "tensorflow/core/common_runtime/executor.h" 53 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 54 #include "tensorflow/core/common_runtime/session_factory.h" 55 #include "tensorflow/core/common_runtime/simple_graph_execution_state.h" 56 #include "tensorflow/core/framework/cancellation.h" 57 #include "tensorflow/core/framework/graph.pb.h" 58 #include "tensorflow/core/framework/session_state.h" 59 #include "tensorflow/core/framework/tensor.h" 60 #include "tensorflow/core/lib/core/errors.h" 61 #include "tensorflow/core/lib/core/status.h" 62 #include "tensorflow/core/platform/macros.h" 63 #include "tensorflow/core/platform/mutex.h" 64 #include "tensorflow/core/platform/types.h" 65 #include "tensorflow/core/public/session.h" 76 class TBBSessionFactory;
91 typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher>
97 const std::vector<string>& output_names,
98 const std::vector<string>& target_nodes,
99 std::vector<Tensor>*
outputs)
override;
103 const NamedTensorList&
inputs,
104 const std::vector<string>& output_names,
105 const std::vector<string>& target_nodes,
107 RunMetadata* run_metadata)
override;
115 std::vector<DeviceAttributes>* response)
override;
119 cost_model_manager_.ExportCostModels(cost_models);
128 Graph* graph =
nullptr;
129 std::unique_ptr<FunctionLibraryRuntime>
flib;
151 std::unique_ptr<FunctionLibraryDefinition>
flib_def;
152 std::vector<PerPartitionExecutorsAndLib>
items;
169 IntraProcessRendezvous* rendez =
nullptr;
177 RunState(int64 step_id,
const std::vector<Device*>* devices);
179 RunState(
const std::vector<string>& pending_input_names,
180 const std::vector<string>& pending_output_names, int64 step_id,
181 const std::vector<Device*>* devices);
184 bool PendingDone()
const;
192 bool is_partial_run =
false;
200 Status MaybeInitializeExecutionState(
const GraphDef& graph,
201 bool* out_already_initialized)
202 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
207 gtl::ArraySlice<string>
inputs,
208 gtl::ArraySlice<string>
outputs, gtl::ArraySlice<string> target_nodes,
215 const BuildGraphOptions&
options,
216 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
217 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
218 RunStateArgs* run_state_args, DataTypeVector* input_types,
219 DataTypeVector* output_types);
222 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
225 const Tensor& resource_tensor, Tensor* retrieved_tensor);
232 int64 timeout_in_ms);
233 void WaitForNotification(tbb::task_arena& arena, tbb::task_group&
group,
234 RunState* run_state, CancellationManager* cm,
235 int64 timeout_in_ms);
238 mutex_lock
l(closed_lock_);
239 if (closed_)
return errors::Cancelled(
"Session has been closed.");
244 const DebugOptions& debug_options, int64 session_run_index,
245 int64 executor_step_index,
const std::vector<string>&
input_names,
246 const std::vector<string>& output_names,
247 const std::vector<string>& target_names,
248 std::unique_ptr<DebuggerStateInterface>* debugger_state);
251 const DebugOptions& debug_options, Graph* graph, Device* device);
261 bool graph_created_ GUARDED_BY(graph_def_lock_) =
false;
264 GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
269 bool sync_on_finish_ =
true;
270 void SchedClosure(tbb::task_arena& arena, tbb::task_group&
g,
std::function<
void()>
c);
278 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
279 GUARDED_BY(executor_lock_);
282 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
283 GUARDED_BY(executor_lock_);
295 std::unordered_map<string, string> stateful_placements_
296 GUARDED_BY(graph_def_lock_);
299 std::unique_ptr<SimpleGraphExecutionState> execution_state_
300 GUARDED_BY(graph_def_lock_);
309 bool closed_ GUARDED_BY(closed_lock_) =
false;
312 std::atomic<int64> edge_name_counter_ = {0};
313 std::atomic<int64> handle_name_counter_ = {0};
319 const int64 operation_timeout_in_ms_ = 0;
324 Executor::Args::NodeOutputsCallback node_outputs_callback_ =
nullptr;
329 friend class DebugGateway;
334 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H std::unique_ptr< FunctionLibraryDefinition > flib_def_
std::vector< PerPartitionExecutorsAndLib > items
DataTypeVector output_types
std::unordered_map< string, size_t > input_name_to_index
ScopedStepContainer step_container
Notification executors_done
static boost::mutex mutex
std::vector< Device * > devices_
std::vector< std::pair< string, Tensor > > NamedTensorList
std::unique_ptr< FunctionLibraryRuntime > flib
RunStateArgs(const DebugOptions &options)
std::unordered_map< string, string > output_name_to_rendezvous_key
CostModelManager cost_model_manager_
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
std::unordered_map< string, bool > pending_inputs
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
std::unordered_map< StringPiece, Node *, StringPiece::Hasher > NameNodeMap
CancellationManager * cancellation_manager_
std::unique_ptr< StepStatsCollector > collector
std::unordered_map< string, size_t > output_name_to_index
const DebugOptions & debug_options
DataTypeVector input_types
std::pair< int, edm::FunctionWithDict > OK
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status CheckNotClosed()
std::atomic_int_fast64_t step_count
const SessionOptions options_
TBBSessionFactory *const factory_
static std::atomic_int_fast64_t step_id_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
std::unique_ptr< FunctionLibraryDefinition > flib_def
std::function< void(Session *)> CloseCallback
std::unique_ptr< Executor > executor
std::unique_ptr< Graph > graph
SessionState session_state_
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
std::unique_ptr< Graph > graph