36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H 42 #include <unordered_map> 43 #include <unordered_set> 46 #include "tbb/task_arena.h" 48 #include "tensorflow/core/common_runtime/costmodel_manager.h" 49 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 50 #include "tensorflow/core/common_runtime/device_mgr.h" 51 #include "tensorflow/core/common_runtime/device_set.h" 52 #include "tensorflow/core/common_runtime/executor.h" 53 #include "tensorflow/core/common_runtime/graph_execution_state.h" 54 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 55 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 56 #include "tensorflow/core/common_runtime/session_factory.h" 57 #include "tensorflow/core/framework/cancellation.h" 58 #include "tensorflow/core/framework/graph.pb.h" 59 #include "tensorflow/core/framework/session_state.h" 60 #include "tensorflow/core/framework/tensor.h" 61 #include "tensorflow/core/lib/core/errors.h" 62 #include "tensorflow/core/lib/core/status.h" 63 #include "tensorflow/core/platform/macros.h" 64 #include "tensorflow/core/platform/mutex.h" 65 #include "tensorflow/core/platform/types.h" 66 #include "tensorflow/core/public/session.h" 79 class TBBSessionFactory;
93 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher>
NameNodeMap;
98 const std::vector<string>& output_names,
99 const std::vector<string>& target_nodes,
100 std::vector<Tensor>*
outputs)
override;
104 const NamedTensorList&
inputs,
105 const std::vector<string>& output_names,
106 const std::vector<string>& target_nodes,
108 RunMetadata* run_metadata)
override;
117 *output = device_mgr_.get();
122 cost_model_manager_.ExportCostModels(cost_models);
130 Device* device =
nullptr;
131 FunctionLibraryRuntime* flib =
nullptr;
149 std::vector<PerPartitionExecutorsAndLib>
items;
171 std::unique_ptr<FunctionLibraryDefinition>
flib_def;
172 std::unique_ptr<ProcessFunctionLibraryRuntime>
proc_flr;
182 IntraProcessRendezvous* rendez =
nullptr;
190 RunState(int64 step_id,
const std::vector<Device*>* devices);
192 RunState(
const std::vector<string>& pending_input_names,
193 const std::vector<string>& pending_output_names,
195 const std::vector<Device*>* devices);
198 bool PendingDone()
const;
206 bool is_partial_run =
false;
214 Status MaybeInitializeExecutionState(
const GraphDef& graph,
bool* out_already_initialized)
215 EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
220 gtl::ArraySlice<string>
outputs,
221 gtl::ArraySlice<string> target_nodes,
229 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
230 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
232 DataTypeVector* input_types,
233 DataTypeVector* output_types);
235 ::tensorflow::Status ExtendLocked(
const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
237 ::tensorflow::Status ResourceHandleToInputTensor(
const Tensor& resource_tensor, Tensor* retrieved_tensor);
244 void WaitForNotification(tbb::task_arena& arena,
245 tbb::task_group&
group,
247 CancellationManager* cm,
248 int64 timeout_in_ms);
251 mutex_lock
l(closed_lock_);
253 return errors::Cancelled(
"Session has been closed.");
258 int64 session_run_index,
259 int64 executor_step_index,
261 const std::vector<string>& output_names,
262 const std::vector<string>& target_names,
263 std::unique_ptr<DebuggerStateInterface>* debugger_state);
277 bool graph_created_ GUARDED_BY(graph_def_lock_) =
false;
280 GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
285 bool sync_on_finish_ =
true;
286 void SchedClosure(tbb::task_arena& arena, tbb::task_group&
g,
std::function<
void()>
c);
288 std::vector<std::unique_ptr<FunctionInfo>> functions_ GUARDED_BY(executor_lock_);
296 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_ GUARDED_BY(executor_lock_);
299 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ GUARDED_BY(executor_lock_);
311 std::unordered_map<string, string> stateful_placements_ GUARDED_BY(graph_def_lock_);
314 std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(graph_def_lock_);
323 bool closed_ GUARDED_BY(closed_lock_) =
false;
326 std::atomic<int64> edge_name_counter_ = {0};
327 std::atomic<int64> handle_name_counter_ = {0};
333 const int64 operation_timeout_in_ms_ = 0;
338 Executor::Args::NodeOutputsCallback node_outputs_callback_ =
nullptr;
343 friend class DebugGateway;
348 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H std::unique_ptr< FunctionLibraryDefinition > flib_def_
static boost::mutex mutex
std::vector< PerPartitionExecutorsAndLib > items
DataTypeVector output_types
std::unordered_map< string, size_t > input_name_to_index
ScopedStepContainer step_container
Notification executors_done
std::vector< Device * > devices_
std::vector< std::pair< string, Tensor > > NamedTensorList
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
std::unique_ptr< FunctionLibraryDefinition > flib_def
RunStateArgs(const DebugOptions &options)
std::unordered_map< string, string > output_name_to_rendezvous_key
CostModelManager cost_model_manager_
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
std::unordered_map< string, bool > pending_inputs
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
CancellationManager * cancellation_manager_
std::unique_ptr< StepStatsCollector > collector
std::unordered_map< string, size_t > output_name_to_index
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
const DebugOptions & debug_options
DataTypeVector input_types
std::pair< int, edm::FunctionWithDict > OK
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status CheckNotClosed()
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
std::atomic_int_fast64_t step_count
const SessionOptions options_
TBBSessionFactory *const factory_
static std::atomic_int_fast64_t step_id_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
std::function< void(Session *)> CloseCallback
std::unique_ptr< Executor > executor
std::unique_ptr< Graph > graph
SessionState session_state_
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
std::unique_ptr< Graph > graph