CMS 3D CMS Logo

TBBSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.6.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to TBBSession (TBB = thread building blocks)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "tbb"
26  - Removed the PRunSetup, PRun, SendPRunInputs, RecvPRunOutputs and CheckFetch methods
27  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
28  - Let tbb arena and tbb group handle scheduling in WaitForNotification and SchedClosure
29  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
30  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
31  - Renamed the session factory class to TBBSessionFactory
32  - Renamed the session registrar class to TBBSessionRegistrar
33  - Renamed include guard to reflect location within CMSSW
34 */
35 
36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
38 
39 #include <atomic>
40 #include <memory>
41 #include <string>
42 #include <unordered_map>
43 #include <unordered_set>
44 #include <vector>
45 
46 #include "tbb/task_arena.h"
47 
48 #include "tensorflow/core/common_runtime/costmodel_manager.h"
49 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
50 #include "tensorflow/core/common_runtime/device_mgr.h"
51 #include "tensorflow/core/common_runtime/device_set.h"
52 #include "tensorflow/core/common_runtime/executor.h"
53 #include "tensorflow/core/common_runtime/graph_execution_state.h"
54 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
55 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
56 #include "tensorflow/core/common_runtime/session_factory.h"
57 #include "tensorflow/core/framework/cancellation.h"
58 #include "tensorflow/core/framework/graph.pb.h"
59 #include "tensorflow/core/framework/session_state.h"
60 #include "tensorflow/core/framework/tensor.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/platform/macros.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/types.h"
66 #include "tensorflow/core/public/session.h"
67 
68 namespace tbb {
69 
70 class task_group;
71 
72 } // end namespace tbb
73 
74 namespace tensorflow {
75 
76 class CostModel;
77 class DebugGateway;
78 class Device;
79 class TBBSessionFactory;
80 
81 class TBBSession : public Session {
82  public:
83  typedef std::function<void(Session*)> CloseCallback;
84 
85  // Takes ownership of 'device_mgr'.
86  // 'factory' is used to unregister the TBBSession with 'factory' when its
87  // closed. This ensures that Reset requests from the 'factory' don't get sent
88  // to sessions that are already closed.
89  TBBSession(const SessionOptions& options, const DeviceMgr* device_mgr,
90  TBBSessionFactory* factory);
91  ~TBBSession() override;
92 
93  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
94  typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
95 
96  ::tensorflow::Status Create(const GraphDef& graph) override;
97  ::tensorflow::Status Extend(const GraphDef& graph) override;
98  ::tensorflow::Status Run(const NamedTensorList& inputs,
99  const std::vector<string>& output_names,
100  const std::vector<string>& target_nodes,
101  std::vector<Tensor>* outputs) override;
102 
103  // NOTE: Experimental and subject to change.
104  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
105  const NamedTensorList& inputs,
106  const std::vector<string>& output_names,
107  const std::vector<string>& target_nodes,
108  std::vector<Tensor>* outputs,
109  RunMetadata* run_metadata) override;
110 
111  // Reset clears 'containers' from the device_mgr of the TBBSession.
112  // If 'containers' is empty, then Reset clears the default container.
113  ::tensorflow::Status Reset(const std::vector<string>& containers);
114 
115  ::tensorflow::Status ListDevices(
116  std::vector<DeviceAttributes>* response) override;
117  ::tensorflow::Status Close() override;
118  ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
119  *output = device_mgr_.get();
121  }
122 
123  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
124  cost_model_manager_.ExportCostModels(cost_models);
125  }
126 
127  private:
128  // We create one executor and its dependent library runtime for
129  // every partition.
131  Graph* graph = nullptr; // not owned.
132  Device* device = nullptr; // not owned.
133  FunctionLibraryRuntime* flib = nullptr; // not owned.
134  std::unique_ptr<Executor> executor;
135  };
136 
137  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
138  // 'step_count' is the number of times this graph is executed.
139  // 'graph' is the entire graph being executed. 'name_to_node'
140  // maps node name to node. We keep 'graph' and 'name_to_node' only in
141  // the case of partial runs. Each item in 'items' is the executor for
142  // a partition of the graph bundled with its dependent library runtime.
143  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
144  // are rendezvous keys for the fetches.
146  ExecutorsAndKeys() : step_count(0) {}
147 
148  std::atomic_int_fast64_t step_count;
149  std::unique_ptr<Graph> graph;
150  NameNodeMap name_to_node;
151  std::vector<PerPartitionExecutorsAndLib> items;
152  std::unordered_map<string, size_t> input_name_to_index;
153  std::unordered_map<string, string> input_name_to_rendezvous_key;
154  std::unordered_map<string, size_t> output_name_to_index;
155  std::unordered_map<string, string> output_name_to_rendezvous_key;
156 
157  DataTypeVector input_types;
158  DataTypeVector output_types;
159  };
160 
161  // A FunctionInfo object is created for every unique set of feeds/fetches.
162  // This info could be folded into the ExecutorsAndKeys object but we would
163  // like to maintain a deletion order in which the OpKernels (owned by the
164  // executor) should be destroyed first, followed by the resources in the
165  // device and then followed by the function stuff.
166  // TODO(rohanj): Consolidate function library definitions so that we can
167  // instantiate only one ProcFLR and lib_def and make this just a member
168  // variable and not a vector.
169  // 'flib_def' is the function library used.
170  // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
171  // device.
172  struct FunctionInfo {
173  std::unique_ptr<FunctionLibraryDefinition> flib_def;
174  std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
175  };
176 
177  // For each live partial execution, the session maintains a RunState.
178  // 'status' is the current status of this partial execution. 'executor_done'
179  // is "notified" when all executors are done. 'pending_inputs' are the set
180  // of pending feeds and 'pending_outputs' are the set of pending fetches.
181  struct RunState {
183  Status status GUARDED_BY(mu_);
184  IntraProcessRendezvous* rendez = nullptr;
185  std::unique_ptr<StepStatsCollector> collector;
186  Notification executors_done;
187  std::unordered_map<string, bool> pending_inputs; // true if fed
188  std::unordered_map<string, bool> pending_outputs; // true if fetched
189  TensorStore tensor_store;
190  ScopedStepContainer step_container;
191 
192  RunState(int64 step_id, const std::vector<Device*>* devices);
193 
194  RunState(const std::vector<string>& pending_input_names,
195  const std::vector<string>& pending_output_names, int64 step_id,
196  const std::vector<Device*>* devices);
197 
198  // Returns true if all pending inputs and outputs have been completed.
199  bool PendingDone() const;
200 
201  ~RunState();
202  };
203 
204  struct RunStateArgs {
205  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
206 
207  bool is_partial_run = false;
208  string handle;
209  std::unique_ptr<Graph> graph;
210  const DebugOptions& debug_options;
211  };
212 
213  // Initializes the base execution state given the 'graph',
214  // if not already initialized.
215  Status MaybeInitializeExecutionState(const GraphDef& graph,
216  bool* out_already_initialized)
217  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
218 
219  // Retrieves an already existing set of executors to run 'inputs' and
220  // 'outputs', or creates and caches them for future use.
221  ::tensorflow::Status GetOrCreateExecutors(
222  gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
223  gtl::ArraySlice<string> target_nodes,
224  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
225 
226  // Creates several graphs given the existing graph_def_ and the
227  // input feeds and fetches, given 'devices'. The graphs share a common
228  // function library 'flib_def'.
229  ::tensorflow::Status CreateGraphs(
230  const BuildGraphOptions& options,
231  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
232  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
233  RunStateArgs* run_state_args, DataTypeVector* input_types,
234  DataTypeVector* output_types);
235 
236  ::tensorflow::Status ExtendLocked(const GraphDef& graph)
237  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
238 
239  ::tensorflow::Status ResourceHandleToInputTensor(
240  const Tensor& resource_tensor, Tensor* retrieved_tensor);
241 
242  // Use the appropriate WaitForNotification function based on whether
243  // operation_timeout_in_ms is greater than 0.
244  //
245  // If the timeout expires, the `cm->StartCancel()` will be called.
246  ::tensorflow::Status WaitForNotification(Notification* n,
247  int64 timeout_in_ms);
248  void WaitForNotification(tbb::task_arena& arena, tbb::task_group& group,
249  RunState* run_state, CancellationManager* cm, int64 timeout_in_ms);
250 
252  mutex_lock l(closed_lock_);
253  if (closed_) return errors::Cancelled("Session has been closed.");
255  }
256 
257  ::tensorflow::Status CreateDebuggerState(
258  const DebugOptions& debug_options, int64 session_run_index,
259  int64 executor_step_index, const std::vector<string>& input_names,
260  const std::vector<string>& output_names,
261  const std::vector<string>& target_names,
262  std::unique_ptr<DebuggerStateInterface>* debugger_state);
263 
264  ::tensorflow::Status DecorateAndPublishGraphForDebug(
265  const DebugOptions& debug_options, Graph* graph, Device* device);
266 
267  const SessionOptions options_;
268 
269  // Device structures.
270  const std::unique_ptr<const DeviceMgr> device_mgr_;
271  std::vector<Device*> devices_; // not owned
272  DeviceSet device_set_;
273 
275  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
276 
278  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
279 
280  Status init_error_; // Set to an error if construction failed.
281 
282  // If true, blocks until device has finished all queued operations in a step.
283  bool sync_on_finish_ = true;
284  void SchedClosure(tbb::task_arena& arena, tbb::task_group& g, std::function<void()> c);
285 
286  std::vector<std::unique_ptr<FunctionInfo>> functions_
287  GUARDED_BY(executor_lock_);
288 
289  mutex executor_lock_; // protects executors_
290  // Holds mappings from signature to the executors that process
291  // it. The reason for a level of indirection around mapped_type is
292  // to guarantee address stability.
293  // The map value is a shared_ptr since multiple map keys can point to the
294  // same ExecutorsAndKey object.
295  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
296  GUARDED_BY(executor_lock_);
297 
298  // Holds mappings from handle to partial run state.
299  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
300  GUARDED_BY(executor_lock_);
301 
302  // This holds all the tensors that are currently alive in the session.
303  SessionState session_state_;
304 
305  TBBSessionFactory* const factory_; // not owned
306  CancellationManager* cancellation_manager_;
307 
308  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
309  // is true, such as "params" and "queue" nodes. Once placed these
310  // nodes can not be moved to a different device. Maps node names to
311  // device names.
312  std::unordered_map<string, string> stateful_placements_
313  GUARDED_BY(graph_def_lock_);
314 
315  // Execution_state; used when placing the entire graph.
316  std::unique_ptr<GraphExecutionState> execution_state_
317  GUARDED_BY(graph_def_lock_);
318 
319  // The function library, before any rewrites or optimizations have been
320  // performed. In particular, CreateGraphs() may need to modify the function
321  // library; it copies and modifies the function library.
322  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
323 
324  // true if the Session has been Closed.
326  bool closed_ GUARDED_BY(closed_lock_) = false;
327 
328  // For generating unique names for this session instance.
329  std::atomic<int64> edge_name_counter_ = {0};
330  std::atomic<int64> handle_name_counter_ = {0};
331 
332  // For generating step ids that are unique across all sessions.
333  static std::atomic_int_fast64_t step_id_counter_;
334 
335  // Global timeout for all blocking operations in this session.
336  const int64 operation_timeout_in_ms_ = 0;
337 
338  // Manages all the cost models for the graphs executed in this session.
339  CostModelManager cost_model_manager_;
340 
341  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
342 
343  TF_DISALLOW_COPY_AND_ASSIGN(TBBSession);
344 
345  // EXPERIMENTAL: debugger (tfdbg) related
346  friend class DebugGateway;
347 };
348 
349 } // end namespace tensorflow
350 
351 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:322
static boost::mutex mutex
Definition: Proxy.cc:11
std::vector< PerPartitionExecutorsAndLib > items
Definition: TBBSession.h:151
std::unordered_map< string, size_t > input_name_to_index
Definition: TBBSession.h:152
ScopedStepContainer step_container
Definition: TBBSession.h:190
std::vector< Device * > devices_
Definition: TBBSession.h:271
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: TBBSession.h:93
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
Definition: TBBSession.h:174
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: TBBSession.h:173
RunStateArgs(const DebugOptions &options)
Definition: TBBSession.h:205
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: TBBSession.h:155
CostModelManager cost_model_manager_
Definition: TBBSession.h:339
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: TBBSession.h:123
std::unordered_map< string, bool > pending_inputs
Definition: TBBSession.h:187
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Definition: Activities.doc:4
CancellationManager * cancellation_manager_
Definition: TBBSession.h:306
std::unique_ptr< StepStatsCollector > collector
Definition: TBBSession.h:185
std::unordered_map< string, size_t > output_name_to_index
Definition: TBBSession.h:154
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: TBBSession.h:94
const DebugOptions & debug_options
Definition: TBBSession.h:210
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unordered_map< string, bool > pending_outputs
Definition: TBBSession.h:188
::tensorflow::Status CheckNotClosed()
Definition: TBBSession.h:251
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
Definition: TBBSession.h:118
std::atomic_int_fast64_t step_count
Definition: TBBSession.h:148
const SessionOptions options_
Definition: TBBSession.h:267
TBBSessionFactory *const factory_
Definition: TBBSession.h:305
static std::atomic_int_fast64_t step_id_counter_
Definition: TBBSession.h:333
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:270
std::function< void(Session *)> CloseCallback
Definition: TBBSession.h:83
Definition: TBBSession.h:68
std::unique_ptr< Graph > graph
Definition: TBBSession.h:209
SessionState session_state_
Definition: TBBSession.h:303
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: TBBSession.h:153
std::unique_ptr< Graph > graph
Definition: TBBSession.h:149