Skip to content

Commit b7f09d8

Browse files
authored
[SYCL][CUDA] Non-uniform algorithm implementations for ext_oneapi_cuda. (#9671)
This PR adds cuda support for fixed_size_group, ballot_group, and opportunistic_group algorithms. All group algorithm support added for the SPIRV impls (those added in e.g. #9181) is correspondingly added here for the cuda backend. Everything except the reduce/scans uses the same impl for all non-uniform groups. Reduce algorithms also use the same impl for all group types on sm80 for special IsRedux types/ops pairs. Otherwise reduce/scans have two impl categories: 1.fixed_size_group 2.opportunistic_group, ballot_group, (and tangle_group once it is supported) all use the same impls. Note that tangle_group is still not supported. However all algorithms implemented by ballot_group/opportunistic_group will I think be appropriate for tangle_group when it is supported. --------- Signed-off-by: JackAKirk <[email protected]>
1 parent dc48d7c commit b7f09d8

File tree

8 files changed

+410
-8
lines changed

8 files changed

+410
-8
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ template <typename ParentGroup>
152152
bool GroupAll(ext::oneapi::experimental::tangle_group<ParentGroup>, bool pred) {
153153
return __spirv_GroupNonUniformAll(group_scope<ParentGroup>::value, pred);
154154
}
155-
template <typename Group>
155+
156156
bool GroupAll(const ext::oneapi::experimental::opportunistic_group &,
157157
bool pred) {
158158
return __spirv_GroupNonUniformAll(
@@ -1022,8 +1022,10 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
10221022
template <typename Group>
10231023
typename std::enable_if_t<
10241024
ext::oneapi::experimental::is_user_constructed_group_v<Group>>
1025-
ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
1026-
#if defined(__SPIR__)
1025+
ControlBarrier(Group g, memory_scope FenceScope, memory_order Order) {
1026+
#if defined(__NVPTX__)
1027+
__nvvm_bar_warp_sync(detail::ExtractMask(detail::GetMask(g))[0]);
1028+
#else
10271029
// SPIR-V does not define an instruction to synchronize partial groups.
10281030
// However, most (possibly all?) of the current SPIR-V targets execute
10291031
// work-items in lockstep, so we can probably get away with a MemoryBarrier.
@@ -1033,8 +1035,6 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
10331035
__spv::MemorySemanticsMask::SubgroupMemory |
10341036
__spv::MemorySemanticsMask::WorkgroupMemory |
10351037
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
1036-
#elif defined(__NVPTX__)
1037-
// TODO: Call syncwarp with appropriate mask extracted from the group
10381038
#endif
10391039
}
10401040

sycl/include/sycl/detail/type_traits.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121
namespace sycl {
2222
__SYCL_INLINE_VER_NAMESPACE(_V1) {
23+
namespace detail {
24+
template <class T> struct is_fixed_size_group : std::false_type {};
25+
26+
template <class T>
27+
inline constexpr bool is_fixed_size_group_v = is_fixed_size_group<T>::value;
28+
} // namespace detail
29+
2330
template <int Dimensions> class group;
2431
namespace ext::oneapi {
2532
struct sub_group;
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
//==----- non_uniform_algorithms.hpp - cuda masked subgroup algorithms -----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <sycl/known_identity.hpp>
11+
12+
namespace sycl {
13+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
14+
namespace detail {
15+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
16+
17+
template <typename T, class BinaryOperation>
18+
using IsRedux = std::bool_constant<
19+
std::is_integral<T>::value && IsBitAND<T, BinaryOperation>::value ||
20+
IsBitOR<T, BinaryOperation>::value || IsBitXOR<T, BinaryOperation>::value ||
21+
IsPlus<T, BinaryOperation>::value || IsMinimum<T, BinaryOperation>::value ||
22+
IsMaximum<T, BinaryOperation>::value>;
23+
24+
//// Masked reductions using redux.sync, requires integer types
25+
26+
template <typename Group, typename T, class BinaryOperation>
27+
std::enable_if_t<
28+
is_sugeninteger<T>::value && IsMinimum<T, BinaryOperation>::value, T>
29+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
30+
const uint32_t MemberMask) {
31+
return __nvvm_redux_sync_umin(x, MemberMask);
32+
}
33+
34+
template <typename Group, typename T, class BinaryOperation>
35+
std::enable_if_t<
36+
is_sigeninteger<T>::value && IsMinimum<T, BinaryOperation>::value, T>
37+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
38+
const uint32_t MemberMask) {
39+
return __nvvm_redux_sync_min(x, MemberMask);
40+
}
41+
42+
template <typename Group, typename T, class BinaryOperation>
43+
std::enable_if_t<
44+
is_sugeninteger<T>::value && IsMaximum<T, BinaryOperation>::value, T>
45+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
46+
const uint32_t MemberMask) {
47+
return __nvvm_redux_sync_umax(x, MemberMask);
48+
}
49+
50+
template <typename Group, typename T, class BinaryOperation>
51+
std::enable_if_t<
52+
is_sigeninteger<T>::value && IsMaximum<T, BinaryOperation>::value, T>
53+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
54+
const uint32_t MemberMask) {
55+
return __nvvm_redux_sync_max(x, MemberMask);
56+
}
57+
58+
template <typename Group, typename T, class BinaryOperation>
59+
std::enable_if_t<(is_sugeninteger<T>::value || is_sigeninteger<T>::value) &&
60+
IsPlus<T, BinaryOperation>::value,
61+
T>
62+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
63+
const uint32_t MemberMask) {
64+
return __nvvm_redux_sync_add(x, MemberMask);
65+
}
66+
67+
template <typename Group, typename T, class BinaryOperation>
68+
std::enable_if_t<(is_sugeninteger<T>::value || is_sigeninteger<T>::value) &&
69+
IsBitAND<T, BinaryOperation>::value,
70+
T>
71+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
72+
const uint32_t MemberMask) {
73+
return __nvvm_redux_sync_and(x, MemberMask);
74+
}
75+
76+
template <typename Group, typename T, class BinaryOperation>
77+
std::enable_if_t<(is_sugeninteger<T>::value || is_sigeninteger<T>::value) &&
78+
IsBitOR<T, BinaryOperation>::value,
79+
T>
80+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
81+
const uint32_t MemberMask) {
82+
return __nvvm_redux_sync_or(x, MemberMask);
83+
}
84+
85+
template <typename Group, typename T, class BinaryOperation>
86+
std::enable_if_t<(is_sugeninteger<T>::value || is_sigeninteger<T>::value) &&
87+
IsBitXOR<T, BinaryOperation>::value,
88+
T>
89+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
90+
const uint32_t MemberMask) {
91+
return __nvvm_redux_sync_xor(x, MemberMask);
92+
}
93+
////
94+
95+
//// Shuffle based masked reduction impls
96+
97+
// fixed_size_group group reduction using shfls
98+
template <typename Group, typename T, class BinaryOperation>
99+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<is_fixed_size_group_v<Group>, T>
100+
masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
101+
const uint32_t MemberMask) {
102+
for (int i = g.get_local_range()[0] / 2; i > 0; i /= 2) {
103+
T tmp;
104+
if constexpr (std::is_same_v<T, double>) {
105+
int x_a, x_b;
106+
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(x));
107+
auto tmp_a = __nvvm_shfl_sync_bfly_i32(MemberMask, x_a, -1, i);
108+
auto tmp_b = __nvvm_shfl_sync_bfly_i32(MemberMask, x_b, -1, i);
109+
asm volatile("mov.b64 %0,{%1,%2};" : "=d"(tmp) : "r"(tmp_a), "r"(tmp_b));
110+
} else if constexpr (std::is_same_v<T, long> ||
111+
std::is_same_v<T, unsigned long>) {
112+
int x_a, x_b;
113+
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(x));
114+
auto tmp_a = __nvvm_shfl_sync_bfly_i32(MemberMask, x_a, -1, i);
115+
auto tmp_b = __nvvm_shfl_sync_bfly_i32(MemberMask, x_b, -1, i);
116+
asm volatile("mov.b64 %0,{%1,%2};" : "=l"(tmp) : "r"(tmp_a), "r"(tmp_b));
117+
} else if constexpr (std::is_same_v<T, half>) {
118+
short tmp_b16;
119+
asm volatile("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(x));
120+
auto tmp_b32 = __nvvm_shfl_sync_bfly_i32(
121+
MemberMask, static_cast<int>(tmp_b16), -1, i);
122+
asm volatile("mov.b16 %0,%1;"
123+
: "=h"(tmp)
124+
: "h"(static_cast<short>(tmp_b32)));
125+
} else if constexpr (std::is_same_v<T, float>) {
126+
auto tmp_b32 =
127+
__nvvm_shfl_sync_bfly_i32(MemberMask, __nvvm_bitcast_f2i(x), -1, i);
128+
tmp = __nvvm_bitcast_i2f(tmp_b32);
129+
} else {
130+
tmp = __nvvm_shfl_sync_bfly_i32(MemberMask, x, -1, i);
131+
}
132+
x = binary_op(x, tmp);
133+
}
134+
return x;
135+
}
136+
137+
template <typename Group, typename T>
138+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
139+
ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
140+
non_uniform_shfl_T(const uint32_t MemberMask, T x, int shfl_param) {
141+
if constexpr (is_fixed_size_group_v<Group>) {
142+
return __nvvm_shfl_sync_up_i32(MemberMask, x, shfl_param, 0);
143+
} else {
144+
return __nvvm_shfl_sync_idx_i32(MemberMask, x, shfl_param, 31);
145+
}
146+
}
147+
148+
template <typename Group, typename T>
149+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
150+
ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
151+
non_uniform_shfl(Group g, const uint32_t MemberMask, T x, int shfl_param) {
152+
T res;
153+
if constexpr (std::is_same_v<T, double>) {
154+
int x_a, x_b;
155+
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(x));
156+
auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
157+
auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
158+
asm volatile("mov.b64 %0,{%1,%2};" : "=d"(res) : "r"(tmp_a), "r"(tmp_b));
159+
} else if constexpr (std::is_same_v<T, long> ||
160+
std::is_same_v<T, unsigned long>) {
161+
int x_a, x_b;
162+
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(x));
163+
auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
164+
auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
165+
asm volatile("mov.b64 %0,{%1,%2};" : "=l"(res) : "r"(tmp_a), "r"(tmp_b));
166+
} else if constexpr (std::is_same_v<T, half>) {
167+
short tmp_b16;
168+
asm volatile("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(x));
169+
auto tmp_b32 = non_uniform_shfl_T<Group>(
170+
MemberMask, static_cast<int>(tmp_b16), shfl_param);
171+
asm volatile("mov.b16 %0,%1;"
172+
: "=h"(res)
173+
: "h"(static_cast<short>(tmp_b32)));
174+
} else if constexpr (std::is_same_v<T, float>) {
175+
auto tmp_b32 = non_uniform_shfl_T<Group>(MemberMask, __nvvm_bitcast_f2i(x),
176+
shfl_param);
177+
res = __nvvm_bitcast_i2f(tmp_b32);
178+
} else {
179+
res = non_uniform_shfl_T<Group>(MemberMask, x, shfl_param);
180+
}
181+
return res;
182+
}
183+
184+
// Opportunistic/Ballot group reduction using shfls
185+
template <typename Group, typename T, class BinaryOperation>
186+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
187+
ext::oneapi::experimental::is_user_constructed_group_v<Group> &&
188+
!is_fixed_size_group_v<Group>,
189+
T>
190+
masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
191+
const uint32_t MemberMask) {
192+
193+
unsigned localSetBit = g.get_local_id()[0] + 1;
194+
195+
// number of elements requiring binary operations each loop iteration
196+
auto opRange = g.get_local_range()[0];
197+
198+
// stride between local_ids forming a binary op
199+
unsigned stride = opRange / 2;
200+
while (stride >= 1) {
201+
202+
// if (remainder == 1), there is a WI without a binary op partner
203+
unsigned remainder = opRange % 2;
204+
205+
// unfolded position of set bit in mask of shfl src lane
206+
int unfoldedSrcSetBit = localSetBit + stride;
207+
208+
// __nvvm_fns automatically wraps around to the correct bit position.
209+
// There is no performance impact on src_set_bit position wrt localSetBit
210+
auto tmp = non_uniform_shfl(g, MemberMask, x,
211+
__nvvm_fns(MemberMask, 0, unfoldedSrcSetBit));
212+
213+
if (!(localSetBit == 1 && remainder != 0)) {
214+
x = binary_op(x, tmp);
215+
}
216+
217+
opRange = stride + remainder;
218+
stride = opRange / 2;
219+
}
220+
unsigned broadID;
221+
asm volatile(".reg .u32 rev;\n\t"
222+
"brev.b32 rev, %1;\n\t" // reverse mask bits
223+
"clz.b32 %0, rev;"
224+
: "=r"(broadID)
225+
: "r"(MemberMask));
226+
227+
return non_uniform_shfl(g, MemberMask, x, broadID);
228+
}
229+
230+
// Non Redux types must fall back to shfl based implementations.
231+
template <typename Group, typename T, class BinaryOperation>
232+
std::enable_if_t<
233+
std::is_same<IsRedux<T, BinaryOperation>, std::false_type>::value &&
234+
ext::oneapi::experimental::is_user_constructed_group_v<Group>,
235+
T>
236+
masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
237+
const uint32_t MemberMask) {
238+
return masked_reduction_cuda_shfls(g, x, binary_op, MemberMask);
239+
}
240+
241+
// get_identity is only currently used in this cuda specific header. If in the
242+
// future it has more general use it should be moved to a more appropriate
243+
// header.
244+
template <typename T, class BinaryOperation>
245+
inline __SYCL_ALWAYS_INLINE
246+
std::enable_if_t<IsPlus<T, BinaryOperation>::value ||
247+
IsBitOR<T, BinaryOperation>::value ||
248+
IsBitXOR<T, BinaryOperation>::value,
249+
T>
250+
get_identity() {
251+
return 0;
252+
}
253+
254+
template <typename T, class BinaryOperation>
255+
inline __SYCL_ALWAYS_INLINE
256+
std::enable_if_t<IsMultiplies<T, BinaryOperation>::value, T>
257+
get_identity() {
258+
return 1;
259+
}
260+
261+
template <typename T, class BinaryOperation>
262+
inline __SYCL_ALWAYS_INLINE
263+
std::enable_if_t<IsBitAND<T, BinaryOperation>::value, T>
264+
get_identity() {
265+
return ~0;
266+
}
267+
268+
#define GET_ID(OP_CHECK, OP) \
269+
template <typename T, class BinaryOperation> \
270+
inline __SYCL_ALWAYS_INLINE \
271+
std::enable_if_t<OP_CHECK<T, BinaryOperation>::value, T> \
272+
get_identity() { \
273+
return std::numeric_limits<T>::OP(); \
274+
}
275+
276+
GET_ID(IsMinimum, max)
277+
GET_ID(IsMaximum, min)
278+
279+
#undef GET_ID
280+
281+
//// Shuffle based masked reduction impls
282+
283+
// fixed_size_group group scan using shfls
284+
template <__spv::GroupOperation Op, typename Group, typename T,
285+
class BinaryOperation>
286+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<is_fixed_size_group_v<Group>, T>
287+
masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
288+
const uint32_t MemberMask) {
289+
unsigned localIdVal = g.get_local_id()[0];
290+
for (int i = 1; i < g.get_local_range()[0]; i *= 2) {
291+
auto tmp = non_uniform_shfl(g, MemberMask, x, i);
292+
if (localIdVal >= i)
293+
x = binary_op(x, tmp);
294+
}
295+
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {
296+
297+
x = non_uniform_shfl(g, MemberMask, x, 1);
298+
if (localIdVal == 0) {
299+
return get_identity<T, BinaryOperation>();
300+
}
301+
}
302+
return x;
303+
}
304+
305+
template <__spv::GroupOperation Op, typename Group, typename T,
306+
class BinaryOperation>
307+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
308+
ext::oneapi::experimental::is_user_constructed_group_v<Group> &&
309+
!is_fixed_size_group_v<Group>,
310+
T>
311+
masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
312+
const uint32_t MemberMask) {
313+
unsigned localIdVal = g.get_local_id()[0];
314+
unsigned localSetBit = localIdVal + 1;
315+
316+
for (int i = 1; i < g.get_local_range()[0]; i *= 2) {
317+
int unfoldedSrcSetBit = localSetBit - i;
318+
319+
auto tmp = non_uniform_shfl(g, MemberMask, x,
320+
__nvvm_fns(MemberMask, 0, unfoldedSrcSetBit));
321+
if (localIdVal >= i)
322+
x = binary_op(x, tmp);
323+
}
324+
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {
325+
x = non_uniform_shfl(g, MemberMask, x,
326+
__nvvm_fns(MemberMask, 0, localSetBit - 1));
327+
if (localIdVal == 0) {
328+
return get_identity<T, BinaryOperation>();
329+
}
330+
}
331+
return x;
332+
}
333+
334+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
335+
} // namespace detail
336+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
337+
} // namespace sycl

sycl/include/sycl/ext/oneapi/experimental/fixed_size_group.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ struct is_user_constructed_group<fixed_size_group<PartitionSize, ParentGroup>>
163163

164164
} // namespace ext::oneapi::experimental
165165

166+
namespace detail {
167+
template <size_t PartitionSize, typename ParentGroup>
168+
struct is_fixed_size_group<
169+
ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>>
170+
: std::true_type {};
171+
} // namespace detail
172+
166173
template <size_t PartitionSize, typename ParentGroup>
167174
struct is_group<
168175
ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>>

0 commit comments

Comments
 (0)