Skip to content

Commit 01ac033

Browse files
Pennycookaelovikov-intelsteffenlarsen
authored
[SYCL] Add fixed_size_group support to algorithms (#9181)
Enables the following functions to be used with fixed_size_group arguments: - group_barrier - group_broadcast - any_of_group - all_of_group - none_of_group - reduce_over_group - exclusive_scan_over_group - inclusive_scan_over_group Signed-off-by: John Pennycook [email protected] --------- Signed-off-by: John Pennycook [email protected] Co-authored-by: aelovikov-intel <[email protected]> Co-authored-by: Steffen Larsen <[email protected]>
1 parent a1e38b5 commit 01ac033

File tree

6 files changed

+298
-1
lines changed

6 files changed

+298
-1
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,10 @@ template <typename ValueT, typename IdT>
972972
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
973973
__spirv_GroupNonUniformBroadcast(__spv::Scope::Flag, ValueT, IdT);
974974

975+
template <typename ValueT, typename IdT>
976+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
977+
__spirv_GroupNonUniformShuffle(__spv::Scope::Flag, ValueT, IdT) noexcept;
978+
975979
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool
976980
__spirv_GroupNonUniformAll(__spv::Scope::Flag, bool);
977981

@@ -1030,6 +1034,71 @@ template <typename ValueT>
10301034
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
10311035
__spirv_GroupNonUniformBitwiseAnd(__spv::Scope::Flag, unsigned int, ValueT);
10321036

1037+
template <typename ValueT>
1038+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1039+
__spirv_GroupNonUniformSMin(__spv::Scope::Flag, unsigned int, ValueT,
1040+
unsigned int);
1041+
1042+
template <typename ValueT>
1043+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1044+
__spirv_GroupNonUniformUMin(__spv::Scope::Flag, unsigned int, ValueT,
1045+
unsigned int);
1046+
1047+
template <typename ValueT>
1048+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1049+
__spirv_GroupNonUniformFMin(__spv::Scope::Flag, unsigned int, ValueT,
1050+
unsigned int);
1051+
1052+
template <typename ValueT>
1053+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1054+
__spirv_GroupNonUniformSMax(__spv::Scope::Flag, unsigned int, ValueT,
1055+
unsigned int);
1056+
1057+
template <typename ValueT>
1058+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1059+
__spirv_GroupNonUniformUMax(__spv::Scope::Flag, unsigned int, ValueT,
1060+
unsigned int);
1061+
1062+
template <typename ValueT>
1063+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1064+
__spirv_GroupNonUniformFMax(__spv::Scope::Flag, unsigned int, ValueT,
1065+
unsigned int);
1066+
1067+
template <typename ValueT>
1068+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1069+
__spirv_GroupNonUniformIAdd(__spv::Scope::Flag, unsigned int, ValueT,
1070+
unsigned int);
1071+
1072+
template <typename ValueT>
1073+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1074+
__spirv_GroupNonUniformFAdd(__spv::Scope::Flag, unsigned int, ValueT,
1075+
unsigned int);
1076+
1077+
template <typename ValueT>
1078+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1079+
__spirv_GroupNonUniformIMul(__spv::Scope::Flag, unsigned int, ValueT,
1080+
unsigned int);
1081+
1082+
template <typename ValueT>
1083+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1084+
__spirv_GroupNonUniformFMul(__spv::Scope::Flag, unsigned int, ValueT,
1085+
unsigned int);
1086+
1087+
template <typename ValueT>
1088+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1089+
__spirv_GroupNonUniformBitwiseOr(__spv::Scope::Flag, unsigned int, ValueT,
1090+
unsigned int);
1091+
1092+
template <typename ValueT>
1093+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1094+
__spirv_GroupNonUniformBitwiseXor(__spv::Scope::Flag, unsigned int, ValueT,
1095+
unsigned int);
1096+
1097+
template <typename ValueT>
1098+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1099+
__spirv_GroupNonUniformBitwiseAnd(__spv::Scope::Flag, unsigned int, ValueT,
1100+
unsigned int);
1101+
10331102
extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT void
10341103
__clc_BarrierInitialize(int64_t *state, int32_t expected_count) noexcept;
10351104

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ struct MemorySemanticsMask {
109109
enum class GroupOperation : uint32_t {
110110
Reduce = 0,
111111
InclusiveScan = 1,
112-
ExclusiveScan = 2
112+
ExclusiveScan = 2,
113+
ClusteredReduce = 3,
113114
};
114115

115116
#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)

sycl/include/sycl/detail/spirv.hpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace oneapi {
2626
struct sub_group;
2727
namespace experimental {
2828
template <typename ParentGroup> class ballot_group;
29+
template <size_t PartitionSize, typename ParentGroup> class fixed_size_group;
2930
} // namespace experimental
3031
} // namespace oneapi
3132
} // namespace ext
@@ -65,6 +66,12 @@ struct group_scope<sycl::ext::oneapi::experimental::ballot_group<ParentGroup>> {
6566
static constexpr __spv::Scope::Flag value = group_scope<ParentGroup>::value;
6667
};
6768

69+
template <size_t PartitionSize, typename ParentGroup>
70+
struct group_scope<sycl::ext::oneapi::experimental::fixed_size_group<
71+
PartitionSize, ParentGroup>> {
72+
static constexpr __spv::Scope::Flag value = group_scope<ParentGroup>::value;
73+
};
74+
6875
// Generic shuffles and broadcasts may require multiple calls to
6976
// intrinsics, and should use the fewest broadcasts possible
7077
// - Loop over chunks until remaining bytes < chunk size
@@ -118,6 +125,16 @@ bool GroupAll(ext::oneapi::experimental::ballot_group<ParentGroup> g,
118125
return __spirv_GroupNonUniformAll(group_scope<ParentGroup>::value, pred);
119126
}
120127
}
128+
template <size_t PartitionSize, typename ParentGroup>
129+
bool GroupAll(
130+
ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>,
131+
bool pred) {
132+
// GroupNonUniformAll doesn't support cluster size, so use a reduction
133+
return __spirv_GroupNonUniformBitwiseAnd(
134+
group_scope<ParentGroup>::value,
135+
static_cast<uint32_t>(__spv::GroupOperation::ClusteredReduce),
136+
static_cast<uint32_t>(pred), PartitionSize);
137+
}
121138

