CMS 3D CMS Logo

TBBSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.6.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to TBBSession (TBB = thread building blocks)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "tbb"
26  - Removed the PRunSetup, PRun, SendPRunInputs, RecvPRunOutputs and CheckFetch methods
27  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
28  - Let tbb arena and tbb group handle scheduling in WaitForNotification and SchedClosure
29  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
30  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
31  - Renamed the session factory class to TBBSessionFactory
32  - Renamed the session registrar class to TBBSessionRegistrar
33  - Renamed include guard to reflect location within CMSSW
34 */
35 
36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
38 
39 #include <atomic>
40 #include <memory>
41 #include <string>
42 #include <unordered_map>
43 #include <unordered_set>
44 #include <vector>
45 
46 #include "tbb/task_arena.h"
47 
48 #include "tensorflow/core/common_runtime/costmodel_manager.h"
49 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
50 #include "tensorflow/core/common_runtime/device_mgr.h"
51 #include "tensorflow/core/common_runtime/device_set.h"
52 #include "tensorflow/core/common_runtime/executor.h"
53 #include "tensorflow/core/common_runtime/graph_execution_state.h"
54 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
55 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
56 #include "tensorflow/core/common_runtime/session_factory.h"
57 #include "tensorflow/core/framework/cancellation.h"
58 #include "tensorflow/core/framework/graph.pb.h"
59 #include "tensorflow/core/framework/session_state.h"
60 #include "tensorflow/core/framework/tensor.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/platform/macros.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/types.h"
66 #include "tensorflow/core/public/session.h"
67 
68 namespace tbb {
69 
70  class task_group;
71 
72 } // end namespace tbb
73 
74 namespace tensorflow {
75 
76  class CostModel;
77  class DebugGateway;
78  class Device;
79  class TBBSessionFactory;
80 
81  class TBBSession : public Session {
82  public:
83  typedef std::function<void(Session*)> CloseCallback;
84 
85  // Takes ownership of 'device_mgr'.
86  // 'factory' is used to unregister the TBBSession with 'factory' when its
87  // closed. This ensures that Reset requests from the 'factory' don't get sent
88  // to sessions that are already closed.
89  TBBSession(const SessionOptions& options, const DeviceMgr* device_mgr, TBBSessionFactory* factory);
90  ~TBBSession() override;
91 
92  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
93  typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
94 
95  ::tensorflow::Status Create(const GraphDef& graph) override;
96  ::tensorflow::Status Extend(const GraphDef& graph) override;
97  ::tensorflow::Status Run(const NamedTensorList& inputs,
98  const std::vector<string>& output_names,
99  const std::vector<string>& target_nodes,
100  std::vector<Tensor>* outputs) override;
101 
102  // NOTE: Experimental and subject to change.
103  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
104  const NamedTensorList& inputs,
105  const std::vector<string>& output_names,
106  const std::vector<string>& target_nodes,
107  std::vector<Tensor>* outputs,
108  RunMetadata* run_metadata) override;
109 
110  // Reset clears 'containers' from the device_mgr of the TBBSession.
111  // If 'containers' is empty, then Reset clears the default container.
112  ::tensorflow::Status Reset(const std::vector<string>& containers);
113 
114  ::tensorflow::Status ListDevices(std::vector<DeviceAttributes>* response) override;
115  ::tensorflow::Status Close() override;
116  ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
117  *output = device_mgr_.get();
119  }
120 
121  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
122  cost_model_manager_.ExportCostModels(cost_models);
123  }
124 
125  private:
126  // We create one executor and its dependent library runtime for
127  // every partition.
129  Graph* graph = nullptr; // not owned.
130  Device* device = nullptr; // not owned.
131  FunctionLibraryRuntime* flib = nullptr; // not owned.
132  std::unique_ptr<Executor> executor;
133  };
134 
135  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
136  // 'step_count' is the number of times this graph is executed.
137  // 'graph' is the entire graph being executed. 'name_to_node'
138  // maps node name to node. We keep 'graph' and 'name_to_node' only in
139  // the case of partial runs. Each item in 'items' is the executor for
140  // a partition of the graph bundled with its dependent library runtime.
141  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
142  // are rendezvous keys for the fetches.
144  ExecutorsAndKeys() : step_count(0) {}
145 
146  std::atomic_int_fast64_t step_count;
147  std::unique_ptr<Graph> graph;
148  NameNodeMap name_to_node;
149  std::vector<PerPartitionExecutorsAndLib> items;
150  std::unordered_map<string, size_t> input_name_to_index;
151  std::unordered_map<string, string> input_name_to_rendezvous_key;
152  std::unordered_map<string, size_t> output_name_to_index;
153  std::unordered_map<string, string> output_name_to_rendezvous_key;
154 
155  DataTypeVector input_types;
156  DataTypeVector output_types;
157  };
158 
159  // A FunctionInfo object is created for every unique set of feeds/fetches.
160  // This info could be folded into the ExecutorsAndKeys object but we would
161  // like to maintain a deletion order in which the OpKernels (owned by the
162  // executor) should be destroyed first, followed by the resources in the
163  // device and then followed by the function stuff.
164  // TODO(rohanj): Consolidate function library definitions so that we can
165  // instantiate only one ProcFLR and lib_def and make this just a member
166  // variable and not a vector.
167  // 'flib_def' is the function library used.
168  // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
169  // device.
170  struct FunctionInfo {
171  std::unique_ptr<FunctionLibraryDefinition> flib_def;
172  std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
173  };
174 
175  // For each live partial execution, the session maintains a RunState.
176  // 'status' is the current status of this partial execution. 'executor_done'
177  // is "notified" when all executors are done. 'pending_inputs' are the set
178  // of pending feeds and 'pending_outputs' are the set of pending fetches.
179  struct RunState {
181  Status status GUARDED_BY(mu_);
182  IntraProcessRendezvous* rendez = nullptr;
183  std::unique_ptr<StepStatsCollector> collector;
184  Notification executors_done;
185  std::unordered_map<string, bool> pending_inputs; // true if fed
186  std::unordered_map<string, bool> pending_outputs; // true if fetched
187  TensorStore tensor_store;
188  ScopedStepContainer step_container;
189 
190  RunState(int64 step_id, const std::vector<Device*>* devices);
191 
192  RunState(const std::vector<string>& pending_input_names,
193  const std::vector<string>& pending_output_names,
194  int64 step_id,
195  const std::vector<Device*>* devices);
196 
197  // Returns true if all pending inputs and outputs have been completed.
198  bool PendingDone() const;
199 
200  ~RunState();
201  };
202 
203  struct RunStateArgs {
204  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
205 
206  bool is_partial_run = false;
207  string handle;
208  std::unique_ptr<Graph> graph;
209  const DebugOptions& debug_options;
210  };
211 
212  // Initializes the base execution state given the 'graph',
213  // if not already initialized.
214  Status MaybeInitializeExecutionState(const GraphDef& graph, bool* out_already_initialized)
215  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
216 
217  // Retrieves an already existing set of executors to run 'inputs' and
218  // 'outputs', or creates and caches them for future use.
219  ::tensorflow::Status GetOrCreateExecutors(gtl::ArraySlice<string> inputs,
220  gtl::ArraySlice<string> outputs,
221  gtl::ArraySlice<string> target_nodes,
222  ExecutorsAndKeys** executors_and_keys,
223  RunStateArgs* run_state_args);
224 
225  // Creates several graphs given the existing graph_def_ and the
226  // input feeds and fetches, given 'devices'. The graphs share a common
227  // function library 'flib_def'.
228  ::tensorflow::Status CreateGraphs(const BuildGraphOptions& options,
229  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
230  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
231  RunStateArgs* run_state_args,
232  DataTypeVector* input_types,
233  DataTypeVector* output_types);
234 
235  ::tensorflow::Status ExtendLocked(const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
236 
237  ::tensorflow::Status ResourceHandleToInputTensor(const Tensor& resource_tensor, Tensor* retrieved_tensor);
238 
239  // Use the appropriate WaitForNotification function based on whether
240  // operation_timeout_in_ms is greater than 0.
241  //
242  // If the timeout expires, the `cm->StartCancel()` will be called.
243  ::tensorflow::Status WaitForNotification(Notification* n, int64 timeout_in_ms);
244  void WaitForNotification(tbb::task_arena& arena,
245  tbb::task_group& group,
246  RunState* run_state,
247  CancellationManager* cm,
248  int64 timeout_in_ms);
249 
251  mutex_lock l(closed_lock_);
252  if (closed_)
253  return errors::Cancelled("Session has been closed.");
255  }
256 
257  ::tensorflow::Status CreateDebuggerState(const DebugOptions& debug_options,
258  int64 session_run_index,
259  int64 executor_step_index,
260  const std::vector<string>& input_names,
261  const std::vector<string>& output_names,
262  const std::vector<string>& target_names,
263  std::unique_ptr<DebuggerStateInterface>* debugger_state);
264 
265  ::tensorflow::Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
266  Graph* graph,
267  Device* device);
268 
269  const SessionOptions options_;
270 
271  // Device structures.
272  const std::unique_ptr<const DeviceMgr> device_mgr_;
273  std::vector<Device*> devices_; // not owned
274  DeviceSet device_set_;
275 
277  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
278 
280  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
281 
282  Status init_error_; // Set to an error if construction failed.
283 
284  // If true, blocks until device has finished all queued operations in a step.
285  bool sync_on_finish_ = true;
286  void SchedClosure(tbb::task_arena& arena, tbb::task_group& g, std::function<void()> c);
287 
288  std::vector<std::unique_ptr<FunctionInfo>> functions_ GUARDED_BY(executor_lock_);
289 
290  mutex executor_lock_; // protects executors_
291  // Holds mappings from signature to the executors that process
292  // it. The reason for a level of indirection around mapped_type is
293  // to guarantee address stability.
294  // The map value is a shared_ptr since multiple map keys can point to the
295  // same ExecutorsAndKey object.
296  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_ GUARDED_BY(executor_lock_);
297 
298  // Holds mappings from handle to partial run state.
299  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ GUARDED_BY(executor_lock_);
300 
301  // This holds all the tensors that are currently alive in the session.
302  SessionState session_state_;
303 
304  TBBSessionFactory* const factory_; // not owned
305  CancellationManager* cancellation_manager_;
306 
307  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
308  // is true, such as "params" and "queue" nodes. Once placed these
309  // nodes can not be moved to a different device. Maps node names to
310  // device names.
311  std::unordered_map<string, string> stateful_placements_ GUARDED_BY(graph_def_lock_);
312 
313  // Execution_state; used when placing the entire graph.
314  std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(graph_def_lock_);
315 
316  // The function library, before any rewrites or optimizations have been
317  // performed. In particular, CreateGraphs() may need to modify the function
318  // library; it copies and modifies the function library.
319  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
320 
321  // true if the Session has been Closed.
323  bool closed_ GUARDED_BY(closed_lock_) = false;
324 
325  // For generating unique names for this session instance.
326  std::atomic<int64> edge_name_counter_ = {0};
327  std::atomic<int64> handle_name_counter_ = {0};
328 
329  // For generating step ids that are unique across all sessions.
330  static std::atomic_int_fast64_t step_id_counter_;
331 
332  // Global timeout for all blocking operations in this session.
333  const int64 operation_timeout_in_ms_ = 0;
334 
335  // Manages all the cost models for the graphs executed in this session.
336  CostModelManager cost_model_manager_;
337 
338  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
339 
340  TF_DISALLOW_COPY_AND_ASSIGN(TBBSession);
341 
342  // EXPERIMENTAL: debugger (tfdbg) related
343  friend class DebugGateway;
344  };
345 
346 } // end namespace tensorflow
347 
348 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:319
static boost::mutex mutex
Definition: Proxy.cc:9
std::vector< PerPartitionExecutorsAndLib > items
Definition: TBBSession.h:149
std::unordered_map< string, size_t > input_name_to_index
Definition: TBBSession.h:150
ScopedStepContainer step_container
Definition: TBBSession.h:188
std::vector< Device * > devices_
Definition: TBBSession.h:273
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: TBBSession.h:92
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
Definition: TBBSession.h:172
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: TBBSession.h:171
RunStateArgs(const DebugOptions &options)
Definition: TBBSession.h:204
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: TBBSession.h:153
CostModelManager cost_model_manager_
Definition: TBBSession.h:336
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: TBBSession.h:121
std::unordered_map< string, bool > pending_inputs
Definition: TBBSession.h:185
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Definition: Activities.doc:4
CancellationManager * cancellation_manager_
Definition: TBBSession.h:305
std::unique_ptr< StepStatsCollector > collector
Definition: TBBSession.h:183
std::unordered_map< string, size_t > output_name_to_index
Definition: TBBSession.h:152
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: TBBSession.h:93
const DebugOptions & debug_options
Definition: TBBSession.h:209
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:126
std::unordered_map< string, bool > pending_outputs
Definition: TBBSession.h:186
::tensorflow::Status CheckNotClosed()
Definition: TBBSession.h:250
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
Definition: TBBSession.h:116
std::atomic_int_fast64_t step_count
Definition: TBBSession.h:146
const SessionOptions options_
Definition: TBBSession.h:269
TBBSessionFactory *const factory_
Definition: TBBSession.h:304
static std::atomic_int_fast64_t step_id_counter_
Definition: TBBSession.h:330
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:272
std::function< void(Session *)> CloseCallback
Definition: TBBSession.h:83
Definition: TBBSession.h:68
std::unique_ptr< Graph > graph
Definition: TBBSession.h:208
SessionState session_state_
Definition: TBBSession.h:302
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: TBBSession.h:151
std::unique_ptr< Graph > graph
Definition: TBBSession.h:147