CMS 3D CMS Logo

NTSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.6.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to NTSession (NT = non-threading)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "no_threads"
26  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
27  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
28  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
29  - Renamed the session factory class to NTSessionFactory
30  - Renamed the session registrar class to NTSessionRegistrar
31  - Renamed include guard to reflect location within CMSSW
32 */
33 
34 #ifndef PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
35 #define PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
36 
37 #include <atomic>
38 #include <memory>
39 #include <string>
40 #include <unordered_map>
41 #include <unordered_set>
42 #include <vector>
43 
44 #include "tensorflow/core/common_runtime/costmodel_manager.h"
45 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
46 #include "tensorflow/core/common_runtime/device_mgr.h"
47 #include "tensorflow/core/common_runtime/device_set.h"
48 #include "tensorflow/core/common_runtime/executor.h"
49 #include "tensorflow/core/common_runtime/graph_execution_state.h"
50 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
51 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
52 #include "tensorflow/core/common_runtime/session_factory.h"
53 #include "tensorflow/core/framework/cancellation.h"
54 #include "tensorflow/core/framework/graph.pb.h"
55 #include "tensorflow/core/framework/session_state.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/macros.h"
60 #include "tensorflow/core/platform/mutex.h"
61 #include "tensorflow/core/platform/types.h"
62 #include "tensorflow/core/public/session.h"
63 
64 namespace tensorflow {
65 
66 class CostModel;
67 class DebugGateway;
68 class Device;
69 class NTSessionFactory;
70 
71 class NTSession : public Session {
72  public:
73  typedef std::function<void(Session*)> CloseCallback;
74 
75  // Takes ownership of 'device_mgr'.
76  // 'factory' is used to unregister the NTSession with 'factory' when its
77  // closed. This ensures that Reset requests from the 'factory' don't get sent
78  // to sessions that are already closed.
79  NTSession(const SessionOptions& options, const DeviceMgr* device_mgr,
80  NTSessionFactory* factory);
81  ~NTSession() override;
82 
83  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
84  typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
85 
86  ::tensorflow::Status Create(const GraphDef& graph) override;
87  ::tensorflow::Status Extend(const GraphDef& graph) override;
88  ::tensorflow::Status Run(const NamedTensorList& inputs,
89  const std::vector<string>& output_names,
90  const std::vector<string>& target_nodes,
91  std::vector<Tensor>* outputs) override;
92 
93  // NOTE: Experimental and subject to change.
94  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
95  const NamedTensorList& inputs,
96  const std::vector<string>& output_names,
97  const std::vector<string>& target_nodes,
98  std::vector<Tensor>* outputs,
99  RunMetadata* run_metadata) override;
100 
101  // NOTE: PRunSetup and PRun are added to support partial execution. This
102  // feature is experimental and subject to change.
103  ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
104  const std::vector<string>& output_names,
105  const std::vector<string>& target_nodes,
106  string* handle) override;
107  ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
108  const std::vector<string>& output_names,
109  std::vector<Tensor>* outputs) override;
110 
111  // Reset clears 'containers' from the device_mgr of the NTSession.
112  // If 'containers' is empty, then Reset clears the default container.
113  ::tensorflow::Status Reset(const std::vector<string>& containers);
114 
116  std::vector<DeviceAttributes>* response) override;
117  ::tensorflow::Status Close() override;
118  ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
119  *output = device_mgr_.get();
121  }
122 
123  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
124  cost_model_manager_.ExportCostModels(cost_models);
125  }
126 
127  private:
128  // We create one executor and its dependent library runtime for
129  // every partition.
131  Graph* graph = nullptr; // not owned.
132  Device* device = nullptr; // not owned.
133  FunctionLibraryRuntime* flib = nullptr; // not owned.
134  std::unique_ptr<Executor> executor;
135  };
136 
137  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
138  // 'step_count' is the number of times this graph is executed.
139  // 'graph' is the entire graph being executed. 'name_to_node'
140  // maps node name to node. We keep 'graph' and 'name_to_node' only in
141  // the case of partial runs. Each item in 'items' is the executor for
142  // a partition of the graph bundled with its dependent library runtime.
143  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
144  // are rendezvous keys for the fetches.
146  ExecutorsAndKeys() : step_count(0) {}
147 
148  std::atomic_int_fast64_t step_count;
149  std::unique_ptr<Graph> graph;
150  NameNodeMap name_to_node;
151  std::vector<PerPartitionExecutorsAndLib> items;
152  std::unordered_map<string, size_t> input_name_to_index;
153  std::unordered_map<string, string> input_name_to_rendezvous_key;
154  std::unordered_map<string, size_t> output_name_to_index;
155  std::unordered_map<string, string> output_name_to_rendezvous_key;
156 
157  DataTypeVector input_types;
158  DataTypeVector output_types;
159  };
160 
161  // A FunctionInfo object is created for every unique set of feeds/fetches.
162  // This info could be folded into the ExecutorsAndKeys object but we would
163  // like to maintain a deletion order in which the OpKernels (owned by the
164  // executor) should be destroyed first, followed by the resources in the
165  // device and then followed by the function stuff.
166  // TODO(rohanj): Consolidate function library definitions so that we can
167  // instantiate only one ProcFLR and lib_def and make this just a member
168  // variable and not a vector.
169  // 'flib_def' is the function library used.
170  // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
171  // device.
172  struct FunctionInfo {
173  std::unique_ptr<FunctionLibraryDefinition> flib_def;
174  std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
175  };
176 
177  // For each live partial execution, the session maintains a RunState.
178  // 'status' is the current status of this partial execution. 'executor_done'
179  // is "notified" when all executors are done. 'pending_inputs' are the set
180  // of pending feeds and 'pending_outputs' are the set of pending fetches.
181  struct RunState {
183  Status status GUARDED_BY(mu_);
184  IntraProcessRendezvous* rendez = nullptr;
185  std::unique_ptr<StepStatsCollector> collector;
186  Notification executors_done;
187  std::unordered_map<string, bool> pending_inputs; // true if fed
188  std::unordered_map<string, bool> pending_outputs; // true if fetched
189  TensorStore tensor_store;
190  ScopedStepContainer step_container;
191 
192  RunState(int64 step_id, const std::vector<Device*>* devices);
193 
194  RunState(const std::vector<string>& pending_input_names,
195  const std::vector<string>& pending_output_names, int64 step_id,
196  const std::vector<Device*>* devices);
197 
198  // Returns true if all pending inputs and outputs have been completed.
199  bool PendingDone() const;
200 
201  ~RunState();
202  };
203 
204  struct RunStateArgs {
205  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
206 
207  bool is_partial_run = false;
208  string handle;
209  std::unique_ptr<Graph> graph;
210  const DebugOptions& debug_options;
211  };
212 
213  // Initializes the base execution state given the 'graph',
214  // if not already initialized.
216  bool* out_already_initialized)
217  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
218 
219  // Retrieves an already existing set of executors to run 'inputs' and
220  // 'outputs', or creates and caches them for future use.
222  gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
223  gtl::ArraySlice<string> target_nodes,
224  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
225 
226  // Creates several graphs given the existing graph_def_ and the
227  // input feeds and fetches, given 'devices'. The graphs share a common
228  // function library 'flib_def'.
230  const BuildGraphOptions& options,
231  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
232  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
233  RunStateArgs* run_state_args, DataTypeVector* input_types,
234  DataTypeVector* output_types);
235 
236  ::tensorflow::Status ExtendLocked(const GraphDef& graph)
237  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
238 
240  const Tensor& resource_tensor, Tensor* retrieved_tensor);
241 
242  // Feeds more inputs to the executors, triggering further execution.
244  const std::vector<std::pair<string, Tensor>>& inputs,
245  const ExecutorsAndKeys* executors_and_keys,
246  IntraProcessRendezvous* rendez);
247 
248  // Fetches more outputs from the executors. It waits until the output
249  // tensors are computed.
251  const std::vector<string>& output_names,
252  const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
253  std::vector<Tensor>* outputs);
254 
255  // Check if the specified fetches can be computed from the feeds
256  // that we have already provided.
258  const std::vector<std::pair<string, Tensor>>& feeds,
259  const std::vector<string>& fetches,
260  const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
261 
262  // Use the appropriate WaitForNotification function based on whether
263  // operation_timeout_in_ms is greater than 0.
264  //
265  // If the timeout expires, the `cm->StartCancel()` will be called.
267  int64 timeout_in_ms);
268  void WaitForNotification(RunState* run_state, CancellationManager* cm,
269  int64 timeout_in_ms);
270 
272  mutex_lock l(closed_lock_);
273  if (closed_) return errors::Cancelled("Session has been closed.");
275  }
276 
278  const DebugOptions& debug_options, int64 session_run_index,
279  int64 executor_step_index, const std::vector<string>& input_names,
280  const std::vector<string>& output_names,
281  const std::vector<string>& target_names,
282  std::unique_ptr<DebuggerStateInterface>* debugger_state);
283 
285  const DebugOptions& debug_options, Graph* graph, Device* device);
286 
287  const SessionOptions options_;
288 
289  // Device structures.
290  const std::unique_ptr<const DeviceMgr> device_mgr_;
291  std::vector<Device*> devices_; // not owned
292  DeviceSet device_set_;
293 
295  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
296 
298  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
299 
300  Status init_error_; // Set to an error if construction failed.
301 
302  // If true, blocks until device has finished all queued operations in a step.
303  bool sync_on_finish_ = true;
304  void SchedClosure(std::function<void()> c);
305 
306  std::vector<std::unique_ptr<FunctionInfo>> functions_
308 
309  mutex executor_lock_; // protects executors_
310  // Holds mappings from signature to the executors that process
311  // it. The reason for a level of indirection around mapped_type is
312  // to guarantee address stability.
313  // The map value is a shared_ptr since multiple map keys can point to the
314  // same ExecutorsAndKey object.
315  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
316  GUARDED_BY(executor_lock_);
317 
318  // Holds mappings from handle to partial run state.
319  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
320  GUARDED_BY(executor_lock_);
321 
322  // This holds all the tensors that are currently alive in the session.
323  SessionState session_state_;
324 
325  NTSessionFactory* const factory_; // not owned
326  CancellationManager* cancellation_manager_;
327 
328  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
329  // is true, such as "params" and "queue" nodes. Once placed these
330  // nodes can not be moved to a different device. Maps node names to
331  // device names.
332  std::unordered_map<string, string> stateful_placements_
333  GUARDED_BY(graph_def_lock_);
334 
335  // Execution_state; used when placing the entire graph.
336  std::unique_ptr<GraphExecutionState> execution_state_
337  GUARDED_BY(graph_def_lock_);
338 
339  // The function library, before any rewrites or optimizations have been
340  // performed. In particular, CreateGraphs() may need to modify the function
341  // library; it copies and modifies the function library.
342  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
343 
344  // true if the Session has been Closed.
346  bool closed_ GUARDED_BY(closed_lock_) = false;
347 
348  // For generating unique names for this session instance.
349  std::atomic<int64> edge_name_counter_ = {0};
350  std::atomic<int64> handle_name_counter_ = {0};
351 
352  // For generating step ids that are unique across all sessions.
353  static std::atomic_int_fast64_t step_id_counter_;
354 
355  // Global timeout for all blocking operations in this session.
356  const int64 operation_timeout_in_ms_ = 0;
357 
358  // Manages all the cost models for the graphs executed in this session.
359  CostModelManager cost_model_manager_;
360 
361  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
362 
364 
365  // EXPERIMENTAL: debugger (tfdbg) related
366  friend class DebugGateway;
367 };
368 
369 } // end namespace tensorflow
370 
371 #endif // PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
static boost::mutex mutex
Definition: Proxy.cc:11
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: NTSession.cc:334
std::vector< PerPartitionExecutorsAndLib > items
Definition: NTSession.h:151
static std::atomic_int_fast64_t step_id_counter_
Definition: NTSession.h:353
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: NTSession.cc:1386
TF_DISALLOW_COPY_AND_ASSIGN(NTSession)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:260
const SessionOptions options_
Definition: NTSession.h:287
std::unique_ptr< StepStatsCollector > collector
Definition: NTSession.h:185
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: NTSession.h:153
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: NTSession.cc:807
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
Definition: NTSession.cc:630
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: NTSession.cc:348
SessionState session_state_
Definition: NTSession.h:323
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: NTSession.h:155
RunStateArgs(const DebugOptions &options)
Definition: NTSession.h:205
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
Definition: NTSession.cc:833
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: NTSession.h:83
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
Definition: NTSession.h:174
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:707
~NTSession() override
Definition: NTSession.cc:239
DeviceSet device_set_
Definition: NTSession.h:292
std::unordered_map< string, bool > pending_outputs
Definition: NTSession.h:188
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:325
bool graph_created_ GUARDED_BY(graph_def_lock_)
::tensorflow::Status CheckNotClosed()
Definition: NTSession.h:271
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:310
std::function< void(Session *)> CloseCallback
Definition: NTSession.h:73
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: NTSession.cc:1465
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: NTSession.cc:1375
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
Definition: NTSession.cc:918
std::unordered_map< string, size_t > input_name_to_index
Definition: NTSession.h:152
std::atomic< int64 > edge_name_counter_
Definition: NTSession.h:349
::tensorflow::Status Close() override
Definition: NTSession.cc:1392
std::atomic< int64 > handle_name_counter_
Definition: NTSession.h:350
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: NTSession.h:290
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: NTSession.h:173
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unique_ptr< Graph > graph
Definition: NTSession.h:209
std::unordered_map< string, size_t > output_name_to_index
Definition: NTSession.h:154
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
Definition: NTSession.cc:872
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: NTSession.cc:1220
ScopedStepContainer step_container
Definition: NTSession.h:190
std::vector< Device * > devices_
Definition: NTSession.h:291
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: NTSession.h:361
std::atomic_int_fast64_t step_count
Definition: NTSession.h:148
friend class DebugGateway
Definition: NTSession.h:366
void SchedClosure(std::function< void()> c)
Definition: NTSession.cc:192
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: NTSession.cc:979
std::unordered_map< string, bool > pending_inputs
Definition: NTSession.h:187
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
Definition: NTSession.cc:196
CancellationManager * cancellation_manager_
Definition: NTSession.h:326
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
Definition: NTSession.h:118
std::unique_ptr< Graph > graph
Definition: NTSession.h:149
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: NTSession.h:123
::tensorflow::Status Create(const GraphDef &graph) override
Definition: NTSession.cc:291
const int64 operation_timeout_in_ms_
Definition: NTSession.h:356
NTSessionFactory *const factory_
Definition: NTSession.h:325
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: NTSession.h:342
const DebugOptions & debug_options
Definition: NTSession.h:210
CostModelManager cost_model_manager_
Definition: NTSession.h:359
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: NTSession.cc:304
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: NTSession.h:84