122139
template <typename Group> bool GroupAny(Group, bool pred) {
123140
return __spirv_GroupAny(group_scope<Group>::value, pred);
@@ -134,6 +151,16 @@ bool GroupAny(ext::oneapi::experimental::ballot_group<ParentGroup> g,
134151
return __spirv_GroupNonUniformAny(group_scope<ParentGroup>::value, pred);
135152
}
136153
}
154+
template <size_t PartitionSize, typename ParentGroup>
155+
bool GroupAny(
156+
ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>,
157+
bool pred) {
158+
// GroupNonUniformAny doesn't support cluster size, so use a reduction
159+
return __spirv_GroupNonUniformBitwiseOr(
160+
group_scope<ParentGroup>::value,
161+
static_cast<uint32_t>(__spv::GroupOperation::ClusteredReduce),
162+
static_cast<uint32_t>(pred), PartitionSize);
163+
}
137164

138165
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
139166
// FIXME: Do not special-case for half or vec once all backends support all data
@@ -231,6 +258,29 @@ GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
231258
OCLX, OCLId);
232259
}
233260
}
261+
template <size_t PartitionSize, typename ParentGroup, typename T, typename IdT>
262+
EnableIfNativeBroadcast<T, IdT> GroupBroadcast(
263+
ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup> g,
264+
T x, IdT local_id) {
265+
// Remap local_id to its original numbering in ParentGroup
266+
auto LocalId = g.get_group_linear_id() * PartitionSize + local_id;
267+
268+
// TODO: Refactor to avoid duplication after design settles.
269+
using GroupIdT = typename GroupId<ParentGroup>::type;
270+
GroupIdT GroupLocalId = static_cast<GroupIdT>(LocalId);
271+
using OCLT = detail::ConvertToOpenCLType_t<T>;
272+
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
273+
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
274+
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
275+
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
276+
277+
// NonUniformBroadcast requires Id to be dynamically uniform, which does not
278+
// hold here; each partition is broadcasting a separate index. We could
279+
// fallback to either NonUniformShuffle or a NonUniformBroadcast per
280+
// partition, and it's unclear which will be faster in practice.
281+
return __spirv_GroupNonUniformShuffle(group_scope<ParentGroup>::value, OCLX,
282+
OCLId);
283+
}
234284

