CMS 3D CMS Logo

TBBThreadPool.h
Go to the documentation of this file.
1 /*
2  * Custom TensorFlow thread pool implementation that schedules tasks in TBB.
3  * Based on TensorFlow 2.1.
4  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
5  *
6  * Author: Marcel Rieger
7  */
8 
9 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
10 #define PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
11 
13 
14 #include "tensorflow/core/lib/core/threadpool.h"
15 
16 #include "oneapi/tbb/task_arena.h"
17 #include "oneapi/tbb/task_group.h"
18 #include "oneapi/tbb/global_control.h"
19 
20 namespace tensorflow {
21 
22  class TBBThreadPool : public tensorflow::thread::ThreadPoolInterface {
23  public:
24  static TBBThreadPool& instance(int nThreads = -1) {
26  return pool;
27  }
28 
29  explicit TBBThreadPool(int nThreads = -1)
31  : tbb::global_control::active_value(tbb::global_control::max_allowed_parallelism)),
33  // when nThreads is zero or smaller, use the default value determined by tbb
34  }
35 
36  void Schedule(std::function<void()> fn) override {
37  numScheduleCalled_ += 1;
38 
39  // use a task arena to avoid having unrelated tasks start
40  // running on this thread, which could potentially start deadlocks
41  tbb::task_arena taskArena;
42  tbb::task_group taskGroup;
43 
44  // we are required to always call wait before destructor
45  auto doneWithTaskGroup = [&taskArena, &taskGroup](void*) {
46  taskArena.execute([&taskGroup]() { taskGroup.wait(); });
47  };
48  std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup)> taskGuard(&taskGroup, doneWithTaskGroup);
49 
50  // schedule the task
51  taskArena.execute([&taskGroup, &fn] { taskGroup.run(fn); });
52 
53  // reset the task guard which will call wait
54  taskGuard.reset();
55  }
56 
57  void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); }
58 
59  void Cancel() override {}
60 
61  int NumThreads() const override { return nThreads_; }
62 
63  int CurrentThreadId() const override {
64  static std::atomic<int> idCounter{0};
65  thread_local const int id = idCounter++;
66  return id;
67  }
68 
70 
71  private:
72  const int nThreads_;
73  std::atomic<int> numScheduleCalled_;
74  };
75 
76 } // namespace tensorflow
77 
78 #endif // PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
Definition: start.py:1
void ScheduleWithHint(std::function< void()> fn, int start, int end) override
Definition: TBBThreadPool.h:57
std::atomic< int > numScheduleCalled_
Definition: TBBThreadPool.h:73
void Cancel() override
Definition: TBBThreadPool.h:59
TBBThreadPool(int nThreads=-1)
Definition: TBBThreadPool.h:29
int CurrentThreadId() const override
Definition: TBBThreadPool.h:63
static TBBThreadPool & instance(int nThreads=-1)
Definition: TBBThreadPool.h:24
#define CMS_THREAD_SAFE
int NumThreads() const override
Definition: TBBThreadPool.h:61
void Schedule(std::function< void()> fn) override
Definition: TBBThreadPool.h:36