Skip to content

Commit bb78d2c

Browse files
authored
[SYCL] Enable async_work_group_copy for scalar and vector bool types (#2582)
* [SYCL] Enable async_work_group_copy for scalar and vector bool types Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent d5058e2 commit bb78d2c

File tree

3 files changed

+244
-30
lines changed

3 files changed

+244
-30
lines changed

sycl/include/CL/sycl/detail/type_traits.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,19 @@ template <typename T>
207207
struct is_vector_arithmetic
208208
: bool_constant<is_vec<T>::value && is_arithmetic<T>::value> {};
209209

210+
// is_bool
211+
template <typename T>
212+
struct is_scalar_bool
213+
: bool_constant<std::is_same<remove_cv_t<T>, bool>::value> {};
214+
215+
template <typename T>
216+
struct is_vector_bool
217+
: bool_constant<is_vec<T>::value &&
218+
is_scalar_bool<vector_element_t<T>>::value> {};
219+
220+
template <typename T>
221+
struct is_bool : bool_constant<is_scalar_bool<vector_element_t<T>>::value> {};
222+
210223
// is_pointer
211224
template <typename T> struct is_pointer_impl : std::false_type {};
212225

sycl/include/CL/sycl/group.hpp

+71-30
Original file line numberDiff line numberDiff line change
@@ -274,58 +274,99 @@ template <int Dimensions = 1> class group {
274274
__spirv_MemoryBarrier(__spv::Scope::Workgroup, flags);
275275
}
276276

277+
/// Asynchronously copies a number of elements specified by \p numElements
278+
/// from the source pointed by \p src to destination pointed by \p dest
279+
/// with a source stride specified by \p srcStride, and returns a SYCL
280+
/// device_event which can be used to wait on the completion of the copy.
281+
/// Permitted types for dataT are all scalar and vector types, except boolean.
277282
template <typename dataT>
278-
device_event async_work_group_copy(local_ptr<dataT> dest,
279-
global_ptr<dataT> src,
280-
size_t numElements) const {
283+
detail::enable_if_t<!detail::is_bool<dataT>::value, device_event>
284+
async_work_group_copy(local_ptr<dataT> dest, global_ptr<dataT> src,
285+
size_t numElements, size_t srcStride) const {
281286
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
282287
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
283288

284-
__ocl_event_t e = OpGroupAsyncCopyGlobalToLocal(
289+
__ocl_event_t E = OpGroupAsyncCopyGlobalToLocal(
285290
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
286-
numElements, 1, 0);
287-
return device_event(&e);
291+
numElements, srcStride, 0);
292+
return device_event(&E);
288293
}
289294

295+
/// Asynchronously copies a number of elements specified by \p numElements
296+
/// from the source pointed by \p src to destination pointed by \p dest with
297+
/// the destination stride specified by \p destStride, and returns a SYCL
298+
/// device_event which can be used to wait on the completion of the copy.
299+
/// Permitted types for dataT are all scalar and vector types, except boolean.
290300
template <typename dataT>
291-
device_event async_work_group_copy(global_ptr<dataT> dest,
292-
local_ptr<dataT> src,
293-
size_t numElements) const {
301+
detail::enable_if_t<!detail::is_bool<dataT>::value, device_event>
302+
async_work_group_copy(global_ptr<dataT> dest, local_ptr<dataT> src,
303+
size_t numElements, size_t destStride) const {
294304
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
295305
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
296306

297-
__ocl_event_t e = OpGroupAsyncCopyLocalToGlobal(
307+
__ocl_event_t E = OpGroupAsyncCopyLocalToGlobal(
298308
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
299-
numElements, 1, 0);
300-
return device_event(&e);
309+
numElements, destStride, 0);
310+
return device_event(&E);
311+
}
312+
313+
/// Specialization for scalar bool type.
314+
/// Asynchronously copies a number of elements specified by \p NumElements
315+
/// from the source pointed by \p Src to destination pointed by \p Dest
316+
/// with a stride specified by \p Stride, and returns a SYCL device_event
317+
/// which can be used to wait on the completion of the copy.
318+
template <typename T, access::address_space DestS, access::address_space SrcS>
319+
detail::enable_if_t<detail::is_scalar_bool<T>::value, device_event>
320+
async_work_group_copy(multi_ptr<T, DestS> Dest, multi_ptr<T, SrcS> Src,
321+
size_t NumElements, size_t Stride) const {
322+
static_assert(sizeof(bool) == sizeof(uint8_t),
323+
"Async copy to/from bool memory is not supported.");
324+
auto DestP =
325+
multi_ptr<uint8_t, DestS>(reinterpret_cast<uint8_t *>(Dest.get()));
326+
auto SrcP =
327+
multi_ptr<uint8_t, SrcS>(reinterpret_cast<uint8_t *>(Src.get()));
328+
return async_work_group_copy(DestP, SrcP, NumElements, Stride);
329+
}
330+
331+
/// Specialization for vector bool type.
332+
/// Asynchronously copies a number of elements specified by \p NumElements
333+
/// from the source pointed by \p Src to destination pointed by \p Dest
334+
/// with a stride specified by \p Stride, and returns a SYCL device_event
335+
/// which can be used to wait on the completion of the copy.
336+
template <typename T, access::address_space DestS, access::address_space SrcS>
337+
detail::enable_if_t<detail::is_vector_bool<T>::value, device_event>
338+
async_work_group_copy(multi_ptr<T, DestS> Dest, multi_ptr<T, SrcS> Src,
339+
size_t NumElements, size_t Stride) const {
340+
static_assert(sizeof(bool) == sizeof(uint8_t),
341+
"Async copy to/from bool memory is not supported.");
342+
using VecT = detail::change_base_type_t<T, uint8_t>;
343+
auto DestP = multi_ptr<VecT, DestS>(reinterpret_cast<VecT *>(Dest.get()));
344+
auto SrcP = multi_ptr<VecT, SrcS>(reinterpret_cast<VecT *>(Src.get()));
345+
return async_work_group_copy(DestP, SrcP, NumElements, Stride);
301346
}
302347

348+
/// Asynchronously copies a number of elements specified by \p numElements
349+
/// from the source pointed by \p src to destination pointed by \p dest and
350+
/// returns a SYCL device_event which can be used to wait on the completion
351+
/// of the copy.
352+
/// Permitted types for dataT are all scalar and vector types.
303353
template <typename dataT>
304354
device_event async_work_group_copy(local_ptr<dataT> dest,
305355
global_ptr<dataT> src,
306-
size_t numElements,
307-
size_t srcStride) const {
308-
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
309-
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
310-
311-
__ocl_event_t e = OpGroupAsyncCopyGlobalToLocal(
312-
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
313-
numElements, srcStride, 0);
314-
return device_event(&e);
356+
size_t numElements) const {
357+
return async_work_group_copy(dest, src, numElements, 1);
315358
}
316359

360+
/// Asynchronously copies a number of elements specified by \p numElements
361+
/// from the source pointed by \p src to destination pointed by \p dest and
362+
/// returns a SYCL device_event which can be used to wait on the completion
363+
/// of the copy.
364+
/// Permitted types for dataT are all scalar and vector types.
317365
template <typename dataT>
318366
device_event async_work_group_copy(global_ptr<dataT> dest,
319367
local_ptr<dataT> src,
320-
size_t numElements,
321-
size_t destStride) const {
322-
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
323-
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
324-
325-
__ocl_event_t e = OpGroupAsyncCopyLocalToGlobal(
326-
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
327-
numElements, destStride, 0);
328-
return device_event(&e);
368+
size_t numElements) const {
369+
return async_work_group_copy(dest, src, numElements, 1);
329370
}
330371

331372
template <typename... eventTN>
+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.run
2+
// RUN: %GPU_RUN_PLACEHOLDER %t.run
3+
// RUN: %CPU_RUN_PLACEHOLDER %t.run
4+
// RUN: %ACC_RUN_PLACEHOLDER %t.run
5+
// RUN: env SYCL_DEVICE_FILTER=host %t.run
6+
7+
#include <CL/sycl.hpp>
8+
#include <iostream>
9+
#include <typeinfo>
10+
11+
using namespace cl::sycl;
12+
13+
template <typename T> class KernelName;
14+
15+
// Define the number of work items to enqueue.
16+
const size_t NElems = 32;
17+
const size_t WorkGroupSize = 8;
18+
const size_t NWorkGroups = NElems / WorkGroupSize;
19+
20+
template <typename T> void initInputBuffer(buffer<T, 1> &Buf, size_t Stride) {
21+
auto Acc = Buf.template get_access<access::mode::write>();
22+
for (size_t I = 0; I < Buf.get_count(); I += WorkGroupSize) {
23+
for (size_t J = 0; J < WorkGroupSize; J++)
24+
Acc[I + J] = I + J + ((J % Stride == 0) ? 100 : 0);
25+
}
26+
}
27+
28+
template <typename T> void initOutputBuffer(buffer<T, 1> &Buf) {
29+
auto Acc = Buf.template get_access<access::mode::write>();
30+
for (size_t I = 0; I < Buf.get_count(); I++)
31+
Acc[I] = 0;
32+
}
33+
34+
template <typename T> struct is_vec : std::false_type {};
35+
template <typename T, size_t N> struct is_vec<vec<T, N>> : std::true_type {};
36+
37+
template <typename T> bool checkEqual(vec<T, 1> A, size_t B) {
38+
T TB = B;
39+
return A.s0() == TB;
40+
}
41+
42+
template <typename T> bool checkEqual(vec<T, 4> A, size_t B) {
43+
T TB = B;
44+
return A.x() == TB && A.y() == TB && A.z() == TB && A.w() == TB;
45+
}
46+
47+
template <typename T>
48+
typename std::enable_if<!is_vec<T>::value, bool>::type checkEqual(T A,
49+
size_t B) {
50+
T TB = B;
51+
return A == TB;
52+
}
53+
54+
template <typename T> std::string toString(vec<T, 1> A) {
55+
std::string R("(");
56+
return R + std::to_string(A.s0()) + ")";
57+
}
58+
59+
template <typename T> std::string toString(vec<T, 4> A) {
60+
std::string R("(");
61+
R += std::to_string(A.x()) + "," + std::to_string(A.y()) + "," +
62+
std::to_string(A.z()) + "," + std::to_string(A.w()) + ")";
63+
return R;
64+
}
65+
66+
template <typename T = void>
67+
typename std::enable_if<!is_vec<T>::value, std::string>::type toString(T A) {
68+
return std::to_string(A);
69+
}
70+
71+
template <typename T> int checkResults(buffer<T, 1> &OutBuf, size_t Stride) {
72+
auto Out = OutBuf.template get_access<access::mode::read>();
73+
int EarlyFailout = 20;
74+
75+
for (size_t I = 0; I < OutBuf.get_count(); I += WorkGroupSize) {
76+
for (size_t J = 0; J < WorkGroupSize; J++) {
77+
size_t ExpectedVal = (J % Stride == 0) ? (100 + I + J) : 0;
78+
if (!checkEqual(Out[I + J], ExpectedVal)) {
79+
std::cerr << std::string(typeid(T).name()) + ": Stride=" << Stride
80+
<< " : Incorrect value at index " << I + J
81+
<< " : Expected: " << toString(ExpectedVal)
82+
<< ", Computed: " << toString(Out[I + J]) << "\n";
83+
if (--EarlyFailout == 0)
84+
return 1;
85+
}
86+
}
87+
}
88+
return EarlyFailout - 20;
89+
}
90+
91+
template <typename T> int test(size_t Stride) {
92+
queue Q;
93+
94+
buffer<T, 1> InBuf(NElems);
95+
buffer<T, 1> OutBuf(NElems);
96+
97+
initInputBuffer(InBuf, Stride);
98+
initOutputBuffer(OutBuf);
99+
100+
Q.submit([&](handler &CGH) {
101+
auto In = InBuf.template get_access<access::mode::read>(CGH);
102+
auto Out = OutBuf.template get_access<access::mode::write>(CGH);
103+
accessor<T, 1, access::mode::read_write, access::target::local> Local(
104+
range<1>{WorkGroupSize}, CGH);
105+
106+
nd_range<1> NDR{range<1>(NElems), range<1>(WorkGroupSize)};
107+
CGH.parallel_for<KernelName<T>>(NDR, [=](nd_item<1> NDId) {
108+
auto GrId = NDId.get_group_linear_id();
109+
auto Group = NDId.get_group();
110+
size_t NElemsToCopy =
111+
WorkGroupSize / Stride + ((WorkGroupSize % Stride) ? 1 : 0);
112+
size_t Offset = GrId * WorkGroupSize;
113+
if (Stride == 1) { // Check the version without stride arg.
114+
auto E = NDId.async_work_group_copy(
115+
Local.get_pointer(), In.get_pointer() + Offset, NElemsToCopy);
116+
E.wait();
117+
} else {
118+
auto E = NDId.async_work_group_copy(Local.get_pointer(),
119+
In.get_pointer() + Offset,
120+
NElemsToCopy, Stride);
121+
E.wait();
122+
}
123+
124+
if (Stride == 1) { // Check the version without stride arg.
125+
auto E = Group.async_work_group_copy(
126+
Out.get_pointer() + Offset, Local.get_pointer(), NElemsToCopy);
127+
Group.wait_for(E);
128+
} else {
129+
auto E = Group.async_work_group_copy(Out.get_pointer() + Offset,
130+
Local.get_pointer(), NElemsToCopy,
131+
Stride);
132+
Group.wait_for(E);
133+
}
134+
});
135+
}).wait();
136+
137+
return checkResults(OutBuf, Stride);
138+
}
139+
140+
int main() {
141+
for (int Stride = 1; Stride < WorkGroupSize; Stride++) {
142+
if (test<int>(Stride))
143+
return 1;
144+
if (test<vec<int, 1>>(Stride))
145+
return 1;
146+
if (test<int4>(Stride))
147+
return 1;
148+
if (test<bool>(Stride))
149+
return 1;
150+
if (test<vec<bool, 1>>(Stride))
151+
return 1;
152+
if (test<vec<bool, 4>>(Stride))
153+
return 1;
154+
if (test<cl::sycl::cl_bool>(Stride))
155+
return 1;
156+
}
157+
158+
std::cout << "Test passed.\n";
159+
return 0;
160+
}

0 commit comments

Comments
 (0)