CMS 3D CMS Logo

TBBSession.h
Go to the documentation of this file.
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 /*
17 This file is an adaptation of the original direct_session.h file located at
18 https://github.com/tensorflow/tensorflow/blob/v1.3.0/tensorflow/core/common_runtime/direct_session.h
19 to meet the demands of the software environment developed and used by the CMS collaboration.
20 
21 Changes:
22  - Renamed the session class to TBBSession (TBB = thread building blocks)
23  - Renamed some members to refelct that change
24  - Removed the thread_pools_ member
25  - Set the session handle to "tbb"
26  - Removed the PRunSetup, PRun, SendPRunInputs, RecvPRunOutputs and CheckFetch methods
27  - Removed the ThreadPool arguments from GetOrCreateExecutors and SchedClosure
28  - Let tbb arena and tbb group handle scheduling in WaitForNotification and SchedClosure
29  - Removed obsolete helper functions NumInterOpThreadsFromSessionOptions,
30  NewThreadPoolFromSessionOptions, NewThreadPoolFromThreadPoolOptions and GlobalThreadPool
31  - Renamed the session factory class to TBBSessionFactory
32  - Renamed the session registrar class to TBBSessionRegistrar
33  - Renamed include guard to reflect location within CMSSW
34 */
35 
36 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
37 #define PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
38 
39 #include <atomic>
40 #include <memory>
41 #include <string>
42 #include <unordered_map>
43 #include <unordered_set>
44 #include <vector>
45 
46 #include "tbb/task_arena.h"
47 
48 #include "tensorflow/core/common_runtime/costmodel_manager.h"
49 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
50 #include "tensorflow/core/common_runtime/device_mgr.h"
51 #include "tensorflow/core/common_runtime/device_set.h"
52 #include "tensorflow/core/common_runtime/executor.h"
53 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
54 #include "tensorflow/core/common_runtime/session_factory.h"
55 #include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
56 #include "tensorflow/core/framework/cancellation.h"
57 #include "tensorflow/core/framework/graph.pb.h"
58 #include "tensorflow/core/framework/session_state.h"
59 #include "tensorflow/core/framework/tensor.h"
60 #include "tensorflow/core/lib/core/errors.h"
61 #include "tensorflow/core/lib/core/status.h"
62 #include "tensorflow/core/platform/macros.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/types.h"
65 #include "tensorflow/core/public/session.h"
66 
67 namespace tbb {
68 class task_group;
69 }
70 
71 namespace tensorflow {
72 
73 class CostModel;
74 class DebugGateway;
75 class Device;
76 class TBBSessionFactory;
77 
78 class TBBSession : public Session {
79  public:
80  typedef std::function<void(Session*)> CloseCallback;
81 
82  // Takes ownership of 'device_mgr'.
83  // 'factory' is used to unregister the TBBSession with 'factory' when its
84  // closed. This ensures that Reset requests from the 'factory' don't get sent
85  // to sessions that are already closed.
86  TBBSession(const SessionOptions& options, const DeviceMgr* device_mgr,
87  TBBSessionFactory* factory);
88  ~TBBSession() override;
89 
90  typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
91  typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher>
93 
94  ::tensorflow::Status Create(const GraphDef& graph) override;
95  ::tensorflow::Status Extend(const GraphDef& graph) override;
96  ::tensorflow::Status Run(const NamedTensorList& inputs,
97  const std::vector<string>& output_names,
98  const std::vector<string>& target_nodes,
99  std::vector<Tensor>* outputs) override;
100 
101  // NOTE: Experimental and subject to change.
102  ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
103  const NamedTensorList& inputs,
104  const std::vector<string>& output_names,
105  const std::vector<string>& target_nodes,
106  std::vector<Tensor>* outputs,
107  RunMetadata* run_metadata) override;
108 
109 
110  // Reset clears 'containers' from the device_mgr of the TBBSession.
111  // If 'containers' is empty, then Reset clears the default container.
112  ::tensorflow::Status Reset(const std::vector<string>& containers);
113 
114  ::tensorflow::Status ListDevices(
115  std::vector<DeviceAttributes>* response) override;
116  ::tensorflow::Status Close() override;
117 
118  void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
119  cost_model_manager_.ExportCostModels(cost_models);
120  }
121 
122  private:
123  typedef TBBSession ME;
124 
125  // We create one executor and its dependent library runtime for
126  // every partition.
128  Graph* graph = nullptr;
129  std::unique_ptr<FunctionLibraryRuntime> flib;
130  std::unique_ptr<Executor> executor;
131  };
132 
133  // An ExecutorsAndKeys is created for a given set of feeds/fetches.
134  // 'step_count' is the number of times this graph is executed.
135  // 'graph' is the entire graph being executed. 'name_to_node'
136  // maps node name to node. We keep 'graph' and 'name_to_node' only in
137  // the case of partial runs. Each item in 'items' is the executor for
138  // a partition of the graph bundled with its dependent library runtime.
139  // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
140  // are rendezvous keys for the fetches.
141  // 'flib_def' is the function library used by graphs in 'items'.
142  // TODO(phawkins): currently partitions always share the same function
143  // library. Consider giving each partition its own function library to enable
144  // per-partition rewrites.
146  ExecutorsAndKeys() : step_count(0) {}
147 
148  std::atomic_int_fast64_t step_count;
149  std::unique_ptr<Graph> graph;
151  std::unique_ptr<FunctionLibraryDefinition> flib_def;
152  std::vector<PerPartitionExecutorsAndLib> items;
153  std::unordered_map<string, size_t> input_name_to_index;
154  std::unordered_map<string, string> input_name_to_rendezvous_key;
155  std::unordered_map<string, size_t> output_name_to_index;
156  std::unordered_map<string, string> output_name_to_rendezvous_key;
157 
158  DataTypeVector input_types;
159  DataTypeVector output_types;
160  };
161 
162  // For each live partial execution, the session maintains a RunState.
163  // 'status' is the current status of this partial execution. 'executor_done'
164  // is "notified" when all executors are done. 'pending_inputs' are the set
165  // of pending feeds and 'pending_outputs' are the set of pending fetches.
166  struct RunState {
168  Status status GUARDED_BY(mu_);
169  IntraProcessRendezvous* rendez = nullptr;
170  std::unique_ptr<StepStatsCollector> collector;
171  Notification executors_done;
172  std::unordered_map<string, bool> pending_inputs; // true if fed
173  std::unordered_map<string, bool> pending_outputs; // true if fetched
174  TensorStore tensor_store;
175  ScopedStepContainer step_container;
176 
177  RunState(int64 step_id, const std::vector<Device*>* devices);
178 
179  RunState(const std::vector<string>& pending_input_names,
180  const std::vector<string>& pending_output_names, int64 step_id,
181  const std::vector<Device*>* devices);
182 
183  // Returns true if all pending inputs and outputs have been completed.
184  bool PendingDone() const;
185 
186  ~RunState();
187  };
188 
189  struct RunStateArgs {
190  RunStateArgs(const DebugOptions& options) : debug_options(options) {}
191 
192  bool is_partial_run = false;
193  string handle;
194  std::unique_ptr<Graph> graph;
195  const DebugOptions& debug_options;
196  };
197 
198  // Initializes the base execution state given the 'graph',
199  // if not already initialized.
200  Status MaybeInitializeExecutionState(const GraphDef& graph,
201  bool* out_already_initialized)
202  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
203 
204  // Retrieves an already existing set of executors to run 'inputs' and
205  // 'outputs', or creates and caches them for future use.
206  ::tensorflow::Status GetOrCreateExecutors(
207  gtl::ArraySlice<string> inputs,
208  gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes,
209  ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
210 
211  // Creates several graphs given the existing graph_def_ and the
212  // input feeds and fetches, given 'devices'. The graphs share a common
213  // function library 'flib_def'.
214  ::tensorflow::Status CreateGraphs(
215  const BuildGraphOptions& options,
216  std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
217  std::unique_ptr<FunctionLibraryDefinition>* flib_def,
218  RunStateArgs* run_state_args, DataTypeVector* input_types,
219  DataTypeVector* output_types);
220 
221  ::tensorflow::Status ExtendLocked(const GraphDef& graph)
222  EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
223 
224  ::tensorflow::Status ResourceHandleToInputTensor(
225  const Tensor& resource_tensor, Tensor* retrieved_tensor);
226 
227  // Use the appropriate WaitForNotification function based on whether
228  // operation_timeout_in_ms is greater than 0.
229  //
230  // If the timeout expires, the `cm->StartCancel()` will be called.
231  ::tensorflow::Status WaitForNotification(Notification* n,
232  int64 timeout_in_ms);
233  void WaitForNotification(tbb::task_arena& arena, tbb::task_group& group,
234  RunState* run_state, CancellationManager* cm,
235  int64 timeout_in_ms);
236 
238  mutex_lock l(closed_lock_);
239  if (closed_) return errors::Cancelled("Session has been closed.");
241  }
242 
243  ::tensorflow::Status CreateDebuggerState(
244  const DebugOptions& debug_options, int64 session_run_index,
245  int64 executor_step_index, const std::vector<string>& input_names,
246  const std::vector<string>& output_names,
247  const std::vector<string>& target_names,
248  std::unique_ptr<DebuggerStateInterface>* debugger_state);
249 
250  ::tensorflow::Status DecorateAndPublishGraphForDebug(
251  const DebugOptions& debug_options, Graph* graph, Device* device);
252 
253  const SessionOptions options_;
254 
255  // Device structures.
256  const std::unique_ptr<const DeviceMgr> device_mgr_;
257  std::vector<Device*> devices_; // not owned
258  DeviceSet device_set_;
259 
261  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
262 
264  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
265 
266  Status init_error_; // Set to an error if construction failed.
267 
268  // If true, blocks until device has finished all queued operations in a step.
269  bool sync_on_finish_ = true;
270  void SchedClosure(tbb::task_arena& arena, tbb::task_group& g, std::function<void()> c);
271 
272  mutex executor_lock_; // protects executors_
273  // Holds mappings from signature to the executors that process
274  // it. The reason for a level of indirection around mapped_type is
275  // to guarantee address stability.
276  // The map value is a shared_ptr since multiple map keys can point to the
277  // same ExecutorsAndKey object.
278  std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
279  GUARDED_BY(executor_lock_);
280 
281  // Holds mappings from handle to partial run state.
282  std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
283  GUARDED_BY(executor_lock_);
284 
285  // This holds all the tensors that are currently alive in the session.
286  SessionState session_state_;
287 
288  TBBSessionFactory* const factory_; // not owned
289  CancellationManager* cancellation_manager_;
290 
291  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
292  // is true, such as "params" and "queue" nodes. Once placed these
293  // nodes can not be moved to a different device. Maps node names to
294  // device names.
295  std::unordered_map<string, string> stateful_placements_
296  GUARDED_BY(graph_def_lock_);
297 
298  // Execution_state; used when placing the entire graph.
299  std::unique_ptr<SimpleGraphExecutionState> execution_state_
300  GUARDED_BY(graph_def_lock_);
301 
302  // The function library, before any rewrites or optimizations have been
303  // performed. In particular, CreateGraphs() may need to modify the function
304  // library; it copies and modifies the function library.
305  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
306 
307  // true if the Session has been Closed.
309  bool closed_ GUARDED_BY(closed_lock_) = false;
310 
311  // For generating unique names for this session instance.
312  std::atomic<int64> edge_name_counter_ = {0};
313  std::atomic<int64> handle_name_counter_ = {0};
314 
315  // For generating step ids that are unique across all sessions.
316  static std::atomic_int_fast64_t step_id_counter_;
317 
318  // Global timeout for all blocking operations in this session.
319  const int64 operation_timeout_in_ms_ = 0;
320 
321  // Manages all the cost models for the graphs executed in this session.
322  CostModelManager cost_model_manager_;
323 
324  Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
325 
326  TF_DISALLOW_COPY_AND_ASSIGN(TBBSession);
327 
328  // EXPERIMENTAL: debugger (tfdbg) related
329  friend class DebugGateway;
330 };
331 
332 } // end namespace tensorflow
333 
334 #endif // PHYSICSTOOLS_TENSORFLOW_TBBSESSION_H
std::unique_ptr< FunctionLibraryDefinition > flib_def_
Definition: TBBSession.h:305
std::vector< PerPartitionExecutorsAndLib > items
Definition: TBBSession.h:152
std::unordered_map< string, size_t > input_name_to_index
Definition: TBBSession.h:153
ScopedStepContainer step_container
Definition: TBBSession.h:175
static boost::mutex mutex
Definition: LHEProxy.cc:11
std::vector< Device * > devices_
Definition: TBBSession.h:257
std::vector< std::pair< string, Tensor > > NamedTensorList
Definition: TBBSession.h:90
std::unique_ptr< FunctionLibraryRuntime > flib
Definition: TBBSession.h:129
RunStateArgs(const DebugOptions &options)
Definition: TBBSession.h:190
std::unordered_map< string, string > output_name_to_rendezvous_key
Definition: TBBSession.h:156
CostModelManager cost_model_manager_
Definition: TBBSession.h:322
void ExportCostModels(CostModelManager::CostModelMap *cost_models)
Definition: TBBSession.h:118
std::unordered_map< string, bool > pending_inputs
Definition: TBBSession.h:172
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e g
Definition: Activities.doc:4
std::unordered_map< StringPiece, Node *, StringPiece::Hasher > NameNodeMap
Definition: TBBSession.h:92
CancellationManager * cancellation_manager_
Definition: TBBSession.h:289
std::unique_ptr< StepStatsCollector > collector
Definition: TBBSession.h:170
std::unordered_map< string, size_t > output_name_to_index
Definition: TBBSession.h:155
const DebugOptions & debug_options
Definition: TBBSession.h:195
std::pair< int, edm::FunctionWithDict > OK
Definition: findMethod.cc:136
std::unordered_map< string, bool > pending_outputs
Definition: TBBSession.h:173
::tensorflow::Status CheckNotClosed()
Definition: TBBSession.h:237
std::atomic_int_fast64_t step_count
Definition: TBBSession.h:148
const SessionOptions options_
Definition: TBBSession.h:253
TBBSessionFactory *const factory_
Definition: TBBSession.h:288
static std::atomic_int_fast64_t step_id_counter_
Definition: TBBSession.h:316
const std::unique_ptr< const DeviceMgr > device_mgr_
Definition: TBBSession.h:256
std::unique_ptr< FunctionLibraryDefinition > flib_def
Definition: TBBSession.h:151
std::function< void(Session *)> CloseCallback
Definition: TBBSession.h:80
Definition: TBBSession.h:67
std::unique_ptr< Graph > graph
Definition: TBBSession.h:194
SessionState session_state_
Definition: TBBSession.h:286
void Reset(std::vector< TH2F > &depth)
std::unordered_map< string, string > input_name_to_rendezvous_key
Definition: TBBSession.h:154
std::unique_ptr< Graph > graph
Definition: TBBSession.h:149