235285
template <typename Group, typename T, typename IdT>
236286
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(Group g, T x, IdT local_id) {
@@ -950,6 +1000,43 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
9501000
} else { \
9511001
return __spirv_GroupNonUniform##Instruction(Scope, OpInt, Arg); \
9521002
} \
1003+
} \
1004+
\
1005+
template <__spv::GroupOperation Op, size_t PartitionSize, \
1006+
typename ParentGroup, typename T> \
1007+
inline T Group##Instruction( \
1008+
ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup> \
1009+
g, \
1010+
T x) { \
1011+
using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
1012+
\
1013+
using OCLT = std::conditional_t< \
1014+
std::is_same<ConvertedT, cl_char>() || \
1015+
std::is_same<ConvertedT, cl_short>(), \
1016+
cl_int, \
1017+
std::conditional_t<std::is_same<ConvertedT, cl_uchar>() || \
1018+
std::is_same<ConvertedT, cl_ushort>(), \
1019+
cl_uint, ConvertedT>>; \
1020+
OCLT Arg = x; \
1021+
constexpr auto Scope = group_scope<ParentGroup>::value; \
1022+
/* SPIR-V only defines a ClusteredReduce, with no equivalents for scan. */ \
1023+
/* Emulate Clustered*Scan using control flow to separate clusters. */ \
1024+
if constexpr (Op == __spv::GroupOperation::Reduce) { \
1025+
constexpr auto OpInt = \
1026+
static_cast<unsigned int>(__spv::GroupOperation::ClusteredReduce); \
1027+
return __spirv_GroupNonUniform##Instruction(Scope, OpInt, Arg, \
1028+
PartitionSize); \
1029+
} else { \
1030+
T tmp; \
1031+
for (size_t Cluster = 0; Cluster < g.get_group_linear_range(); \
1032+
++Cluster) { \
1033+
if (Cluster == g.get_group_linear_id()) { \
1034+
constexpr auto OpInt = static_cast<unsigned int>(Op); \
1035+
tmp = __spirv_GroupNonUniform##Instruction(Scope, OpInt, Arg); \
1036+
} \
1037+
} \
1038+
return tmp; \
1039+
} \
9531040
}
9541041

9551042
__SYCL_GROUP_COLLECTIVE_OVERLOAD(SMin)

sycl/include/sycl/ext/oneapi/experimental/fixed_size_group.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,5 +137,11 @@ struct is_user_constructed_group<fixed_size_group<PartitionSize, ParentGroup>>
137137
: std::true_type {};
138138

139139
} // namespace ext::oneapi::experimental
140+
141+
template <size_t PartitionSize, typename ParentGroup>
142+
struct is_group<
143+
ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>>
144+
: std::true_type {};
145+
140146
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
141147
} // namespace sycl

