CMS 3D CMS Logo

NTSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.5.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to NTSession (NT = non-threading)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "no_threads"
26  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
27  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
28  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
29  - Renamed the session factory class to NTSessionFactory
30  - Renamed the session registrar class to NTSessionRegistrar
31  - Renamed include guard to reflect location within CMSSW
32 */
33 
34 #ifndef PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
35 #define PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
36 
37 #include <atomic>
38 #include <memory>
39 #include <string>
40 #include <unordered_map>
41 #include <unordered_set>
42 #include <vector>
43 
44 #include "tensorflow/core/common_runtime/costmodel_manager.h"
45 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
46 #include "tensorflow/core/common_runtime/device_mgr.h"
47 #include "tensorflow/core/common_runtime/device_set.h"
48 #include "tensorflow/core/common_runtime/executor.h"
49 #include "tensorflow/core/common_runtime/graph_execution_state.h"
50 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
51 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
52 #include "tensorflow/core/common_runtime/session_factory.h"
53 #include "tensorflow/core/framework/cancellation.h"
54 #include "tensorflow/core/framework/graph.pb.h"
55 #include "tensorflow/core/framework/session_state.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/macros.h"
60 #include "tensorflow/core/platform/mutex.h"
61 #include "tensorflow/core/platform/types.h"
62 #include "tensorflow/core/public/session.h"
63 
64 namespace tensorflow {
65 
66 class CostModel;
67 class DebugGateway;
68 class Device;
69 class NTSessionFactory;
70 
71 class NTSession : public Session {
72  public:
73  typedef std::function<void(Session*)> CloseCallback;
74 
75  // Takes ownership of 'device_mgr'.
76  // 'factory' is used to unregister the NTSession with 'factory' when its
77  // closed. This ensures that Reset requests from the 'factory' don't get sent
78  // to sessions that are already closed.
79  NTSession(const SessionOptions& options, const DeviceMgr* device_mgr,
80  NTSessionFactory* factory);
81  ~NTSession() override;
82 
83  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
84  typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
85 
86  ::tensorflow::Status Create(const GraphDef& graph) override;
87  ::tensorflow::Status Extend(const GraphDef& graph) override;
88  ::tensorflow::Status Run(const NamedTensorList& inputs,
89  const std::vector<string>& output_names,
90  const std::vector<string>& target_nodes,
91  std::vector<Tensor>* outputs) override;
92 
93  // NOTE: Experimental and subject to change.
94  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
95  const NamedTensorList& inputs,
96  const std::vector<string>& output_names,
97  const std::vector<string>& target_nodes,
98  std::vector<Tensor>* outputs,
99  RunMetadata* run_metadata) override;
100 
101  // NOTE: PRunSetup and PRun are added to support partial execution. This
102  // feature is experimental and subject to change.
103  ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
104  const std::vector<string>& output_names,
105  const std::vector<string>& target_nodes,
106  string* handle) override;
107  ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
108  const std::vector<string>& output_names,
109  std::vector<Tensor>* outputs) override;
110 
111  // Reset clears 'containers' from the device_mgr of the NTSession.
112  // If 'containers' is empty, then Reset clears the default container.
113  ::tensorflow::Status Reset(const std::vector<string>& containers);
114 
116  std::vector<DeviceAttributes>* response) override;
117  ::tensorflow::Status Close() override;
118  ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
119  *output = device_mgr_.get();
121  }
122 
123  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
124  cost_model_manager_.ExportCostModels(cost_models);
125  }
126 
127  private:
128  // We create one executor and its dependent library runtime for
129  // every partition.
131  Graph* graph = nullptr; // not owned.
132  Device* device = nullptr; // not owned.
133  FunctionLibraryRuntime* flib = nullptr; // not owned.
134  std::unique_ptr<Executor> executor;
135  };
136 
137  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
138  // 'step_count' is the number of times this graph is executed.
139  // 'graph' is the entire graph being executed. 'name_to_node'
140  // maps node name to node. We keep 'graph' and 'name_to_node' only in
141  // the case of partial runs. Each item in 'items' is the executor for
142  // a partition of the graph bundled with its dependent library runtime.
143  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
144  // are rendezvous keys for the fetches.
145  // 'flib_def' is the function library used by graphs in 'items'.
146  // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
147  // device.
148  // TODO(phawkins): currently partitions always share the same function
149  // library. Consider giving each partition its own function library to enable
150  // per-partition rewrites.
152  ExecutorsAndKeys() : step_count(0) {}
153 
154  std::atomic_int_fast64_t step_count;
155  std::unique_ptr<Graph> graph;
156  NameNodeMap name_to_node;
157  std::unique_ptr<FunctionLibraryDefinition> flib_def;
158  std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
159  std::vector<PerPartitionExecutorsAndLib> items;
160  std::unordered_map<string, size_t> input_name_to_index;
161  std::unordered_map<string, string> input_name_to_rendezvous_key;
162  std::unordered_map<string, size_t> output_name_to_index;
163  std::unordered_map<string, string> output_name_to_rendezvous_key;
164 
165  DataTypeVector input_types;
166  DataTypeVector output_types;
167  };
168 
169  // For each live partial execution, the session maintains a RunState.
170  // 'status' is the current status of this partial execution. 'executor_done'
171  // is "notified" when all executors are done. 'pending_inputs' are the set
172  // of pending feeds and 'pending_outputs' are the set of pending fetches.
173  struct RunState {
175  Status status GUARDED_BY(mu_);
176  IntraProcessRendezvous* rendez = nullptr;
177  std::unique_ptr<StepStatsCollector> collector;
178  Notification executors_done;
179  std::unordered_map<string, bool> pending_inputs; // true if fed
180  std::unordered_map<string, bool> pending_outputs; // true if fetched
181  TensorStore tensor_store;
182  ScopedStepContainer step_container;
183 
184  RunState(int64 step_id, const std::vector<Device*>* devices);
185 
186  RunState(const std::vector<string>& pending_input_names,
187  const std::vector<string>& pending_output_names, int64 step_id,
188  const std::vector<Device*>* devices);
189 
190  // Returns true if all pending inputs and outputs have been completed.
191  bool PendingDone() const;
192 
193  ~RunState();
194  };
195 
196  struct RunStateArgs {
197  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
198 
199  bool is_partial_run = false;
200  string handle;
201  std::unique_ptr<Graph> graph;
202  const DebugOptions& debug_options;
203  };
204 
205  // Initializes the base execution state given the 'graph',
206  // if not already initialized.
208  bool* out_already_initialized)
209  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
210 
211  // Retrieves an already existing set of executors to run 'inputs' and
212  // 'outputs', or creates and caches them for future use.
214  gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
215  gtl::ArraySlice<string> target_nodes,
216  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
217 
218  // Creates several graphs given the existing graph_def_ and the
219  // input feeds and fetches, given 'devices'. The graphs share a common
220  // function library 'flib_def'.
222  const BuildGraphOptions& options,
223  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
224  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
225  RunStateArgs* run_state_args, DataTypeVector* input_types,
226  DataTypeVector* output_types);
227 
228  ::tensorflow::Status ExtendLocked(const GraphDef& graph)
229  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
230 
232  const Tensor& resource_tensor, Tensor* retrieved_tensor);
233 
234  // Feeds more inputs to the executors, triggering further execution.
236  const std::vector<std::pair<string, Tensor>>& inputs,
237  const ExecutorsAndKeys* executors_and_keys,
238  IntraProcessRendezvous* rendez);
239 
240  // Fetches more outputs from the executors. It waits until the output
241  // tensors are computed.
243  const std::vector<string>& output_names,
244  const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
245  std::vector<Tensor>* outputs);
246 
247  // Check if the specified fetches can be computed from the feeds
248  // that we have already provided.
250  const std::vector<std::pair<string, Tensor>>& feeds,
251  const std::vector<string>& fetches,
252  const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
253 
254  // Use the appropriate WaitForNotification function based on whether
255  // operation_timeout_in_ms is greater than 0.
256  //
257  // If the timeout expires, the `cm->StartCancel()` will be called.
259  int64 timeout_in_ms);
260  void WaitForNotification(RunState* run_state, CancellationManager* cm,
261  int64 timeout_in_ms);
262 
264  mutex_lock l(closed_lock_);
265  if (closed_) return errors::Cancelled("Session has been closed.");
267  }
268 
270  const DebugOptions& debug_options, int64 session_run_index,
271  int64 executor_step_index, const std::vector<string>& input_names,
272  const std::vector<string>& output_names,
273  const std::vector<string>& target_names,
274  std::unique_ptr<DebuggerStateInterface>* debugger_state);
275 
277  const DebugOptions& debug_options, Graph* graph, Device* device);
278 
279  const SessionOptions options_;
280 
281  // Device structures.
282  const std::unique_ptr<const DeviceMgr> device_mgr_;
283  std::vector<Device*> devices_; // not owned
284  DeviceSet device_set_;
285 
287  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
288 
290  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
291 
292  Status init_error_; // Set to an error if construction failed.
293 
294  // If true, blocks until device has finished all queued operations in a step.
295  bool sync_on_finish_ = true;
296  void SchedClosure(std::function<void()> c);
297 
298  mutex executor_lock_; // protects executors_
299  // Holds mappings from signature to the executors that process
300  // it. The reason for a level of indirection around mapped_type is
301  // to guarantee address stability.
302  // The map value is a shared_ptr since multiple map keys can point to the
303  // same ExecutorsAndKey object.
304  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
305  GUARDED_BY(executor_lock_);
306 
307  // Holds mappings from handle to partial run state.
308  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
309  GUARDED_BY(executor_lock_);
310 
311  // This holds all the tensors that are currently alive in the session.
312  SessionState session_state_;
313 
314  NTSessionFactory* const factory_; // not owned
315  CancellationManager* cancellation_manager_;
316 
317  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
318  // is true, such as "params" and "queue" nodes. Once placed these
319  // nodes can not be moved to a different device. Maps node names to
320  // device names.
321  std::unordered_map<string, string> stateful_placements_
322  GUARDED_BY(graph_def_lock_);
323 
324  // Execution_state; used when placing the entire graph.
325  std::unique_ptr<GraphExecutionState> execution_state_
326  GUARDED_BY(graph_def_lock_);
327 
328  // The function library, before any rewrites or optimizations have been
329  // performed. In particular, CreateGraphs() may need to modify the function
330  // library; it copies and modifies the function library.
331  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
332 
333  // true if the Session has been Closed.
335  bool closed_ GUARDED_BY(closed_lock_) = false;
336 
337  // For generating unique names for this session instance.
338  std::atomic<int64> edge_name_counter_ = {0};
339  std::atomic<int64> handle_name_counter_ = {0};
340 
341  // For generating step ids that are unique across all sessions.
342  static std::atomic_int_fast64_t step_id_counter_;
343 
344  // Global timeout for all blocking operations in this session.
345  const int64 operation_timeout_in_ms_ = 0;
346 
347  // Manages all the cost models for the graphs executed in this session.
348  CostModelManager cost_model_manager_;
349 
350  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
351 
353 
354  // EXPERIMENTAL: debugger (tfdbg) related
355  friend class DebugGateway;
356 };
357 
358 } // end namespace tensorflow
359 
360 #endif // PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: NTSession.h:157
static boost::mutex mutex
Definition: Proxy.cc:11
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: NTSession.cc:331
std::vector< PerPartitionExecutorsAndLib > items
Definition: NTSession.h:159
static std::atomic_int_fast64_t step_id_counter_
Definition: NTSession.h:342
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: NTSession.cc:1380
TF_DISALLOW_COPY_AND_ASSIGN(NTSession)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:257
const SessionOptions options_
Definition: NTSession.h:279
std::unique_ptr< StepStatsCollector > collector
Definition: NTSession.h:177
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: NTSession.h:161
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: NTSession.cc:803
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
Definition: NTSession.cc:626
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: NTSession.cc:345
SessionState session_state_
Definition: NTSession.h:312
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: NTSession.h:163
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
Definition: NTSession.h:158
RunStateArgs(const DebugOptions &options)
Definition: NTSession.h:197
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
Definition: NTSession.cc:829
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: NTSession.h:83
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:703
~NTSession() override
Definition: NTSession.cc:240
DeviceSet device_set_
Definition: NTSession.h:284
std::unordered_map< string, bool > pending_outputs
Definition: NTSession.h:180
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:322
bool graph_created_ GUARDED_BY(graph_def_lock_)
::tensorflow::Status CheckNotClosed()
Definition: NTSession.h:263
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:307
std::function< void(Session *)> CloseCallback
Definition: NTSession.h:73
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: NTSession.cc:1459
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: NTSession.cc:1369
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
Definition: NTSession.cc:914
std::unordered_map< string, size_t > input_name_to_index
Definition: NTSession.h:160
std::atomic< int64 > edge_name_counter_
Definition: NTSession.h:338
::tensorflow::Status Close() override
Definition: NTSession.cc:1386
std::atomic< int64 > handle_name_counter_
Definition: NTSession.h:339
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: NTSession.h:282
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unique_ptr< Graph > graph
Definition: NTSession.h:201
std::unordered_map< string, size_t > output_name_to_index
Definition: NTSession.h:162
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
Definition: NTSession.cc:868
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: NTSession.cc:1214
ScopedStepContainer step_container
Definition: NTSession.h:182
std::vector< Device * > devices_
Definition: NTSession.h:283
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: NTSession.h:350
std::atomic_int_fast64_t step_count
Definition: NTSession.h:154
friend class DebugGateway
Definition: NTSession.h:355
void SchedClosure(std::function< void()> c)
Definition: NTSession.cc:193
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: NTSession.cc:975
std::unordered_map< string, bool > pending_inputs
Definition: NTSession.h:179
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
Definition: NTSession.cc:197
CancellationManager * cancellation_manager_
Definition: NTSession.h:315
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
Definition: NTSession.h:118
std::unique_ptr< Graph > graph
Definition: NTSession.h:155
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: NTSession.h:123
::tensorflow::Status Create(const GraphDef &graph) override
Definition: NTSession.cc:288
const int64 operation_timeout_in_ms_
Definition: NTSession.h:345
NTSessionFactory *const factory_
Definition: NTSession.h:314
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: NTSession.h:331
const DebugOptions & debug_options
Definition: NTSession.h:202
CostModelManager cost_model_manager_
Definition: NTSession.h:348
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: NTSession.cc:301
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: NTSession.h:84