34 #ifndef PHYSICSTOOLS_TENSORFLOW_NTSESSION_H 35 #define PHYSICSTOOLS_TENSORFLOW_NTSESSION_H 40 #include <unordered_map> 41 #include <unordered_set> 44 #include "tensorflow/core/common_runtime/costmodel_manager.h" 45 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 46 #include "tensorflow/core/common_runtime/device_mgr.h" 47 #include "tensorflow/core/common_runtime/device_set.h" 48 #include "tensorflow/core/common_runtime/executor.h" 49 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 50 #include "tensorflow/core/common_runtime/session_factory.h" 51 #include "tensorflow/core/common_runtime/simple_graph_execution_state.h" 52 #include "tensorflow/core/framework/cancellation.h" 53 #include "tensorflow/core/framework/graph.pb.h" 54 #include "tensorflow/core/framework/session_state.h" 55 #include "tensorflow/core/framework/tensor.h" 56 #include "tensorflow/core/lib/core/errors.h" 57 #include "tensorflow/core/lib/core/status.h" 58 #include "tensorflow/core/platform/macros.h" 59 #include "tensorflow/core/platform/mutex.h" 60 #include "tensorflow/core/platform/types.h" 61 #include "tensorflow/core/public/session.h" 68 class NTSessionFactory;
83 typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher>
89 const std::vector<string>& output_names,
90 const std::vector<string>& target_nodes,
91 std::vector<Tensor>*
outputs)
override;
95 const NamedTensorList&
inputs,
96 const std::vector<string>& output_names,
97 const std::vector<string>& target_nodes,
99 RunMetadata* run_metadata)
override;
104 const std::vector<string>& output_names,
105 const std::vector<string>& target_nodes,
108 const std::vector<string>& output_names,
109 std::vector<Tensor>*
outputs)
override;
116 std::vector<DeviceAttributes>* response)
override;
130 std::unique_ptr<FunctionLibraryRuntime>
flib;
152 std::unique_ptr<FunctionLibraryDefinition>
flib_def;
153 std::vector<PerPartitionExecutorsAndLib>
items;
170 IntraProcessRendezvous* rendez =
nullptr;
178 RunState(int64 step_id,
const std::vector<Device*>* devices);
180 RunState(
const std::vector<string>& pending_input_names,
181 const std::vector<string>& pending_output_names, int64 step_id,
182 const std::vector<Device*>* devices);
185 bool PendingDone()
const;
193 bool is_partial_run =
false;
202 bool* out_already_initialized)
208 gtl::ArraySlice<string>
inputs,
209 gtl::ArraySlice<string>
outputs, gtl::ArraySlice<string> target_nodes,
216 const BuildGraphOptions&
options,
217 std::unordered_map<
string, std::unique_ptr<Graph>>*
outputs,
218 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
219 RunStateArgs* run_state_args, DataTypeVector* input_types,
220 DataTypeVector* output_types);
226 const Tensor& resource_tensor, Tensor* retrieved_tensor);
230 const std::vector<std::pair<string, Tensor>>&
inputs,
232 IntraProcessRendezvous* rendez);
237 const std::vector<string>& output_names,
244 const std::vector<std::pair<string, Tensor>>& feeds,
245 const std::vector<string>& fetches,
253 int64 timeout_in_ms);
255 int64 timeout_in_ms);
259 if (closed_)
return errors::Cancelled(
"Session has been closed.");
264 const DebugOptions& debug_options, int64 session_run_index,
265 int64 executor_step_index,
const std::vector<string>&
input_names,
266 const std::vector<string>& output_names,
267 const std::vector<string>& target_names,
268 std::unique_ptr<DebuggerStateInterface>* debugger_state);
271 const DebugOptions& debug_options, Graph*
graph, Device* device);
284 GraphDef graph_def_
GUARDED_BY(graph_def_lock_);
298 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
302 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
315 std::unordered_map<string, string> stateful_placements_
319 std::unique_ptr<SimpleGraphExecutionState> execution_state_
329 bool closed_
GUARDED_BY(closed_lock_) =
false;
354 #endif // PHYSICSTOOLS_TENSORFLOW_NTSESSION_H std::unique_ptr< FunctionLibraryDefinition > flib_def
DataTypeVector output_types
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
std::vector< PerPartitionExecutorsAndLib > items
static std::atomic_int_fast64_t step_id_counter_
::tensorflow::Status Reset(const std::vector< string > &containers)
std::unique_ptr< Executor > executor
TF_DISALLOW_COPY_AND_ASSIGN(NTSession)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
const SessionOptions options_
std::unique_ptr< StepStatsCollector > collector
static boost::mutex mutex
std::unordered_map< string, string > input_name_to_rendezvous_key
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
SessionState session_state_
std::unordered_map< string, string > output_name_to_rendezvous_key
RunStateArgs(const DebugOptions &options)
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
std::vector< std::pair< string, Tensor > > NamedTensorList
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
DataTypeVector input_types
Notification executors_done
std::unordered_map< string, bool > pending_outputs
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
bool graph_created_ GUARDED_BY(graph_def_lock_)
::tensorflow::Status CheckNotClosed()
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
std::function< void(Session *)> CloseCallback
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
std::unordered_map< string, size_t > input_name_to_index
std::atomic< int64 > edge_name_counter_
::tensorflow::Status Close() override
std::atomic< int64 > handle_name_counter_
const std::unique_ptr< const DeviceMgr > device_mgr_
std::pair< int, edm::FunctionWithDict > OK
std::unique_ptr< Graph > graph
std::unordered_map< string, size_t > output_name_to_index
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
ScopedStepContainer step_container
std::vector< Device * > devices_
Executor::Args::NodeOutputsCallback node_outputs_callback_
std::atomic_int_fast64_t step_count
std::unique_ptr< FunctionLibraryRuntime > flib
friend class DebugGateway
void SchedClosure(std::function< void()> c)
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
std::unordered_map< string, bool > pending_inputs
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
CancellationManager * cancellation_manager_
std::unique_ptr< Graph > graph
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
::tensorflow::Status Create(const GraphDef &graph) override
std::unordered_map< StringPiece, Node *, StringPiece::Hasher > NameNodeMap
const int64 operation_timeout_in_ms_
NTSessionFactory *const factory_
std::unique_ptr< FunctionLibraryDefinition > flib_def_
const DebugOptions & debug_options
CostModelManager cost_model_manager_
::tensorflow::Status Extend(const GraphDef &graph) override