CMS 3D CMS Logo

NTSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.3.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to NTSession (NT = non-threading)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "no_threads"
26  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
27  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
28  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
29  - Renamed the session factory class to NTSessionFactory
30  - Renamed the session registrar class to NTSessionRegistrar
31  - Renamed include guard to reflect location within CMSSW
32 */
33 
34 #ifndef PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
35 #define PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
36 
37 #include <atomic>
38 #include <memory>
39 #include <string>
40 #include <unordered_map>
41 #include <unordered_set>
42 #include <vector>
43 
44 #include "tensorflow/core/common_runtime/costmodel_manager.h"
45 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
46 #include "tensorflow/core/common_runtime/device_mgr.h"
47 #include "tensorflow/core/common_runtime/device_set.h"
48 #include "tensorflow/core/common_runtime/executor.h"
49 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
50 #include "tensorflow/core/common_runtime/session_factory.h"
51 #include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
52 #include "tensorflow/core/framework/cancellation.h"
53 #include "tensorflow/core/framework/graph.pb.h"
54 #include "tensorflow/core/framework/session_state.h"
55 #include "tensorflow/core/framework/tensor.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/status.h"
58 #include "tensorflow/core/platform/macros.h"
59 #include "tensorflow/core/platform/mutex.h"
60 #include "tensorflow/core/platform/types.h"
61 #include "tensorflow/core/public/session.h"
62 
63 namespace tensorflow {
64 
65 class CostModel;
66 class DebugGateway;
67 class Device;
68 class NTSessionFactory;
69 
70 class NTSession : public Session {
71  public:
72  typedef std::function<void(Session*)> CloseCallback;
73 
74  // Takes ownership of 'device_mgr'.
75  // 'factory' is used to unregister the NTSession with 'factory' when its
76  // closed. This ensures that Reset requests from the 'factory' don't get sent
77  // to sessions that are already closed.
78  NTSession(const SessionOptions& options, const DeviceMgr* device_mgr,
79  NTSessionFactory* factory);
80  ~NTSession() override;
81 
82  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
83  typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher>
85 
86  ::tensorflow::Status Create(const GraphDef& graph) override;
87  ::tensorflow::Status Extend(const GraphDef& graph) override;
88  ::tensorflow::Status Run(const NamedTensorList& inputs,
89  const std::vector<string>& output_names,
90  const std::vector<string>& target_nodes,
91  std::vector<Tensor>* outputs) override;
92 
93  // NOTE: Experimental and subject to change.
94  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
95  const NamedTensorList& inputs,
96  const std::vector<string>& output_names,
97  const std::vector<string>& target_nodes,
98  std::vector<Tensor>* outputs,
99  RunMetadata* run_metadata) override;
100 
101  // NOTE: PRunSetup and PRun are added to support partial execution. This
102  // feature is experimental and subject to change.
103  ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
104  const std::vector<string>& output_names,
105  const std::vector<string>& target_nodes,
106  string* handle) override;
107  ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
108  const std::vector<string>& output_names,
109  std::vector<Tensor>* outputs) override;
110 
111  // Reset clears 'containers' from the device_mgr of the NTSession.
112  // If 'containers' is empty, then Reset clears the default container.
113  ::tensorflow::Status Reset(const std::vector<string>& containers);
114 
116  std::vector<DeviceAttributes>* response) override;
117  ::tensorflow::Status Close() override;
118 
119  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
120  cost_model_manager_.ExportCostModels(cost_models);
121  }
122 
123  private:
124  typedef NTSession ME;
125 
126  // We create one executor and its dependent library runtime for
127  // every partition.
129  Graph* graph = nullptr;
130  std::unique_ptr<FunctionLibraryRuntime> flib;
131  std::unique_ptr<Executor> executor;
132  };
133 
134  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
135  // 'step_count' is the number of times this graph is executed.
136  // 'graph' is the entire graph being executed. 'name_to_node'
137  // maps node name to node. We keep 'graph' and 'name_to_node' only in
138  // the case of partial runs. Each item in 'items' is the executor for
139  // a partition of the graph bundled with its dependent library runtime.
140  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
141  // are rendezvous keys for the fetches.
142  // 'flib_def' is the function library used by graphs in 'items'.
143  // TODO(phawkins): currently partitions always share the same function
144  // library. Consider giving each partition its own function library to enable
145  // per-partition rewrites.
147  ExecutorsAndKeys() : step_count(0) {}
148 
149  std::atomic_int_fast64_t step_count;
150  std::unique_ptr<Graph> graph;
152  std::unique_ptr<FunctionLibraryDefinition> flib_def;
153  std::vector<PerPartitionExecutorsAndLib> items;
154  std::unordered_map<string, size_t> input_name_to_index;
155  std::unordered_map<string, string> input_name_to_rendezvous_key;
156  std::unordered_map<string, size_t> output_name_to_index;
157  std::unordered_map<string, string> output_name_to_rendezvous_key;
158 
159  DataTypeVector input_types;
160  DataTypeVector output_types;
161  };
162 
163  // For each live partial execution, the session maintains a RunState.
164  // 'status' is the current status of this partial execution. 'executor_done'
165  // is "notified" when all executors are done. 'pending_inputs' are the set
166  // of pending feeds and 'pending_outputs' are the set of pending fetches.
167  struct RunState {
169  Status status GUARDED_BY(mu_);
170  IntraProcessRendezvous* rendez = nullptr;
171  std::unique_ptr<StepStatsCollector> collector;
172  Notification executors_done;
173  std::unordered_map<string, bool> pending_inputs; // true if fed
174  std::unordered_map<string, bool> pending_outputs; // true if fetched
175  TensorStore tensor_store;
176  ScopedStepContainer step_container;
177 
178  RunState(int64 step_id, const std::vector<Device*>* devices);
179 
180  RunState(const std::vector<string>& pending_input_names,
181  const std::vector<string>& pending_output_names, int64 step_id,
182  const std::vector<Device*>* devices);
183 
184  // Returns true if all pending inputs and outputs have been completed.
185  bool PendingDone() const;
186 
187  ~RunState();
188  };
189 
190  struct RunStateArgs {
191  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
192 
193  bool is_partial_run = false;
194  string handle;
195  std::unique_ptr<Graph> graph;
196  const DebugOptions& debug_options;
197  };
198 
199  // Initializes the base execution state given the 'graph',
200  // if not already initialized.
202  bool* out_already_initialized)
203  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
204 
205  // Retrieves an already existing set of executors to run 'inputs' and
206  // 'outputs', or creates and caches them for future use.
208  gtl::ArraySlice<string> inputs,
209  gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes,
210  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
211 
212  // Creates several graphs given the existing graph_def_ and the
213  // input feeds and fetches, given 'devices'. The graphs share a common
214  // function library 'flib_def'.
216  const BuildGraphOptions& options,
217  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
218  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
219  RunStateArgs* run_state_args, DataTypeVector* input_types,
220  DataTypeVector* output_types);
221 
222  ::tensorflow::Status ExtendLocked(const GraphDef& graph)
223  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
224 
226  const Tensor& resource_tensor, Tensor* retrieved_tensor);
227 
228  // Feeds more inputs to the executors, triggering further execution.
230  const std::vector<std::pair<string, Tensor>>& inputs,
231  const ExecutorsAndKeys* executors_and_keys,
232  IntraProcessRendezvous* rendez);
233 
234  // Fetches more outputs from the executors. It waits until the output
235  // tensors are computed.
237  const std::vector<string>& output_names,
238  const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
239  std::vector<Tensor>* outputs);
240 
241  // Check if the specified fetches can be computed from the feeds
242  // that we have already provided.
244  const std::vector<std::pair<string, Tensor>>& feeds,
245  const std::vector<string>& fetches,
246  const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
247 
248  // Use the appropriate WaitForNotification function based on whether
249  // operation_timeout_in_ms is greater than 0.
250  //
251  // If the timeout expires, the `cm->StartCancel()` will be called.
253  int64 timeout_in_ms);
254  void WaitForNotification(RunState* run_state, CancellationManager* cm,
255  int64 timeout_in_ms);
256 
258  mutex_lock l(closed_lock_);
259  if (closed_) return errors::Cancelled("Session has been closed.");
261  }
262 
264  const DebugOptions& debug_options, int64 session_run_index,
265  int64 executor_step_index, const std::vector<string>& input_names,
266  const std::vector<string>& output_names,
267  const std::vector<string>& target_names,
268  std::unique_ptr<DebuggerStateInterface>* debugger_state);
269 
271  const DebugOptions& debug_options, Graph* graph, Device* device);
272 
273  const SessionOptions options_;
274 
275  // Device structures.
276  const std::unique_ptr<const DeviceMgr> device_mgr_;
277  std::vector<Device*> devices_; // not owned
278  DeviceSet device_set_;
279 
281  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
282 
284  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
285 
286  Status init_error_; // Set to an error if construction failed.
287 
288  // If true, blocks until device has finished all queued operations in a step.
289  bool sync_on_finish_ = true;
290  void SchedClosure(std::function<void()> c);
291 
292  mutex executor_lock_; // protects executors_
293  // Holds mappings from signature to the executors that process
294  // it. The reason for a level of indirection around mapped_type is
295  // to guarantee address stability.
296  // The map value is a shared_ptr since multiple map keys can point to the
297  // same ExecutorsAndKey object.
298  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
299  GUARDED_BY(executor_lock_);
300 
301  // Holds mappings from handle to partial run state.
302  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
303  GUARDED_BY(executor_lock_);
304 
305  // This holds all the tensors that are currently alive in the session.
306  SessionState session_state_;
307 
308  NTSessionFactory* const factory_; // not owned
309  CancellationManager* cancellation_manager_;
310 
311  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
312  // is true, such as "params" and "queue" nodes. Once placed these
313  // nodes can not be moved to a different device. Maps node names to
314  // device names.
315  std::unordered_map<string, string> stateful_placements_
316  GUARDED_BY(graph_def_lock_);
317 
318  // Execution_state; used when placing the entire graph.
319  std::unique_ptr<SimpleGraphExecutionState> execution_state_
320  GUARDED_BY(graph_def_lock_);
321 
322  // The function library, before any rewrites or optimizations have been
323  // performed. In particular, CreateGraphs() may need to modify the function
324  // library; it copies and modifies the function library.
325  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
326 
327  // true if the Session has been Closed.
329  bool closed_ GUARDED_BY(closed_lock_) = false;
330 
331  // For generating unique names for this session instance.
332  std::atomic<int64> edge_name_counter_ = {0};
333  std::atomic<int64> handle_name_counter_ = {0};
334 
335  // For generating step ids that are unique across all sessions.
336  static std::atomic_int_fast64_t step_id_counter_;
337 
338  // Global timeout for all blocking operations in this session.
339  const int64 operation_timeout_in_ms_ = 0;
340 
341  // Manages all the cost models for the graphs executed in this session.
342  CostModelManager cost_model_manager_;
343 
344  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
345 
347 
348  // EXPERIMENTAL: debugger (tfdbg) related
349  friend class DebugGateway;
350 };
351 
352 } // end namespace tensorflow
353 
354 #endif // PHYSICSTOOLS_TENSORFLOW_NTSESSION_H
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: NTSession.h:152
::tensorflow::Status CreateDebuggerState(const DebugOptions &debug_options, int64 session_run_index, int64 executor_step_index, const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_names, std::unique_ptr< DebuggerStateInterface > *debugger_state)
Definition: NTSession.cc:342
std::vector< PerPartitionExecutorsAndLib > items
Definition: NTSession.h:153
static std::atomic_int_fast64_t step_id_counter_
Definition: NTSession.h:336
::tensorflow::Status Reset(const std::vector< string > &containers)
Definition: NTSession.cc:1350
TF_DISALLOW_COPY_AND_ASSIGN(NTSession)
Status MaybeInitializeExecutionState(const GraphDef &graph, bool *out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:268
const SessionOptions options_
Definition: NTSession.h:273
std::unique_ptr< StepStatsCollector > collector
Definition: NTSession.h:171
static boost::mutex mutex
Definition: LHEProxy.cc:11
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: NTSession.h:155
::tensorflow::Status ResourceHandleToInputTensor(const Tensor &resource_tensor, Tensor *retrieved_tensor)
Definition: NTSession.cc:794
::tensorflow::Status PRunSetup(const std::vector< string > &input_names, const std::vector< string > &output_names, const std::vector< string > &target_nodes, string *handle) override
Definition: NTSession.cc:617
::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions &debug_options, Graph *graph, Device *device)
Definition: NTSession.cc:356
SessionState session_state_
Definition: NTSession.h:306
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: NTSession.h:157
RunStateArgs(const DebugOptions &options)
Definition: NTSession.h:191
::tensorflow::Status SendPRunInputs(const std::vector< std::pair< string, Tensor >> &inputs, const ExecutorsAndKeys *executors_and_keys, IntraProcessRendezvous *rendez)
Definition: NTSession.cc:815
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: NTSession.h:82
::tensorflow::Status PRun(const string &handle, const NamedTensorList &inputs, const std::vector< string > &output_names, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:694
~NTSession() override
Definition: NTSession.cc:251
DeviceSet device_set_
Definition: NTSession.h:278
std::unordered_map< string, bool > pending_outputs
Definition: NTSession.h:174
::tensorflow::Status Run(const NamedTensorList &inputs, const std::vector< string > &output_names, const std::vector< string > &target_nodes, std::vector< Tensor > *outputs) override
Definition: NTSession.cc:333
bool graph_created_ GUARDED_BY(graph_def_lock_)
::tensorflow::Status CheckNotClosed()
Definition: NTSession.h:257
::tensorflow::Status ExtendLocked(const GraphDef &graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_)
Definition: NTSession.cc:318
std::function< void(Session *)> CloseCallback
Definition: NTSession.h:72
::tensorflow::Status WaitForNotification(Notification *n, int64 timeout_in_ms)
Definition: NTSession.cc:1429
::tensorflow::Status ListDevices(std::vector< DeviceAttributes > *response) override
Definition: NTSession.cc:1339
::tensorflow::Status CheckFetch(const std::vector< std::pair< string, Tensor >> &feeds, const std::vector< string > &fetches, const ExecutorsAndKeys *executors_and_keys, const RunState *run_state)
Definition: NTSession.cc:900
std::unordered_map< string, size_t > input_name_to_index
Definition: NTSession.h:154
std::atomic< int64 > edge_name_counter_
Definition: NTSession.h:332
::tensorflow::Status Close() override
Definition: NTSession.cc:1356
std::atomic< int64 > handle_name_counter_
Definition: NTSession.h:333
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: NTSession.h:276
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unique_ptr< Graph > graph
Definition: NTSession.h:195
std::unordered_map< string, size_t > output_name_to_index
Definition: NTSession.h:156
::tensorflow::Status RecvPRunOutputs(const std::vector< string > &output_names, const ExecutorsAndKeys *executors_and_keys, RunState *run_state, std::vector< Tensor > *outputs)
Definition: NTSession.cc:854
::tensorflow::Status CreateGraphs(const BuildGraphOptions &options, std::unordered_map< string, std::unique_ptr< Graph >> *outputs, std::unique_ptr< FunctionLibraryDefinition > *flib_def, RunStateArgs *run_state_args, DataTypeVector *input_types, DataTypeVector *output_types)
Definition: NTSession.cc:1181
ScopedStepContainer step_container
Definition: NTSession.h:176
std::vector< Device * > devices_
Definition: NTSession.h:277
Executor::Args::NodeOutputsCallback node_outputs_callback_
Definition: NTSession.h:344
std::atomic_int_fast64_t step_count
Definition: NTSession.h:149
std::unique_ptr< FunctionLibraryRuntime > flib
Definition: NTSession.h:130
friend class DebugGateway
Definition: NTSession.h:349
void SchedClosure(std::function< void()> c)
Definition: NTSession.cc:195
::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice< string > inputs, gtl::ArraySlice< string > outputs, gtl::ArraySlice< string > target_nodes, ExecutorsAndKeys **executors_and_keys, RunStateArgs *run_state_args)
Definition: NTSession.cc:961
std::unordered_map< string, bool > pending_inputs
Definition: NTSession.h:173
NTSession(const SessionOptions &options, const DeviceMgr *device_mgr, NTSessionFactory *factory)
Definition: NTSession.cc:208
CancellationManager * cancellation_manager_
Definition: NTSession.h:309
std::unique_ptr< Graph > graph
Definition: NTSession.h:150
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: NTSession.h:119
::tensorflow::Status Create(const GraphDef &graph) override
Definition: NTSession.cc:299
std::unordered_map< StringPiece, Node *, StringPiece::Hasher > NameNodeMap
Definition: NTSession.h:84
const int64 operation_timeout_in_ms_
Definition: NTSession.h:339
NTSessionFactory *const factory_
Definition: NTSession.h:308
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: NTSession.h:325
const DebugOptions & debug_options
Definition: NTSession.h:196
CostModelManager cost_model_manager_
Definition: NTSession.h:342
::tensorflow::Status Extend(const GraphDef &graph) override
Definition: NTSession.cc:312