34 #ifndef PHYSICSTOOLS_TENSORFLOW_NTSESSION_H 35 #define PHYSICSTOOLS_TENSORFLOW_NTSESSION_H 40 #include <unordered_map> 41 #include <unordered_set> 44 #include "tensorflow/core/common_runtime/costmodel_manager.h" 45 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 46 #include "tensorflow/core/common_runtime/device_mgr.h" 47 #include "tensorflow/core/common_runtime/device_set.h" 48 #include "tensorflow/core/common_runtime/executor.h" 49 #include "tensorflow/core/common_runtime/graph_execution_state.h" 50 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 51 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 52 #include "tensorflow/core/common_runtime/session_factory.h" 53 #include "tensorflow/core/framework/cancellation.h" 54 #include "tensorflow/core/framework/graph.pb.h" 55 #include "tensorflow/core/framework/session_state.h" 56 #include "tensorflow/core/framework/tensor.h" 57 #include "tensorflow/core/lib/core/errors.h" 58 #include "tensorflow/core/lib/core/status.h" 59 #include "tensorflow/core/platform/macros.h" 60 #include "tensorflow/core/platform/mutex.h" 61 #include "tensorflow/core/platform/types.h" 62 #include "tensorflow/core/public/session.h" 69 class NTSessionFactory;
84 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher>
NameNodeMap;
89 const std::vector<string>& output_names,
90 const std::vector<string>& target_nodes,
91 std::vector<Tensor>*
outputs)
override;
95 const NamedTensorList&
inputs,
96 const std::vector<string>& output_names,
97 const std::vector<string>& target_nodes,
99 RunMetadata* run_metadata)
override;
104 const std::vector<string>& output_names,
105 const std::vector<string>& target_nodes,
108 const std::vector<string>& output_names,
109 std::vector<Tensor>*
outputs)
override;
116 std::vector<DeviceAttributes>* response)
override;
133 FunctionLibraryRuntime*
flib =
nullptr;
157 std::unique_ptr<FunctionLibraryDefinition>
flib_def;
158 std::unique_ptr<ProcessFunctionLibraryRuntime>
proc_flr;
159 std::vector<PerPartitionExecutorsAndLib>
items;
176 IntraProcessRendezvous* rendez =
nullptr;
184 RunState(int64 step_id,
const std::vector<Device*>* devices);
186 RunState(
const std::vector<string>& pending_input_names,
187 const std::vector<string>& pending_output_names, int64 step_id,
188 const std::vector<Device*>* devices);
191 bool PendingDone()
const;
199 bool is_partial_run =
false;
208 bool* out_already_initialized)
214 gtl::ArraySlice<string>
inputs, gtl::ArraySlice<string>
outputs,
215 gtl::ArraySlice<string> target_nodes,
222 const BuildGraphOptions&
options,
223 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
224 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
225 RunStateArgs* run_state_args, DataTypeVector* input_types,
226 DataTypeVector* output_types);
232 const Tensor& resource_tensor, Tensor* retrieved_tensor);
236 const std::vector<std::pair<string, Tensor>>&
inputs,
238 IntraProcessRendezvous* rendez);
243 const std::vector<string>& output_names,
250 const std::vector<std::pair<string, Tensor>>& feeds,
251 const std::vector<string>& fetches,
259 int64 timeout_in_ms);
261 int64 timeout_in_ms);
265 if (closed_)
return errors::Cancelled(
"Session has been closed.");
270 const DebugOptions& debug_options, int64 session_run_index,
271 int64 executor_step_index,
const std::vector<string>& input_names,
272 const std::vector<string>& output_names,
273 const std::vector<string>& target_names,
274 std::unique_ptr<DebuggerStateInterface>* debugger_state);
277 const DebugOptions& debug_options, Graph*
graph, Device*
device);
290 GraphDef graph_def_
GUARDED_BY(graph_def_lock_);
304 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
308 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
321 std::unordered_map<string, string> stateful_placements_
325 std::unique_ptr<GraphExecutionState> execution_state_
335 bool closed_
GUARDED_BY(closed_lock_) =
false;
360 #endif // PHYSICSTOOLS_TENSORFLOW_NTSESSION_H std::unique_ptr< FunctionLibraryDefinition > flib_def
DataTypeVector output_types
static boost::mutex mutex
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
std::vector< PerPartitionExecutorsAndLib > items
static std::atomic_int_fast64_t step_id_counter_
::tensorflow::Status Reset(const std::vector< string > &containers)
FunctionLibraryRuntime * flib
std::unique_ptr< Executor > executor
TF_DISALLOW_COPY_AND_ASSIGN(NTSession)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
const SessionOptions options_
std::unique_ptr< StepStatsCollector > collector
std::unordered_map< string, string > input_name_to_rendezvous_key
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
SessionState session_state_
std::unordered_map< string, string > output_name_to_rendezvous_key
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
RunStateArgs(const DebugOptions &options)
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
std::vector< std::pair< string, Tensor > > NamedTensorList
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
DataTypeVector input_types
Notification executors_done
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
bool graph_created_ GUARDED_BY(graph_def_lock_)
::tensorflow::Status CheckNotClosed()
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
std::function< void(Session *)> CloseCallback
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
std::unordered_map< string, size_t > input_name_to_index
std::atomic< int64 > edge_name_counter_
::tensorflow::Status Close() override
std::atomic< int64 > handle_name_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
std::pair< int, edm::FunctionWithDict > OK
std::unique_ptr< Graph > graph
std::unordered_map< string, size_t > output_name_to_index
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
ScopedStepContainer step_container
std::vector< Device * > devices_
Executor::Args::NodeOutputsCallback node_outputs_callback_
std::atomic_int_fast64_t step_count
friend class DebugGateway
void SchedClosure(std::function< void()> c)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
std::unordered_map< string, bool > pending_inputs
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
CancellationManager * cancellation_manager_
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
std::unique_ptr< Graph > graph
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
::tensorflow::Status Create(const GraphDef &graph) override
const int64 operation_timeout_in_ms_
NTSessionFactory *const factory_
std::unique_ptr< FunctionLibraryDefinition > flib_def_
const DebugOptions & debug_options
CostModelManager cost_model_manager_
::tensorflow::Status Extend(const GraphDef &graph) override
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap