CMS 3D CMS Logo

TBBSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.5.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to TBBSession (TBB = thread building blocks)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "tbb"
26  - Removed the PRunSetup, PRun, SendPRunInputs, RecvPRunOutputs and CheckFetch methods
27  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
28  - Let tbb arena and tbb group handle scheduling in WaitForNotification and SchedClosure
29  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
30  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
31  - Renamed the session factory class to TBBSessionFactory
32  - Renamed the session registrar class to TBBSessionRegistrar
33  - Renamed include guard to reflect location within CMSSW
34 */
35 
36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
38 
39 #include <atomic>
40 #include <memory>
41 #include <string>
42 #include <unordered_map>
43 #include <unordered_set>
44 #include <vector>
45 
46 #include "tbb/task_arena.h"
47 
48 #include "tensorflow/core/common_runtime/costmodel_manager.h"
49 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
50 #include "tensorflow/core/common_runtime/device_mgr.h"
51 #include "tensorflow/core/common_runtime/device_set.h"
52 #include "tensorflow/core/common_runtime/executor.h"
53 #include "tensorflow/core/common_runtime/graph_execution_state.h"
54 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
55 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
56 #include "tensorflow/core/common_runtime/session_factory.h"
57 #include "tensorflow/core/framework/cancellation.h"
58 #include "tensorflow/core/framework/graph.pb.h"
59 #include "tensorflow/core/framework/session_state.h"
60 #include "tensorflow/core/framework/tensor.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/platform/macros.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/types.h"
66 #include "tensorflow/core/public/session.h"
67 
68 namespace tbb {
69 
70 class task_group;
71 
72 } // end namespace tbb
73 
74 namespace tensorflow {
75 
76 class CostModel;
77 class DebugGateway;
78 class Device;
79 class TBBSessionFactory;
80 
81 class TBBSession : public Session {
82  public:
83  typedef std::function<void(Session*)> CloseCallback;
84 
85  // Takes ownership of 'device_mgr'.
86  // 'factory' is used to unregister the TBBSession with 'factory' when its
87  // closed. This ensures that Reset requests from the 'factory' don't get sent
88  // to sessions that are already closed.
89  TBBSession(const SessionOptions& options, const DeviceMgr* device_mgr,
90  TBBSessionFactory* factory);
91  ~TBBSession() override;
92 
93  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
94  typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
95 
96  ::tensorflow::Status Create(const GraphDef& graph) override;
97  ::tensorflow::Status Extend(const GraphDef& graph) override;
98  ::tensorflow::Status Run(const NamedTensorList& inputs,
99  const std::vector<string>& output_names,
100  const std::vector<string>& target_nodes,
101  std::vector<Tensor>* outputs) override;
102 
103  // NOTE: Experimental and subject to change.
104  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
105  const NamedTensorList& inputs,
106  const std::vector<string>& output_names,
107  const std::vector<string>& target_nodes,
108  std::vector<Tensor>* outputs,
109  RunMetadata* run_metadata) override;
110 
111  // Reset clears 'containers' from the device_mgr of the TBBSession.
112  // If 'containers' is empty, then Reset clears the default container.
113  ::tensorflow::Status Reset(const std::vector<string>& containers);
114 
115  ::tensorflow::Status ListDevices(
116  std::vector<DeviceAttributes>* response) override;
117  ::tensorflow::Status Close() override;
118  ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
119  *output = device_mgr_.get();
121  }
122 
123  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
124  cost_model_manager_.ExportCostModels(cost_models);
125  }
126 
127  private:
128  // We create one executor and its dependent library runtime for
129  // every partition.
131  Graph* graph = nullptr; // not owned.
132  Device* device = nullptr; // not owned.
133  FunctionLibraryRuntime* flib = nullptr; // not owned.
134  std::unique_ptr<Executor> executor;
135  };
136 
137  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
138  // 'step_count' is the number of times this graph is executed.
139  // 'graph' is the entire graph being executed. 'name_to_node'
140  // maps node name to node. We keep 'graph' and 'name_to_node' only in
141  // the case of partial runs. Each item in 'items' is the executor for
142  // a partition of the graph bundled with its dependent library runtime.
143  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
144  // are rendezvous keys for the fetches.
145  // 'flib_def' is the function library used by graphs in 'items'.
146  // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
147  // device.
148  // TODO(phawkins): currently partitions always share the same function
149  // library. Consider giving each partition its own function library to enable
150  // per-partition rewrites.
152  ExecutorsAndKeys() : step_count(0) {}
153 
154  std::atomic_int_fast64_t step_count;
155  std::unique_ptr<Graph> graph;
156  NameNodeMap name_to_node;
157  std::unique_ptr<FunctionLibraryDefinition> flib_def;
158  std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
159  std::vector<PerPartitionExecutorsAndLib> items;
160  std::unordered_map<string, size_t> input_name_to_index;
161  std::unordered_map<string, string> input_name_to_rendezvous_key;
162  std::unordered_map<string, size_t> output_name_to_index;
163  std::unordered_map<string, string> output_name_to_rendezvous_key;
164 
165  DataTypeVector input_types;
166  DataTypeVector output_types;
167  };
168 
169  // For each live partial execution, the session maintains a RunState.
170  // 'status' is the current status of this partial execution. 'executor_done'
171  // is "notified" when all executors are done. 'pending_inputs' are the set
172  // of pending feeds and 'pending_outputs' are the set of pending fetches.
173  struct RunState {
175  Status status GUARDED_BY(mu_);
176  IntraProcessRendezvous* rendez = nullptr;
177  std::unique_ptr<StepStatsCollector> collector;
178  Notification executors_done;
179  std::unordered_map<string, bool> pending_inputs; // true if fed
180  std::unordered_map<string, bool> pending_outputs; // true if fetched
181  TensorStore tensor_store;
182  ScopedStepContainer step_container;
183 
184  RunState(int64 step_id, const std::vector<Device*>* devices);
185 
186  RunState(const std::vector<string>& pending_input_names,
187  const std::vector<string>& pending_output_names, int64 step_id,
188  const std::vector<Device*>* devices);
189 
190  // Returns true if all pending inputs and outputs have been completed.
191  bool PendingDone() const;
192 
193  ~RunState();
194  };
195 
196  struct RunStateArgs {
197  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
198 
199  bool is_partial_run = false;
200  string handle;
201  std::unique_ptr<Graph> graph;
202  const DebugOptions& debug_options;
203  };
204 
205  // Initializes the base execution state given the 'graph',
206  // if not already initialized.
207  Status MaybeInitializeExecutionState(const GraphDef& graph,
208  bool* out_already_initialized)
209  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
210 
211  // Retrieves an already existing set of executors to run 'inputs' and
212  // 'outputs', or creates and caches them for future use.
213  ::tensorflow::Status GetOrCreateExecutors(
214  gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
215  gtl::ArraySlice<string> target_nodes,
216  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
217 
218  // Creates several graphs given the existing graph_def_ and the
219  // input feeds and fetches, given 'devices'. The graphs share a common
220  // function library 'flib_def'.
221  ::tensorflow::Status CreateGraphs(
222  const BuildGraphOptions& options,
223  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
224  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
225  RunStateArgs* run_state_args, DataTypeVector* input_types,
226  DataTypeVector* output_types);
227 
228  ::tensorflow::Status ExtendLocked(const GraphDef& graph)
229  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
230 
231  ::tensorflow::Status ResourceHandleToInputTensor(
232  const Tensor& resource_tensor, Tensor* retrieved_tensor);
233 
234  // Use the appropriate WaitForNotification function based on whether
235  // operation_timeout_in_ms is greater than 0.
236  //
237  // If the timeout expires, the `cm->StartCancel()` will be called.
238  ::tensorflow::Status WaitForNotification(Notification* n,
239  int64 timeout_in_ms);
240  void WaitForNotification(tbb::task_arena& arena, tbb::task_group& group,
241  RunState* run_state, CancellationManager* cm, int64 timeout_in_ms);
242 
244  mutex_lock l(closed_lock_);
245  if (closed_) return errors::Cancelled("Session has been closed.");
247  }
248 
249  ::tensorflow::Status CreateDebuggerState(
250  const DebugOptions& debug_options, int64 session_run_index,
251  int64 executor_step_index, const std::vector<string>& input_names,
252  const std::vector<string>& output_names,
253  const std::vector<string>& target_names,
254  std::unique_ptr<DebuggerStateInterface>* debugger_state);
255 
256  ::tensorflow::Status DecorateAndPublishGraphForDebug(
257  const DebugOptions& debug_options, Graph* graph, Device* device);
258 
259  const SessionOptions options_;
260 
261  // Device structures.
262  const std::unique_ptr<const DeviceMgr> device_mgr_;
263  std::vector<Device*> devices_; // not owned
264  DeviceSet device_set_;
265 
267  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
268 
270  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
271 
272  Status init_error_; // Set to an error if construction failed.
273 
274  // If true, blocks until device has finished all queued operations in a step.
275  bool sync_on_finish_ = true;
276  void SchedClosure(tbb::task_arena& arena, tbb::task_group& g, std::function<void()> c);
277 
278  mutex executor_lock_; // protects executors_
279  // Holds mappings from signature to the executors that process
280  // it. The reason for a level of indirection around mapped_type is
281  // to guarantee address stability.
282  // The map value is a shared_ptr since multiple map keys can point to the
283  // same ExecutorsAndKey object.
284  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
285  GUARDED_BY(executor_lock_);
286 
287  // Holds mappings from handle to partial run state.
288  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
289  GUARDED_BY(executor_lock_);
290 
291  // This holds all the tensors that are currently alive in the session.
292  SessionState session_state_;
293 
294  TBBSessionFactory* const factory_; // not owned
295  CancellationManager* cancellation_manager_;
296 
297  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
298  // is true, such as "params" and "queue" nodes. Once placed these
299  // nodes can not be moved to a different device. Maps node names to
300  // device names.
301  std::unordered_map<string, string> stateful_placements_
302  GUARDED_BY(graph_def_lock_);
303 
304  // Execution_state; used when placing the entire graph.
305  std::unique_ptr<GraphExecutionState> execution_state_
306  GUARDED_BY(graph_def_lock_);
307 
308  // The function library, before any rewrites or optimizations have been
309  // performed. In particular, CreateGraphs() may need to modify the function
310  // library; it copies and modifies the function library.
311  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
312 
313  // true if the Session has been Closed.
315  bool closed_ GUARDED_BY(closed_lock_) = false;
316 
317  // For generating unique names for this session instance.
318  std::atomic<int64> edge_name_counter_ = {0};
319  std::atomic<int64> handle_name_counter_ = {0};
320 
321  // For generating step ids that are unique across all sessions.
322  static std::atomic_int_fast64_t step_id_counter_;
323 
324  // Global timeout for all blocking operations in this session.
325  const int64 operation_timeout_in_ms_ = 0;
326 
327  // Manages all the cost models for the graphs executed in this session.
328  CostModelManager cost_model_manager_;
329 
330  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
331 
332  TF_DISALLOW_COPY_AND_ASSIGN(TBBSession);
333 
334  // EXPERIMENTAL: debugger (tfdbg) related
335  friend class DebugGateway;
336 };
337 
338 } // end namespace tensorflow
339 
340 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:311
static boost::mutex mutex
Definition: Proxy.cc:11
std::vector< PerPartitionExecutorsAndLib > items
Definition: TBBSession.h:159
std::unordered_map< string, size_t > input_name_to_index
Definition: TBBSession.h:160
ScopedStepContainer step_container
Definition: TBBSession.h:182
std::vector< Device * > devices_
Definition: TBBSession.h:263
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: TBBSession.h:93
RunStateArgs(const DebugOptions &options)
Definition: TBBSession.h:197
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: TBBSession.h:163
CostModelManager cost_model_manager_
Definition: TBBSession.h:328
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: TBBSession.h:123
std::unordered_map< string, bool > pending_inputs
Definition: TBBSession.h:179
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Definition: Activities.doc:4
CancellationManager * cancellation_manager_
Definition: TBBSession.h:295
std::unique_ptr< StepStatsCollector > collector
Definition: TBBSession.h:177
std::unique_ptr< ProcessFunctionLibraryRuntime > proc_flr
Definition: TBBSession.h:158
std::unordered_map< string, size_t > output_name_to_index
Definition: TBBSession.h:162
std::unordered_map< StringPiece, Node *, StringPieceHasher > NameNodeMap
Definition: TBBSession.h:94
const DebugOptions & debug_options
Definition: TBBSession.h:202
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unordered_map< string, bool > pending_outputs
Definition: TBBSession.h:180
::tensorflow::Status CheckNotClosed()
Definition: TBBSession.h:243
::tensorflow::Status LocalDeviceManager(const DeviceMgr **output) override
Definition: TBBSession.h:118
std::atomic_int_fast64_t step_count
Definition: TBBSession.h:154
const SessionOptions options_
Definition: TBBSession.h:259
TBBSessionFactory *const factory_
Definition: TBBSession.h:294
static std::atomic_int_fast64_t step_id_counter_
Definition: TBBSession.h:322
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:262
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: TBBSession.h:157
std::function< void(Session *)> CloseCallback
Definition: TBBSession.h:83
Definition: TBBSession.h:68
std::unique_ptr< Graph > graph
Definition: TBBSession.h:201
SessionState session_state_
Definition: TBBSession.h:292
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: TBBSession.h:161
std::unique_ptr< Graph > graph
Definition: TBBSession.h:155