36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 42 #include <unordered_map> 43 #include <unordered_set> 46 #include "tbb/task_arena.h" 48 #include "tensorflow/core/common_runtime/costmodel_manager.h" 49 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 50 #include "tensorflow/core/common_runtime/device_mgr.h" 51 #include "tensorflow/core/common_runtime/device_set.h" 52 #include "tensorflow/core/common_runtime/executor.h" 53 #include "tensorflow/core/common_runtime/graph_execution_state.h" 54 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 55 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 56 #include "tensorflow/core/common_runtime/session_factory.h" 57 #include "tensorflow/core/framework/cancellation.h" 58 #include "tensorflow/core/framework/graph.pb.h" 59 #include "tensorflow/core/framework/session_state.h" 60 #include "tensorflow/core/framework/tensor.h" 61 #include "tensorflow/core/lib/core/errors.h" 62 #include "tensorflow/core/lib/core/status.h" 63 #include "tensorflow/core/platform/macros.h" 64 #include "tensorflow/core/platform/mutex.h" 65 #include "tensorflow/core/platform/types.h" 66 #include "tensorflow/core/public/session.h" 79 class TBBSessionFactory;
94 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher>
NameNodeMap;
99 const std::vector<string>& output_names,
100 const std::vector<string>& target_nodes,
101 std::vector<Tensor>*
outputs)
override;
105 const NamedTensorList&
inputs,
106 const std::vector<string>& output_names,
107 const std::vector<string>& target_nodes,
109 RunMetadata* run_metadata)
override;
116 std::vector<DeviceAttributes>* response)
override;
119 *output = device_mgr_.get();
124 cost_model_manager_.ExportCostModels(cost_models);
131 Graph* graph =
nullptr;
132 Device* device =
nullptr;
133 FunctionLibraryRuntime* flib =
nullptr;
157 std::unique_ptr<FunctionLibraryDefinition>
flib_def;
158 std::unique_ptr<ProcessFunctionLibraryRuntime>
proc_flr;
159 std::vector<PerPartitionExecutorsAndLib>
items;
176 IntraProcessRendezvous* rendez =
nullptr;
184 RunState(int64 step_id,
const std::vector<Device*>* devices);
186 RunState(
const std::vector<string>& pending_input_names,
187 const std::vector<string>& pending_output_names, int64 step_id,
188 const std::vector<Device*>* devices);
191 bool PendingDone()
const;
199 bool is_partial_run =
false;
207 Status MaybeInitializeExecutionState(
const GraphDef& graph,
208 bool* out_already_initialized)
209 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
214 gtl::ArraySlice<string>
inputs, gtl::ArraySlice<string>
outputs,
215 gtl::ArraySlice<string> target_nodes,
222 const BuildGraphOptions&
options,
223 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
224 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
225 RunStateArgs* run_state_args, DataTypeVector* input_types,
226 DataTypeVector* output_types);
229 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
232 const Tensor& resource_tensor, Tensor* retrieved_tensor);
239 int64 timeout_in_ms);
240 void WaitForNotification(tbb::task_arena& arena, tbb::task_group&
group,
241 RunState* run_state, CancellationManager* cm, int64 timeout_in_ms);
244 mutex_lock
l(closed_lock_);
245 if (closed_)
return errors::Cancelled(
"Session has been closed.");
250 const DebugOptions& debug_options, int64 session_run_index,
251 int64 executor_step_index,
const std::vector<string>& input_names,
252 const std::vector<string>& output_names,
253 const std::vector<string>& target_names,
254 std::unique_ptr<DebuggerStateInterface>* debugger_state);
257 const DebugOptions& debug_options, Graph* graph, Device* device);
267 bool graph_created_ GUARDED_BY(graph_def_lock_) =
false;
270 GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
275 bool sync_on_finish_ =
true;
276 void SchedClosure(tbb::task_arena& arena, tbb::task_group&
g,
std::function<
void()>
c);
284 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
285 GUARDED_BY(executor_lock_);
288 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
289 GUARDED_BY(executor_lock_);
301 std::unordered_map<string, string> stateful_placements_
302 GUARDED_BY(graph_def_lock_);
305 std::unique_ptr<GraphExecutionState> execution_state_
306 GUARDED_BY(graph_def_lock_);
315 bool closed_ GUARDED_BY(closed_lock_) =
false;
318 std::atomic<int64> edge_name_counter_ = {0};
319 std::atomic<int64> handle_name_counter_ = {0};
325 const int64 operation_timeout_in_ms_ = 0;
330 Executor::Args::NodeOutputsCallback node_outputs_callback_ =
nullptr;
335 friend class DebugGateway;
340 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H std::unique_ptr< FunctionLibraryDefinition > flib_def_
static boost::mutex mutex
std::vector< PerPartitionExecutorsAndLib > items
DataTypeVector output_types
std::unordered_map< string, size_t > input_name_to_index
ScopedStepContainer step_container
Notification executors_done
std::vector< Device * > devices_
std::vector< std::pair< string, Tensor > > NamedTensorList
RunStateArgs(const DebugOptions &options)
std::unordered_map< string, string > output_name_to_rendezvous_key
CostModelManager cost_model_manager_
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
std::unordered_map< string, bool > pending_inputs
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
CancellationManager * cancellation_manager_
std::unique_ptr< StepStatsCollector > collector
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
std::unordered_map< string, size_t > output_name_to_index
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
const DebugOptions & debug_options
DataTypeVector input_types
std::pair< int, edm::FunctionWithDict > OK
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status CheckNotClosed()
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
std::atomic_int_fast64_t step_count
const SessionOptions options_
TBBSessionFactory *const factory_
static std::atomic_int_fast64_t step_id_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
std::unique_ptr< FunctionLibraryDefinition > flib_def
std::function< void(Session *)> CloseCallback
std::unique_ptr< Executor > executor
std::unique_ptr< Graph > graph
SessionState session_state_
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
std::unique_ptr< Graph > graph