CMS 3D CMS Logo

TBBThreadPool.h
Go to the documentation of this file.
1 /*
2  * Custom TensorFlow thread pool implementation that schedules tasks in TBB.
3  * Based on TensorFlow 2.1.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
9 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
10 #define PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
11 
13 
14 #include "tensorflow/core/lib/core/threadpool.h"
15 
16 #include "tbb/task_arena.h"
17 #include "tbb/task_group.h"
18 #include "tbb/global_control.h"
19 
20 namespace tensorflow {
21 
22  class TBBThreadPool : public tensorflow::thread::ThreadPoolInterface {
23  public:
24  static TBBThreadPool& instance(int nThreads = -1) {
26  return pool;
27  }
28 
29  explicit TBBThreadPool(int nThreads = -1)
31  : tbb::global_control::active_value(tbb::global_control::max_allowed_parallelism)),
33  // when nThreads is zero or smaller, use the default value determined by tbb
34  }
35 
36  void Schedule(std::function<void()> fn) override {
37  numScheduleCalled_ += 1;
38 
39  // use a task arena to avoid having unrelated tasks start
40  // running on this thread, which could potentially start deadlocks
41  tbb::task_arena taskArena;
42  tbb::task_group taskGroup;
43 
44  // we are required to always call wait before destructor
45  auto doneWithTaskGroup = [&taskArena, &taskGroup](void*) {
46  taskArena.execute([&taskGroup]() { taskGroup.wait(); });
47  };
48  std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup)> taskGuard(&taskGroup, doneWithTaskGroup);
49 
50  // schedule the task
51  taskArena.execute([&taskGroup, &fn] { taskGroup.run(fn); });
52 
53  // reset the task guard which will call wait
54  taskGuard.reset();
55  }
56 
57  void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); }
58 
59  void Cancel() override {}
60 
61  int NumThreads() const override { return nThreads_; }
62 
63  int CurrentThreadId() const override {
64  static std::atomic<int> idCounter{0};
65  thread_local const int id = idCounter++;
66  return id;
67  }
68 
70 
71  private:
72  const int nThreads_;
73  std::atomic<int> numScheduleCalled_;
74  };
75 
76 } // namespace tensorflow
77 
78 #endif // PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
start
Definition: start.py:1
tensorflow::TBBThreadPool::TBBThreadPool
TBBThreadPool(int nThreads=-1)
Definition: TBBThreadPool.h:29
tensorflow::TBBThreadPool
Definition: TBBThreadPool.h:22
tensorflow::TBBThreadPool::ScheduleWithHint
void ScheduleWithHint(std::function< void()> fn, int start, int end) override
Definition: TBBThreadPool.h:57
tensorflow::TBBThreadPool::CurrentThreadId
int CurrentThreadId() const override
Definition: TBBThreadPool.h:63
mps_fire.end
end
Definition: mps_fire.py:242
CMS_THREAD_SAFE
#define CMS_THREAD_SAFE
Definition: thread_safety_macros.h:4
runTheMatrix.nThreads
nThreads
Definition: runTheMatrix.py:371
tensorflow::TBBThreadPool::GetNumScheduleCalled
int GetNumScheduleCalled()
Definition: TBBThreadPool.h:69
thread_safety_macros.h
tensorflow::TBBThreadPool::instance
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
cms::cuda::device::unique_ptr
std::unique_ptr< T, impl::DeviceDeleter > unique_ptr
Definition: device_unique_ptr.h:33
tensorflow::TBBThreadPool::Cancel
void Cancel() override
Definition: TBBThreadPool.h:59
triggerObjects_cff.id
id
Definition: triggerObjects_cff.py:29
tensorflow::TBBThreadPool::nThreads_
const int nThreads_
Definition: TBBThreadPool.h:72
HiBiasedCentrality_cfi.function
function
Definition: HiBiasedCentrality_cfi.py:4
personalPlayback.fn
fn
Definition: personalPlayback.py:515
tensorflow
Definition: NoThreadPool.h:18
tensorflow::TBBThreadPool::NumThreads
int NumThreads() const override
Definition: TBBThreadPool.h:61
tensorflow::TBBThreadPool::Schedule
void Schedule(std::function< void()> fn) override
Definition: TBBThreadPool.h:36
tensorflow::TBBThreadPool::numScheduleCalled_
std::atomic< int > numScheduleCalled_
Definition: TBBThreadPool.h:73
submitPVResolutionJobs.pool
pool
Definition: submitPVResolutionJobs.py:351