36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 42 #include <unordered_map> 43 #include <unordered_set> 46 #include "tbb/task_arena.h" 48 #include "tensorflow/core/common_runtime/costmodel_manager.h" 49 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 50 #include "tensorflow/core/common_runtime/device_mgr.h" 51 #include "tensorflow/core/common_runtime/device_set.h" 52 #include "tensorflow/core/common_runtime/executor.h" 53 #include "tensorflow/core/common_runtime/graph_execution_state.h" 54 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 55 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 56 #include "tensorflow/core/common_runtime/session_factory.h" 57 #include "tensorflow/core/framework/cancellation.h" 58 #include "tensorflow/core/framework/graph.pb.h" 59 #include "tensorflow/core/framework/session_state.h" 60 #include "tensorflow/core/framework/tensor.h" 61 #include "tensorflow/core/lib/core/errors.h" 62 #include "tensorflow/core/lib/core/status.h" 63 #include "tensorflow/core/platform/macros.h" 64 #include "tensorflow/core/platform/mutex.h" 65 #include "tensorflow/core/platform/types.h" 66 #include "tensorflow/core/public/session.h" 79 class TBBSessionFactory;
94 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher>
NameNodeMap;
99 const std::vector<string>& output_names,
100 const std::vector<string>& target_nodes,
101 std::vector<Tensor>*
outputs)
override;
105 const NamedTensorList&
inputs,
106 const std::vector<string>& output_names,
107 const std::vector<string>& target_nodes,
109 RunMetadata* run_metadata)
override;
116 std::vector<DeviceAttributes>* response)
override;
119 *output = device_mgr_.get();
124 cost_model_manager_.ExportCostModels(cost_models);
132 Device* device =
nullptr;
133 FunctionLibraryRuntime* flib =
nullptr;
151 std::vector<PerPartitionExecutorsAndLib>
items;
173 std::unique_ptr<FunctionLibraryDefinition>
flib_def;
174 std::unique_ptr<ProcessFunctionLibraryRuntime>
proc_flr;
184 IntraProcessRendezvous* rendez =
nullptr;
192 RunState(int64 step_id,
const std::vector<Device*>* devices);
194 RunState(
const std::vector<string>& pending_input_names,
195 const std::vector<string>& pending_output_names, int64 step_id,
196 const std::vector<Device*>* devices);
199 bool PendingDone()
const;
207 bool is_partial_run =
false;
215 Status MaybeInitializeExecutionState(
const GraphDef& graph,
216 bool* out_already_initialized)
217 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
222 gtl::ArraySlice<string>
inputs, gtl::ArraySlice<string>
outputs,
223 gtl::ArraySlice<string> target_nodes,
230 const BuildGraphOptions&
options,
231 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
232 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
233 RunStateArgs* run_state_args, DataTypeVector* input_types,
234 DataTypeVector* output_types);
237 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
240 const Tensor& resource_tensor, Tensor* retrieved_tensor);
247 int64 timeout_in_ms);
248 void WaitForNotification(tbb::task_arena& arena, tbb::task_group&
group,
249 RunState* run_state, CancellationManager* cm, int64 timeout_in_ms);
252 mutex_lock
l(closed_lock_);
253 if (closed_)
return errors::Cancelled(
"Session has been closed.");
258 const DebugOptions& debug_options, int64 session_run_index,
259 int64 executor_step_index,
const std::vector<string>& input_names,
260 const std::vector<string>& output_names,
261 const std::vector<string>& target_names,
262 std::unique_ptr<DebuggerStateInterface>* debugger_state);
265 const DebugOptions& debug_options,
Graph* graph, Device* device);
275 bool graph_created_ GUARDED_BY(graph_def_lock_) =
false;
278 GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
283 bool sync_on_finish_ =
true;
284 void SchedClosure(tbb::task_arena& arena, tbb::task_group&
g,
std::function<
void()>
c);
286 std::vector<std::unique_ptr<FunctionInfo>> functions_
287 GUARDED_BY(executor_lock_);
295 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
296 GUARDED_BY(executor_lock_);
299 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
300 GUARDED_BY(executor_lock_);
312 std::unordered_map<string, string> stateful_placements_
313 GUARDED_BY(graph_def_lock_);
316 std::unique_ptr<GraphExecutionState> execution_state_
317 GUARDED_BY(graph_def_lock_);
326 bool closed_ GUARDED_BY(closed_lock_) =
false;
329 std::atomic<int64> edge_name_counter_ = {0};
330 std::atomic<int64> handle_name_counter_ = {0};
336 const int64 operation_timeout_in_ms_ = 0;
341 Executor::Args::NodeOutputsCallback node_outputs_callback_ =
nullptr;
346 friend class DebugGateway;
351 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H std::unique_ptr< FunctionLibraryDefinition > flib_def_
static boost::mutex mutex
std::vector< PerPartitionExecutorsAndLib > items
DataTypeVector output_types
std::unordered_map< string, size_t > input_name_to_index
ScopedStepContainer step_container
Notification executors_done
std::vector< Device * > devices_
std::vector< std::pair< string, Tensor > > NamedTensorList
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
std::unique_ptr< FunctionLibraryDefinition > flib_def
RunStateArgs(const DebugOptions &options)
std::unordered_map< string, string > output_name_to_rendezvous_key
CostModelManager cost_model_manager_
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
std::unordered_map< string, bool > pending_inputs
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
CancellationManager * cancellation_manager_
std::unique_ptr< StepStatsCollector > collector
std::unordered_map< string, size_t > output_name_to_index
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
const DebugOptions & debug_options
DataTypeVector input_types
std::pair< int, edm::FunctionWithDict > OK
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status CheckNotClosed()
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
std::atomic_int_fast64_t step_count
const SessionOptions options_
TBBSessionFactory *const factory_
static std::atomic_int_fast64_t step_id_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
std::function< void(Session *)> CloseCallback
std::unique_ptr< Executor > executor
std::unique_ptr< Graph > graph
SessionState session_state_
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
std::unique_ptr< Graph > graph