34 #ifndef PHYSICSTOOLS_TENSORFLOW_NTSESSION_H 35 #define PHYSICSTOOLS_TENSORFLOW_NTSESSION_H 40 #include <unordered_map> 41 #include <unordered_set> 44 #include "tensorflow/core/common_runtime/costmodel_manager.h" 45 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 46 #include "tensorflow/core/common_runtime/device_mgr.h" 47 #include "tensorflow/core/common_runtime/device_set.h" 48 #include "tensorflow/core/common_runtime/executor.h" 49 #include "tensorflow/core/common_runtime/graph_execution_state.h" 50 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 51 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 52 #include "tensorflow/core/common_runtime/session_factory.h" 53 #include "tensorflow/core/framework/cancellation.h" 54 #include "tensorflow/core/framework/graph.pb.h" 55 #include "tensorflow/core/framework/session_state.h" 56 #include "tensorflow/core/framework/tensor.h" 57 #include "tensorflow/core/lib/core/errors.h" 58 #include "tensorflow/core/lib/core/status.h" 59 #include "tensorflow/core/platform/macros.h" 60 #include "tensorflow/core/platform/mutex.h" 61 #include "tensorflow/core/platform/types.h" 62 #include "tensorflow/core/public/session.h" 69 class NTSessionFactory;
83 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher>
NameNodeMap;
88 const std::vector<string>& output_names,
89 const std::vector<string>& target_nodes,
90 std::vector<Tensor>*
outputs)
override;
94 const NamedTensorList&
inputs,
95 const std::vector<string>& output_names,
96 const std::vector<string>& target_nodes,
98 RunMetadata* run_metadata)
override;
103 const std::vector<string>& output_names,
104 const std::vector<string>& target_nodes,
107 const NamedTensorList&
inputs,
108 const std::vector<string>& output_names,
109 std::vector<Tensor>*
outputs)
override;
132 FunctionLibraryRuntime*
flib =
nullptr;
150 std::vector<PerPartitionExecutorsAndLib>
items;
172 std::unique_ptr<FunctionLibraryDefinition>
flib_def;
173 std::unique_ptr<ProcessFunctionLibraryRuntime>
proc_flr;
183 IntraProcessRendezvous* rendez =
nullptr;
191 RunState(int64 step_id,
const std::vector<Device*>* devices);
193 RunState(
const std::vector<string>& pending_input_names,
194 const std::vector<string>& pending_output_names,
196 const std::vector<Device*>* devices);
199 bool PendingDone()
const;
207 bool is_partial_run =
false;
221 gtl::ArraySlice<string>
outputs,
222 gtl::ArraySlice<string> target_nodes,
230 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
231 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
233 DataTypeVector* input_types,
234 DataTypeVector* output_types);
243 IntraProcessRendezvous* rendez);
255 const std::vector<string>& fetches,
269 return errors::Cancelled(
"Session has been closed.");
274 int64 session_run_index,
275 int64 executor_step_index,
277 const std::vector<string>& output_names,
278 const std::vector<string>& target_names,
279 std::unique_ptr<DebuggerStateInterface>* debugger_state);
296 GraphDef graph_def_
GUARDED_BY(graph_def_lock_);
312 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
GUARDED_BY(executor_lock_);
315 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
GUARDED_BY(executor_lock_);
327 std::unordered_map<string, string> stateful_placements_
GUARDED_BY(graph_def_lock_);
330 std::unique_ptr<GraphExecutionState> execution_state_
GUARDED_BY(graph_def_lock_);
339 bool closed_
GUARDED_BY(closed_lock_) =
false;
364 #endif // PHYSICSTOOLS_TENSORFLOW_NTSESSION_H DataTypeVector output_types
static boost::mutex mutex
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
std::vector< PerPartitionExecutorsAndLib > items
static std::atomic_int_fast64_t step_id_counter_
::tensorflow::Status Reset(const std::vector< string > &containers)
FunctionLibraryRuntime * flib
std::unique_ptr< Executor > executor
TF_DISALLOW_COPY_AND_ASSIGN(NTSession)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
const SessionOptions options_
std::unique_ptr< StepStatsCollector > collector
std::unordered_map< string, string > input_name_to_rendezvous_key
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
SessionState session_state_
std::unordered_map< string, string > output_name_to_rendezvous_key
RunStateArgs(const DebugOptions &options)
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
std::vector< std::pair< string, Tensor > > NamedTensorList
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
DataTypeVector input_types
Notification executors_done
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
bool graph_created_ GUARDED_BY(graph_def_lock_)
::tensorflow::Status CheckNotClosed()
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
std::function< void(Session *)> CloseCallback
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
std::unordered_map< string, size_t > input_name_to_index
std::atomic< int64 > edge_name_counter_
::tensorflow::Status Close() override
std::atomic< int64 > handle_name_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
std::unique_ptr< FunctionLibraryDefinition > flib_def
std::pair< int, edm::FunctionWithDict > OK
std::unique_ptr< Graph > graph
std::unordered_map< string, size_t > output_name_to_index
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
ScopedStepContainer step_container
std::vector< Device * > devices_
Executor::Args::NodeOutputsCallback node_outputs_callback_
std::atomic_int_fast64_t step_count
friend class DebugGateway
void SchedClosure(std::function< void()> c)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
std::unordered_map< string, bool > pending_inputs
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
CancellationManager * cancellation_manager_
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
std::unique_ptr< Graph > graph
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
::tensorflow::Status Create(const GraphDef &graph) override
const int64 operation_timeout_in_ms_
NTSessionFactory *const factory_
std::unique_ptr< FunctionLibraryDefinition > flib_def_
const DebugOptions & debug_options
CostModelManager cost_model_manager_
::tensorflow::Status Extend(const GraphDef &graph) override
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap