CMS 3D CMS Logo

TBBThreadPool.h
Go to the documentation of this file.
1 /*
2  * Custom TensorFlow thread pool implementation that schedules tasks in TBB.
3  * Based on TensorFlow 2.1.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
9 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
10 #define PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
11 
13 
14 #include "tensorflow/core/lib/core/threadpool.h"
15 
16 #include "tbb/task_scheduler_init.h"
17 #include "tbb/task_arena.h"
18 #include "tbb/task_group.h"
19 
20 namespace tensorflow {
21 
22  class TBBThreadPool : public tensorflow::thread::ThreadPoolInterface {
23  public:
24  static TBBThreadPool& instance(int nThreads = -1) {
26  return pool;
27  }
28 
29  explicit TBBThreadPool(int nThreads = -1)
30  : nThreads_(nThreads > 0 ? nThreads : tbb::task_scheduler_init::default_num_threads()), numScheduleCalled_(0) {
31  // when nThreads is zero or smaller, use the default value determined by tbb
32  }
33 
34  void Schedule(std::function<void()> fn) override {
35  numScheduleCalled_ += 1;
36 
37  // use a task arena to avoid having unrelated tasks start
38  // running on this thread, which could potentially start deadlocks
39  tbb::task_arena taskArena;
40  tbb::task_group taskGroup;
41 
42  // we are required to always call wait before destructor
43  auto doneWithTaskGroup = [&taskArena, &taskGroup](void*) {
44  taskArena.execute([&taskGroup]() { taskGroup.wait(); });
45  };
46  std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup)> taskGuard(&taskGroup, doneWithTaskGroup);
47 
48  // schedule the task
49  taskArena.execute([&taskGroup, &fn] { taskGroup.run(fn); });
50 
51  // reset the task guard which will call wait
52  taskGuard.reset();
53  }
54 
55  void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); }
56 
57  void Cancel() override {}
58 
59  int NumThreads() const override { return nThreads_; }
60 
61  int CurrentThreadId() const override {
62  static std::atomic<int> idCounter{0};
63  thread_local const int id = idCounter++;
64  return id;
65  }
66 
68 
69  private:
70  const int nThreads_;
71  std::atomic<int> numScheduleCalled_;
72  };
73 
74 } // namespace tensorflow
75 
76 #endif // PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
start
Definition: start.py:1
tensorflow::TBBThreadPool::TBBThreadPool
TBBThreadPool(int nThreads=-1)
Definition: TBBThreadPool.h:29
tensorflow::TBBThreadPool
Definition: TBBThreadPool.h:22
tensorflow::TBBThreadPool::ScheduleWithHint
void ScheduleWithHint(std::function< void()> fn, int start, int end) override
Definition: TBBThreadPool.h:55
tensorflow::TBBThreadPool::CurrentThreadId
int CurrentThreadId() const override
Definition: TBBThreadPool.h:61
mps_fire.end
end
Definition: mps_fire.py:242
CMS_THREAD_SAFE
#define CMS_THREAD_SAFE
Definition: thread_safety_macros.h:4
runTheMatrix.nThreads
nThreads
Definition: runTheMatrix.py:355
tensorflow::TBBThreadPool::GetNumScheduleCalled
int GetNumScheduleCalled()
Definition: TBBThreadPool.h:67
thread_safety_macros.h
tensorflow::TBBThreadPool::instance
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
cms::cuda::device::unique_ptr
std::unique_ptr< T, impl::DeviceDeleter > unique_ptr
Definition: device_unique_ptr.h:33
tensorflow::TBBThreadPool::Cancel
void Cancel() override
Definition: TBBThreadPool.h:57
triggerObjects_cff.id
id
Definition: triggerObjects_cff.py:31
tensorflow::TBBThreadPool::nThreads_
const int nThreads_
Definition: TBBThreadPool.h:70
HiBiasedCentrality_cfi.function
function
Definition: HiBiasedCentrality_cfi.py:4
personalPlayback.fn
fn
Definition: personalPlayback.py:515
tensorflow
Definition: NoThreadPool.h:18
tensorflow::TBBThreadPool::NumThreads
int NumThreads() const override
Definition: TBBThreadPool.h:59
tensorflow::TBBThreadPool::Schedule
void Schedule(std::function< void()> fn) override
Definition: TBBThreadPool.h:34
tensorflow::TBBThreadPool::numScheduleCalled_
std::atomic< int > numScheduleCalled_
Definition: TBBThreadPool.h:71
submitPVResolutionJobs.pool
pool
Definition: submitPVResolutionJobs.py:351