sycl/include/sycl/ext/oneapi/experimental/non_uniform_groups.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ namespace ext::oneapi::experimental {
6363

6464
// Forward declarations of non-uniform group types for algorithm definitions
6565
template <typename ParentGroup> class ballot_group;
66+
template <size_t PartitionSize, typename ParentGroup> class fixed_size_group;
6667

6768
} // namespace ext::oneapi::experimental
6869

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// RUN: %clangxx -fsycl -fsycl-device-code-split=per_kernel -fsycl-targets=%sycl_triple %s -o %t.out
2+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
3+
//
4+
// UNSUPPORTED: cpu || cuda || hip
5+
6+
#include <sycl/sycl.hpp>
7+
#include <vector>
8+
namespace syclex = sycl::ext::oneapi::experimental;
9+
10+
template <size_t PartitionSize> class TestKernel;
11+
12+
template <size_t PartitionSize> void test() {
13+
sycl::queue Q;
14+
15+
constexpr uint32_t SGSize = 32;
16+
auto SGSizes = Q.get_device().get_info<sycl::info::device::sub_group_sizes>();
17+
if (std::find(SGSizes.begin(), SGSizes.end(), SGSize) == SGSizes.end()) {
18+
std::cout << "Test skipped due to missing support for sub-group size 32."
19+
<< std::endl;
20+
}
21+
22+
sycl::buffer<size_t, 1> TmpBuf{sycl::range{SGSize}};
23+
sycl::buffer<bool, 1> BarrierBuf{sycl::range{SGSize}};
24+
sycl::buffer<bool, 1> BroadcastBuf{sycl::range{SGSize}};
25+
sycl::buffer<bool, 1> AnyBuf{sycl::range{SGSize}};
26+
sycl::buffer<bool, 1> AllBuf{sycl::range{SGSize}};
27+
sycl::buffer<bool, 1> NoneBuf{sycl::range{SGSize}};
28+
sycl::buffer<bool, 1> ReduceBuf{sycl::range{SGSize}};
29+
sycl::buffer<bool, 1> ExScanBuf{sycl::range{SGSize}};
30+
sycl::buffer<bool, 1> IncScanBuf{sycl::range{SGSize}};
31+
32+
const auto NDR = sycl::nd_range<1>{SGSize, SGSize};
33+
Q.submit([&](sycl::handler &CGH) {
34+
sycl::accessor TmpAcc{TmpBuf, CGH, sycl::write_only};
35+
sycl::accessor BarrierAcc{BarrierBuf, CGH, sycl::write_only};
36+
sycl::accessor BroadcastAcc{BroadcastBuf, CGH, sycl::write_only};
37+
sycl::accessor AnyAcc{AnyBuf, CGH, sycl::write_only};
38+
sycl::accessor AllAcc{AllBuf, CGH, sycl::write_only};
39+
sycl::accessor NoneAcc{NoneBuf, CGH, sycl::write_only};
40+
sycl::accessor ReduceAcc{ReduceBuf, CGH, sycl::write_only};
41+
sycl::accessor ExScanAcc{ExScanBuf, CGH, sycl::write_only};
42+
sycl::accessor IncScanAcc{IncScanBuf, CGH, sycl::write_only};
43+
const auto KernelFunc =
44+
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(SGSize)]] {
45+
auto WI = item.get_global_id();
46+
auto SG = item.get_sub_group();
47+
48+
// Split into partitions of fixed size
49+
auto Partition = syclex::get_fixed_size_group<PartitionSize>(SG);
50+
51+
// Check all other members' writes are visible after a barrier.
52+
TmpAcc[WI] = 1;
53+
sycl::group_barrier(Partition);
54+
size_t Visible = 0;
55+
for (size_t Other = 0; Other < SGSize; ++Other) {
56+
if ((WI / PartitionSize) == (Other / PartitionSize)) {
57+
Visible += TmpAcc[Other];
58+
}
59+
}
60+
BarrierAcc[WI] = (Visible == PartitionSize);
61+
62+
// Simple check of group algorithms.
63+
uint32_t OriginalLID = SG.get_local_linear_id();
64+
uint32_t LID = Partition.get_local_linear_id();
65+
66+
uint32_t PartitionLeader =
67+
(OriginalLID / PartitionSize) * PartitionSize;
68+
uint32_t BroadcastResult =
69+
sycl::group_broadcast(Partition, OriginalLID, 0);
70+
BroadcastAcc[WI] = (BroadcastResult == PartitionLeader);
71+
72+
bool AnyResult = sycl::any_of_group(Partition, (LID == 0));
73+
AnyAcc[WI] = (AnyResult == true);
74+
75+
bool Predicate = ((OriginalLID / PartitionSize) % 2 == 0);
76+
bool AllResult = sycl::all_of_group(Partition, Predicate);
77+
if (Predicate) {
78+
AllAcc[WI] = (AllResult == true);
79+
} else {
80+
AllAcc[WI] = (AllResult == false);
81+
}
82+
83+
bool NoneResult = sycl::none_of_group(Partition, Predicate);
84+
if (Predicate) {
85+
NoneAcc[WI] = (NoneResult == false);
86+
} else {
87+
NoneAcc[WI] = (NoneResult == true);
88+
}
89+
90+
uint32_t ReduceResult =
91+
sycl::reduce_over_group(Partition, 1, sycl::plus<>());
92+
ReduceAcc[WI] = (ReduceResult == PartitionSize);
93+
94+
uint32_t ExScanResult =
95+
sycl::exclusive_scan_over_group(Partition, 1, sycl::plus<>());
96+
ExScanAcc[WI] = (ExScanResult == LID);
97+
98+
uint32_t IncScanResult =
99+
sycl::inclusive_scan_over_group(Partition, 1, sycl::plus<>());
100+
IncScanAcc[WI] = (IncScanResult == LID + 1);
101+
};
102+
CGH.parallel_for<TestKernel<PartitionSize>>(NDR, KernelFunc);
103+
});
104+
105+
sycl::host_accessor BarrierAcc{BarrierBuf, sycl::read_only};
106+
sycl::host_accessor BroadcastAcc{BroadcastBuf, sycl::read_only};
107+
sycl::host_accessor AnyAcc{AnyBuf, sycl::read_only};
108+
sycl::host_accessor AllAcc{AllBuf, sycl::read_only};
109+
sycl::host_accessor NoneAcc{NoneBuf, sycl::read_only};
110+
sycl::host_accessor ReduceAcc{ReduceBuf, sycl::read_only};
111+
sycl::host_accessor ExScanAcc{ExScanBuf, sycl::read_only};
112+
sycl::host_accessor IncScanAcc{IncScanBuf, sycl::read_only};
113+
for (int WI = 0; WI < SGSize; ++WI) {
114+
assert(BarrierAcc[WI] == true);
115+
assert(BroadcastAcc[WI] == true);
116+
assert(AnyAcc[WI] == true);
117+
assert(AllAcc[WI] == true);
118+
assert(NoneAcc[WI] == true);
119+
assert(ReduceAcc[WI] == true);
120+
assert(ExScanAcc[WI] == true);
121+
assert(IncScanAcc[WI] == true);
122+
}
123+
}
124+
125+
int main() {
126+
test<1>();
127+
test<2>();
128+
test<4>();
129+
test<8>();
130+
test<16>();
131+
test<32>();
132+
return 0;
133+
}

0 commit comments

Comments
 (0)