CMS 3D CMS Logo

workdivision.h
Go to the documentation of this file.
1 #ifndef HeterogeneousCore_AlpakaInterface_interface_workdivision_h
2 #define HeterogeneousCore_AlpakaInterface_interface_workdivision_h
3 
4 #include <type_traits>
5 
6 #include <alpaka/alpaka.hpp>
7 
10 
11 namespace cms::alpakatools {
12 
13  using namespace alpaka_common;
14 
15  // If the first argument is not a multiple of the second argument, round it up to the next multiple
16  inline constexpr Idx round_up_by(Idx value, Idx divisor) { return (value + divisor - 1) / divisor * divisor; }
17 
18  // Return the integer division of the first argument by the second argument, rounded up to the next integer
19  inline constexpr Idx divide_up_by(Idx value, Idx divisor) { return (value + divisor - 1) / divisor; }
20 
21  // Trait describing whether or not the accelerator expects the threads-per-block and elements-per-thread to be swapped
22  template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
23  struct requires_single_thread_per_block : public std::true_type {};
24 
25 #ifdef ALPAKA_ACC_GPU_CUDA_ENABLED
26  template <typename TDim>
27  struct requires_single_thread_per_block<alpaka::AccGpuCudaRt<TDim, Idx>> : public std::false_type {};
28 #endif // ALPAKA_ACC_GPU_CUDA_ENABLED
29 
30 #ifdef ALPAKA_ACC_GPU_HIP_ENABLED
31  template <typename TDim>
32  struct requires_single_thread_per_block<alpaka::AccGpuHipRt<TDim, Idx>> : public std::false_type {};
33 #endif // ALPAKA_ACC_GPU_HIP_ENABLED
34 
35 #ifdef ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLED
36  template <typename TDim>
37  struct requires_single_thread_per_block<alpaka::AccCpuThreads<TDim, Idx>> : public std::false_type {};
38 #endif // ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLED
39 
40  // Whether or not the accelerator expects the threads-per-block and elements-per-thread to be swapped
41  template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
43 
44  // Create an accelerator-dependent work division for 1-dimensional kernels
45  template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
47  if constexpr (not requires_single_thread_per_block_v<TAcc>) {
48  // On GPU backends, each thread is looking at a single element:
49  // - the number of threads per block is "elements";
50  // - the number of elements per thread is always 1.
51  return WorkDiv<Dim1D>(blocks, elements, Idx{1});
52  } else {
53  // On CPU backends, run serially with a single thread per block:
54  // - the number of threads per block is always 1;
55  // - the number of elements per thread is "elements".
56  return WorkDiv<Dim1D>(blocks, Idx{1}, elements);
57  }
58  }
59 
60  // Create the accelerator-dependent workdiv for N-dimensional kernels
61  template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
62  inline WorkDiv<alpaka::Dim<TAcc>> make_workdiv(const Vec<alpaka::Dim<TAcc>>& blocks,
63  const Vec<alpaka::Dim<TAcc>>& elements) {
64  using Dim = alpaka::Dim<TAcc>;
65  if constexpr (not requires_single_thread_per_block_v<TAcc>) {
66  // On GPU backends, each thread is looking at a single element:
67  // - the number of threads per block is "elements";
68  // - the number of elements per thread is always 1.
70  } else {
71  // On CPU backends, run serially with a single thread per block:
72  // - the number of threads per block is always 1;
73  // - the number of elements per thread is "elements".
75  }
76  }
77 
78  template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
80  public:
81  ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc)
82  : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
83  thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
84  stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
85  extent_{stride_} {}
86 
87  ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc, Idx extent)
88  : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
89  thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
90  stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
91  extent_{extent} {}
92 
93  class iterator {
94  friend class elements_with_stride;
95 
96  ALPAKA_FN_ACC inline iterator(Idx elements, Idx stride, Idx extent, Idx first)
97  : elements_{elements},
98  stride_{stride},
99  extent_{extent},
100  first_{std::min(first, extent)},
101  index_{first_},
102  range_{std::min(first + elements, extent)} {}
103 
104  public:
105  ALPAKA_FN_ACC inline Idx operator*() const { return index_; }
106 
107  // pre-increment the iterator
108  ALPAKA_FN_ACC inline iterator& operator++() {
109  if constexpr (requires_single_thread_per_block_v<TAcc>) {
110  // increment the index along the elements processed by the current thread
111  ++index_;
112  if (index_ < range_)
113  return *this;
114  }
115 
116  // increment the thread index with the grid stride
117  first_ += stride_;
118  index_ = first_;
119  range_ = std::min(first_ + elements_, extent_);
120  if (index_ < extent_)
121  return *this;
122 
123  // the iterator has reached or passed the end of the extent, clamp it to the extent
124  first_ = extent_;
125  index_ = extent_;
126  range_ = extent_;
127  return *this;
128  }
129 
130  // post-increment the iterator
131  ALPAKA_FN_ACC inline iterator operator++(int) {
132  iterator old = *this;
133  ++(*this);
134  return old;
135  }
136 
137  ALPAKA_FN_ACC inline bool operator==(iterator const& other) const {
138  return (index_ == other.index_) and (first_ == other.first_);
139  }
140 
141  ALPAKA_FN_ACC inline bool operator!=(iterator const& other) const { return not(*this == other); }
142 
143  private:
144  // non-const to support iterator copy and assignment
148  // modified by the pre/post-increment operator
152  };
153 
154  ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, thread_); }
155 
156  ALPAKA_FN_ACC inline iterator end() const { return iterator(elements_, stride_, extent_, extent_); }
157 
158  private:
159  const Idx elements_;
160  const Idx thread_;
161  const Idx stride_;
162  const Idx extent_;
163  };
164 
165  template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and (alpaka::Dim<TAcc>::value > 0)>>
167  public:
168  using Dim = alpaka::Dim<TAcc>;
169  using Vec = alpaka::Vec<Dim, Idx>;
170 
171  ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc)
172  : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)},
173  thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
174  stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc) * elements_},
175  extent_{stride_} {}
176 
177  ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc, Vec extent)
178  : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)},
179  thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
180  stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc) * elements_},
181  extent_{extent} {}
182 
183  // tag used to construct an end iterator
184  struct at_end_t {};
185 
186  class iterator {
188 
189  public:
190  ALPAKA_FN_ACC inline Vec operator*() const { return index_; }
191 
192  // pre-increment the iterator
193  ALPAKA_FN_ACC constexpr inline iterator operator++() {
194  increment();
195  return *this;
196  }
197 
198  // post-increment the iterator
199  ALPAKA_FN_ACC constexpr inline iterator operator++(int) {
200  iterator old = *this;
201  increment();
202  return old;
203  }
204 
205  ALPAKA_FN_ACC constexpr inline bool operator==(iterator const& other) const { return (index_ == other.index_); }
206 
207  ALPAKA_FN_ACC constexpr inline bool operator!=(iterator const& other) const { return not(*this == other); }
208 
209  private:
210  // construct an iterator pointing to the first element to be processed by the current thread
211  ALPAKA_FN_ACC inline iterator(elements_with_stride_nd const* loop, Vec first)
212  : loop_{loop},
213  first_{alpaka::elementwise_min(first, loop->extent_)},
214  range_{alpaka::elementwise_min(first + loop->elements_, loop->extent_)},
215  index_{first_} {}
216 
217  // construct an end iterator, pointing post the end of the extent
218  ALPAKA_FN_ACC inline iterator(elements_with_stride_nd const* loop, at_end_t const&)
219  : loop_{loop}, first_{loop_->extent_}, range_{loop_->extent_}, index_{loop_->extent_} {}
220 
221  template <size_t I>
222  ALPAKA_FN_ACC inline constexpr bool nth_elements_loop() {
223  bool overflow = false;
224  ++index_[I];
225  if (index_[I] >= range_[I]) {
226  index_[I] = first_[I];
227  overflow = true;
228  }
229  return overflow;
230  }
231 
232  template <size_t N>
233  ALPAKA_FN_ACC inline constexpr bool do_elements_loops() {
234  if constexpr (N == 0) {
235  // overflow
236  return true;
237  } else {
238  if (not nth_elements_loop<N - 1>()) {
239  return false;
240  } else {
241  return do_elements_loops<N - 1>();
242  }
243  }
244  }
245 
246  template <size_t I>
247  ALPAKA_FN_ACC inline constexpr bool nth_strided_loop() {
248  bool overflow = false;
249  first_[I] += loop_->stride_[I];
250  if (first_[I] >= loop_->extent_[I]) {
251  first_[I] = loop_->thread_[I];
252  overflow = true;
253  }
254  index_[I] = first_[I];
255  range_[I] = std::min(first_[I] + loop_->elements_[I], loop_->extent_[I]);
256  return overflow;
257  }
258 
259  template <size_t N>
260  ALPAKA_FN_ACC inline constexpr bool do_strided_loops() {
261  if constexpr (N == 0) {
262  // overflow
263  return true;
264  } else {
265  if (not nth_strided_loop<N - 1>()) {
266  return false;
267  } else {
268  return do_strided_loops<N - 1>();
269  }
270  }
271  }
272 
273  // increment the iterator
274  ALPAKA_FN_ACC inline constexpr void increment() {
275  if constexpr (requires_single_thread_per_block_v<TAcc>) {
276  // linear N-dimensional loops over the elements associated to the thread;
277  // do_elements_loops<>() returns true if any of those loops overflows
278  if (not do_elements_loops<Dim::value>()) {
279  // the elements loops did not overflow, return the next index
280  return;
281  }
282  }
283 
284  // strided N-dimensional loop over the threads in the kernel launch grid;
285  // do_strided_loops<>() returns true if any of those loops overflows
286  if (not do_strided_loops<Dim::value>()) {
287  // the strided loops did not overflow, return the next index
288  return;
289  }
290 
291  // the iterator has reached or passed the end of the extent, clamp it to the extent
292  first_ = loop_->extent_;
293  range_ = loop_->extent_;
294  index_ = loop_->extent_;
295  }
296 
297  // const pointer to the elements_with_stride_nd that the iterator refers to
299 
300  // modified by the pre/post-increment operator
301  Vec first_; // first element processed by this thread
302  Vec range_; // last element processed by this thread
303  Vec index_; // current element processed by this thread
304  };
305 
306  ALPAKA_FN_ACC inline iterator begin() const {
307  // check that all dimensions of the current thread index are within the extent
308  if ((thread_ < extent_).all()) {
309  // construct an iterator pointing to the first element to be processed by the current thread
310  return iterator{this, thread_};
311  } else {
312  // construct an end iterator, pointing post the end of the extent
313  return iterator{this, at_end_t{}};
314  }
315  }
316 
317  ALPAKA_FN_ACC inline iterator end() const {
318  // construct an end iterator, pointing post the end of the extent
319  return iterator{this, at_end_t{}};
320  }
321 
322  private:
323  const Vec elements_;
324  const Vec thread_;
325  const Vec stride_;
326  const Vec extent_;
327  };
328 
329 } // namespace cms::alpakatools
330 
331 #endif // HeterogeneousCore_AlpakaInterface_interface_workdivision_h
ALPAKA_FN_ACC constexpr bool nth_strided_loop()
Definition: workdivision.h:247
ALPAKA_FN_ACC elements_with_stride_nd(TAcc const &acc, Vec extent)
Definition: workdivision.h:177
ALPAKA_FN_ACC constexpr bool operator==(iterator const &other) const
Definition: workdivision.h:205
ALPAKA_FN_ACC elements_with_stride(TAcc const &acc)
Definition: workdivision.h:81
WorkDiv< Dim1D > make_workdiv(Idx blocks, Idx elements)
Definition: workdivision.h:46
def all(container)
workaround iterator generators for ROOT classes
Definition: cmstools.py:25
constexpr Idx divide_up_by(Idx value, Idx divisor)
Definition: workdivision.h:19
ALPAKA_FN_ACC elements_with_stride(TAcc const &acc, Idx extent)
Definition: workdivision.h:87
uint32_t Idx
Definition: config.h:13
ALPAKA_FN_ACC iterator begin() const
Definition: workdivision.h:154
ALPAKA_FN_ACC elements_with_stride_nd(TAcc const &acc)
Definition: workdivision.h:171
alpaka::WorkDivMembers< TDim, Idx > WorkDiv
Definition: config.h:30
ALPAKA_FN_ACC iterator(elements_with_stride_nd const *loop, at_end_t const &)
Definition: workdivision.h:218
ALPAKA_FN_ACC constexpr bool nth_elements_loop()
Definition: workdivision.h:222
ALPAKA_FN_ACC iterator end() const
Definition: workdivision.h:156
constexpr Idx round_up_by(Idx value, Idx divisor)
Definition: workdivision.h:16
ALPAKA_FN_ACC iterator(Idx elements, Idx stride, Idx extent, Idx first)
Definition: workdivision.h:96
const std::complex< double > I
Definition: I.h:8
Definition: value.py:1
ALPAKA_FN_ACC iterator end() const
Definition: workdivision.h:317
ALPAKA_FN_ACC constexpr iterator operator++()
Definition: workdivision.h:193
#define N
Definition: blowfish.cc:9
alpaka::Vec< TDim, Idx > Vec
Definition: config.h:23
ALPAKA_FN_ACC constexpr bool do_elements_loops()
Definition: workdivision.h:233
ALPAKA_FN_ACC bool operator==(iterator const &other) const
Definition: workdivision.h:137
ALPAKA_FN_ACC iterator begin() const
Definition: workdivision.h:306
ALPAKA_FN_ACC constexpr iterator operator++(int)
Definition: workdivision.h:199
constexpr bool requires_single_thread_per_block_v
Definition: workdivision.h:42
ALPAKA_FN_ACC iterator(elements_with_stride_nd const *loop, Vec first)
Definition: workdivision.h:211
ALPAKA_FN_ACC constexpr bool do_strided_loops()
Definition: workdivision.h:260
ALPAKA_FN_ACC constexpr bool operator!=(iterator const &other) const
Definition: workdivision.h:207
ALPAKA_FN_ACC bool operator!=(iterator const &other) const
Definition: workdivision.h:141