CMS 3D CMS Logo

NTSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.6.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to NTSession (NT = non-threading)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "no_threads"
26  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
27  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
28  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
29  - Renamed the session factory class to NTSessionFactory
30  - Renamed the session registrar class to NTSessionRegistrar
31  - Renamed include guard to reflect location within CMSSW
32 */
33 
34 #ifndef PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
35 #define PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
36 
37 #include <atomic>
38 #include <memory>
39 #include <string>
40 #include <unordered_map>
41 #include <unordered_set>
42 #include <vector>
43 
44 #include "tensorflow/core/common_runtime/costmodel_manager.h"
45 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
46 #include "tensorflow/core/common_runtime/device_mgr.h"
47 #include "tensorflow/core/common_runtime/device_set.h"
48 #include "tensorflow/core/common_runtime/executor.h"
49 #include "tensorflow/core/common_runtime/graph_execution_state.h"
50 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
51 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
52 #include "tensorflow/core/common_runtime/session_factory.h"
53 #include "tensorflow/core/framework/cancellation.h"
54 #include "tensorflow/core/framework/graph.pb.h"
55 #include "tensorflow/core/framework/session_state.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/macros.h"
60 #include "tensorflow/core/platform/mutex.h"
61 #include "tensorflow/core/platform/types.h"
62 #include "tensorflow/core/public/session.h"
63 
64 namespace tensorflow {
65 
66  class CostModel;
67  class DebugGateway;
68  class Device;
69  class NTSessionFactory;
70 
71  class NTSession : public Session {
72  public:
73  typedef std::function<void(Session*)> CloseCallback;
74 
75  // Takes ownership of 'device_mgr'.
76  // 'factory' is used to unregister the NTSession with 'factory' when its
77  // closed. This ensures that Reset requests from the 'factory' don't get sent
78  // to sessions that are already closed.
79  NTSession(const SessionOptions& options, const DeviceMgr* device_mgr, NTSessionFactory* factory);
80  ~NTSession() override;
81 
82  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
83  typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
84 
85  ::tensorflow::Status Create(const GraphDef& graph) override;
86  ::tensorflow::Status Extend(const GraphDef& graph) override;
87  ::tensorflow::Status Run(const NamedTensorList& inputs,
88  const std::vector<string>& output_names,
89  const std::vector<string>& target_nodes,
90  std::vector<Tensor>* outputs) override;
91 
92  // NOTE: Experimental and subject to change.
93  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
94  const NamedTensorList& inputs,
95  const std::vector<string>& output_names,
96  const std::vector<string>& target_nodes,
97  std::vector<Tensor>* outputs,
98  RunMetadata* run_metadata) override;
99 
100  // NOTE: PRunSetup and PRun are added to support partial execution. This
101  // feature is experimental and subject to change.
102  ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
103  const std::vector<string>& output_names,
104  const std::vector<string>& target_nodes,
105  string* handle) override;
106  ::tensorflow::Status PRun(const string& handle,
107  const NamedTensorList& inputs,
108  const std::vector<string>& output_names,
109  std::vector<Tensor>* outputs) override;
110 
111  // Reset clears 'containers' from the device_mgr of the NTSession.
112  // If 'containers' is empty, then Reset clears the default container.
113  ::tensorflow::Status Reset(const std::vector<string>& containers);
114 
115  ::tensorflow::Status ListDevices(std::vector<DeviceAttributes>* response) override;
116  ::tensorflow::Status Close() override;
117  ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
118  *output = device_mgr_.get();
120  }
121 
122  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
123  cost_model_manager_.ExportCostModels(cost_models);
124  }
125 
126  private:
127  // We create one executor and its dependent library runtime for
128  // every partition.
130  Graph* graph = nullptr; // not owned.
131  Device* device = nullptr; // not owned.
132  FunctionLibraryRuntime* flib = nullptr; // not owned.
133  std::unique_ptr<Executor> executor;
134  };
135 
136  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
137  // 'step_count' is the number of times this graph is executed.
138  // 'graph' is the entire graph being executed. 'name_to_node'
139  // maps node name to node. We keep 'graph' and 'name_to_node' only in
140  // the case of partial runs. Each item in 'items' is the executor for
141  // a partition of the graph bundled with its dependent library runtime.
142  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
143  // are rendezvous keys for the fetches.
145  ExecutorsAndKeys() : step_count(0) {}
146 
147  std::atomic_int_fast64_t step_count;
148  std::unique_ptr<Graph> graph;
149  NameNodeMap name_to_node;
150  std::vector<PerPartitionExecutorsAndLib> items;
151  std::unordered_map<string, size_t> input_name_to_index;
152  std::unordered_map<string, string> input_name_to_rendezvous_key;
153  std::unordered_map<string, size_t> output_name_to_index;
154  std::unordered_map<string, string> output_name_to_rendezvous_key;
155 
156  DataTypeVector input_types;
157  DataTypeVector output_types;
158  };
159 
160  // A FunctionInfo object is created for every unique set of feeds/fetches.
161  // This info could be folded into the ExecutorsAndKeys object but we would
162  // like to maintain a deletion order in which the OpKernels (owned by the
163  // executor) should be destroyed first, followed by the resources in the
164  // device and then followed by the function stuff.
165  // TODO(rohanj): Consolidate function library definitions so that we can
166  // instantiate only one ProcFLR and lib_def and make this just a member
167  // variable and not a vector.
168  // 'flib_def' is the function library used.
169  // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
170  // device.
171  struct FunctionInfo {
172  std::unique_ptr<FunctionLibraryDefinition> flib_def;
173  std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
174  };
175 
176  // For each live partial execution, the session maintains a RunState.
177  // 'status' is the current status of this partial execution. 'executor_done'
178  // is "notified" when all executors are done. 'pending_inputs' are the set
179  // of pending feeds and 'pending_outputs' are the set of pending fetches.
180  struct RunState {
182  Status status GUARDED_BY(mu_);
183  IntraProcessRendezvous* rendez = nullptr;
184  std::unique_ptr<StepStatsCollector> collector;
185  Notification executors_done;
186  std::unordered_map<string, bool> pending_inputs; // true if fed
187  std::unordered_map<string, bool> pending_outputs; // true if fetched
188  TensorStore tensor_store;
189  ScopedStepContainer step_container;
190 
191  RunState(int64 step_id, const std::vector<Device*>* devices);
192 
193  RunState(const std::vector<string>& pending_input_names,
194  const std::vector<string>& pending_output_names,
195  int64 step_id,
196  const std::vector<Device*>* devices);
197 
198  // Returns true if all pending inputs and outputs have been completed.
199  bool PendingDone() const;
200 
201  ~RunState();
202  };
203 
204  struct RunStateArgs {
205  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
206 
207  bool is_partial_run = false;
208  string handle;
209  std::unique_ptr<Graph> graph;
210  const DebugOptions& debug_options;
211  };
212 
213  // Initializes the base execution state given the 'graph',
214  // if not already initialized.
215  Status MaybeInitializeExecutionState(const GraphDef& graph, bool* out_already_initialized)
216  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
217 
218  // Retrieves an already existing set of executors to run 'inputs' and
219  // 'outputs', or creates and caches them for future use.
220  ::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice<string> inputs,
221  gtl::ArraySlice<string> outputs,
222  gtl::ArraySlice<string> target_nodes,
223  ExecutorsAndKeys** executors_and_keys,
224  RunStateArgs* run_state_args);
225 
226  // Creates several graphs given the existing graph_def_ and the
227  // input feeds and fetches, given 'devices'. The graphs share a common
228  // function library 'flib_def'.
229  ::tensorflow::Status CreateGraphs(const BuildGraphOptions& options,
230  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
231  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
232  RunStateArgs* run_state_args,
233  DataTypeVector* input_types,
234  DataTypeVector* output_types);
235 
236  ::tensorflow::Status ExtendLocked(const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
237 
238  ::tensorflow::Status ResourceHandleToInputTensor(const Tensor& resource_tensor, Tensor* retrieved_tensor);
239 
240  // Feeds more inputs to the executors, triggering further execution.
241  ::tensorflow::Status SendPRunInputs(const std::vector<std::pair<string, Tensor>>& inputs,
242  const ExecutorsAndKeys* executors_and_keys,
243  IntraProcessRendezvous* rendez);
244 
245  // Fetches more outputs from the executors. It waits until the output
246  // tensors are computed.
247  ::tensorflow::Status RecvPRunOutputs(const std::vector<string>& output_names,
248  const ExecutorsAndKeys* executors_and_keys,
249  RunState* run_state,
250  std::vector<Tensor>* outputs);
251 
252  // Check if the specified fetches can be computed from the feeds
253  // that we have already provided.
254  ::tensorflow::Status CheckFetch(const std::vector<std::pair<string, Tensor>>& feeds,
255  const std::vector<string>& fetches,
256  const ExecutorsAndKeys* executors_and_keys,
257  const RunState* run_state);
258 
259  // Use the appropriate WaitForNotification function based on whether
260  // operation_timeout_in_ms is greater than 0.
261  //
262  // If the timeout expires, the `cm->StartCancel()` will be called.
263  ::tensorflow::Status WaitForNotification(Notification* n, int64 timeout_in_ms);
264  void WaitForNotification(RunState* run_state, CancellationManager* cm, int64 timeout_in_ms);
265 
267  mutex_lock l(closed_lock_);
268  if (closed_)
269  return errors::Cancelled("Session has been closed.");
271  }
272 
273  ::tensorflow::Status CreateDebuggerState(const DebugOptions& debug_options,
274  int64 session_run_index,
275  int64 executor_step_index,
276  const std::vector<string>& input_names,
277  const std::vector<string>& output_names,
278  const std::vector<string>& target_names,
279  std::unique_ptr<DebuggerStateInterface>* debugger_state);
280 
281  ::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
282  Graph* graph,
283  Device* device);
284 
285  const SessionOptions options_;
286 
287  // Device structures.
288  const std::unique_ptr<const DeviceMgr> device_mgr_;
289  std::vector<Device*> devices_; // not owned
290  DeviceSet device_set_;
291 
293  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
294 
296  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
297 
298  Status init_error_; // Set to an error if construction failed.
299 
300  // If true, blocks until device has finished all queued operations in a step.
301  bool sync_on_finish_ = true;
302  void SchedClosure(std::function<void()> c);
303 
304  std::vector<std::unique_ptr<FunctionInfo>> functions_ GUARDED_BY(executor_lock_);
305 
306  mutex executor_lock_; // protects executors_
307  // Holds mappings from signature to the executors that process
308  // it. The reason for a level of indirection around mapped_type is
309  // to guarantee address stability.
310  // The map value is a shared_ptr since multiple map keys can point to the
311  // same ExecutorsAndKey object.
312  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_ GUARDED_BY(executor_lock_);
313 
314  // Holds mappings from handle to partial run state.
315  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ GUARDED_BY(executor_lock_);
316 
317  // This holds all the tensors that are currently alive in the session.
318  SessionState session_state_;
319 
320  NTSessionFactory* const factory_; // not owned
321  CancellationManager* cancellation_manager_;
322 
323  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
324  // is true, such as "params" and "queue" nodes. Once placed these
325  // nodes can not be moved to a different device. Maps node names to
326  // device names.
327  std::unordered_map<string, string> stateful_placements_ GUARDED_BY(graph_def_lock_);
328 
329  // Execution_state; used when placing the entire graph.
330  std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(graph_def_lock_);
331 
332  // The function library, before any rewrites or optimizations have been
333  // performed. In particular, CreateGraphs() may need to modify the function
334  // library; it copies and modifies the function library.
335  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
336 
337  // true if the Session has been Closed.
339  bool closed_ GUARDED_BY(closed_lock_) = false;
340 
341  // For generating unique names for this session instance.
342  std::atomic<int64> edge_name_counter_ = {0};
343  std::atomic<int64> handle_name_counter_ = {0};
344 
345  // For generating step ids that are unique across all sessions.
346  static std::atomic_int_fast64_t step_id_counter_;
347 
348  // Global timeout for all blocking operations in this session.
349  const int64 operation_timeout_in_ms_ = 0;
350 
351  // Manages all the cost models for the graphs executed in this session.
352  CostModelManager cost_model_manager_;
353 
354  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
355 
357 
358  // EXPERIMENTAL: debugger (tfdbg) related
359  friend class DebugGateway;
360  };
361 
362 } // end namespace tensorflow
363 
364 #endif // PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
static boost::mutex mutex
Definition: Proxy.cc:9
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: NTSession.cc:321
std::vector< PerPartitionExecutorsAndLib > items
Definition: NTSession.h:150
static std::atomic_int_fast64_t step_id_counter_
Definition: NTSession.h:346
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: NTSession.cc:1320
TF_DISALLOW_COPY_AND_ASSIGN(NTSession)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:253
const SessionOptions options_
Definition: NTSession.h:285
std::unique_ptr< StepStatsCollector > collector
Definition: NTSession.h:184
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: NTSession.h:152
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: NTSession.cc:750
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
Definition: NTSession.cc:592
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: NTSession.cc:334
SessionState session_state_
Definition: NTSession.h:318
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: NTSession.h:154
RunStateArgs(const DebugOptions &options)
Definition: NTSession.h:205
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
Definition: NTSession.cc:775
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: NTSession.h:82
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
Definition: NTSession.h:173
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:659
~NTSession() override
Definition: NTSession.cc:231
DeviceSet device_set_
Definition: NTSession.h:290
std::unordered_map< string, bool > pending_outputs
Definition: NTSession.h:187
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:313
bool graph_created_ GUARDED_BY(graph_def_lock_)
::tensorflow::Status CheckNotClosed()
Definition: NTSession.h:266
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:299
std::function< void(Session *)> CloseCallback
Definition: NTSession.h:73
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: NTSession.cc:1398
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: NTSession.cc:1310
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
Definition: NTSession.cc:854
std::unordered_map< string, size_t > input_name_to_index
Definition: NTSession.h:151
std::atomic< int64 > edge_name_counter_
Definition: NTSession.h:342
::tensorflow::Status Close() override
Definition: NTSession.cc:1325
std::atomic< int64 > handle_name_counter_
Definition: NTSession.h:343
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: NTSession.h:288
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: NTSession.h:172
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:126
std::unique_ptr< Graph > graph
Definition: NTSession.h:209
std::unordered_map< string, size_t > output_name_to_index
Definition: NTSession.h:153
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
Definition: NTSession.cc:813
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: NTSession.cc:1158
ScopedStepContainer step_container
Definition: NTSession.h:189
std::vector< Device * > devices_
Definition: NTSession.h:289
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: NTSession.h:354
std::atomic_int_fast64_t step_count
Definition: NTSession.h:147
friend class DebugGateway
Definition: NTSession.h:359
void SchedClosure(std::function< void()> c)
Definition: NTSession.cc:189
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: NTSession.cc:918
std::unordered_map< string, bool > pending_inputs
Definition: NTSession.h:186
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
Definition: NTSession.cc:191
CancellationManager * cancellation_manager_
Definition: NTSession.h:321
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
Definition: NTSession.h:117
std::unique_ptr< Graph > graph
Definition: NTSession.h:148
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: NTSession.h:122
::tensorflow::Status Create(const GraphDef &graph) override
Definition: NTSession.cc:281
const int64 operation_timeout_in_ms_
Definition: NTSession.h:349
NTSessionFactory *const factory_
Definition: NTSession.h:320
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: NTSession.h:335
const DebugOptions & debug_options
Definition: NTSession.h:210
CostModelManager cost_model_manager_
Definition: NTSession.h:352
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: NTSession.cc:293
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: NTSession.h:83