Skip to content

Commit f5178eb

Browse files
[SYCL] Allow raw pointers in SYCL vec load and store (#13895)
In accordance with KhronosGroup/SYCL-Docs#555 proposal, this commit allows raw pointers in the `load` and `store` member functions on `sycl::vec`. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 3767874 commit f5178eb

File tree

3 files changed

+226
-0
lines changed

3 files changed

+226
-0
lines changed

sycl/include/sycl/vector.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,11 @@ template <typename Type, int NumElements> class vec {
985985
MultiPtr(Acc);
986986
load(Offset, MultiPtr);
987987
}
988+
void load(size_t Offset, const DataT *Ptr) {
989+
for (int I = 0; I < NumElements; ++I)
990+
setValue(I, Ptr[Offset * NumElements + I]);
991+
}
992+
988993
template <access::address_space Space, access::decorated DecorateAddress>
989994
void store(size_t Offset,
990995
multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
@@ -1004,6 +1009,10 @@ template <typename Type, int NumElements> class vec {
10041009
MultiPtr(Acc);
10051010
store(Offset, MultiPtr);
10061011
}
1012+
void store(size_t Offset, DataT *Ptr) const {
1013+
for (int I = 0; I < NumElements; ++I)
1014+
Ptr[Offset * NumElements + I] = getValue(I);
1015+
}
10071016

10081017
void ConvertToDataT() {
10091018
for (size_t i = 0; i < NumElements; ++i) {

sycl/include/sycl/vector_preview.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,11 @@ class vec : public detail::vec_arith<DataT, NumElements> {
572572
MultiPtr(Acc);
573573
load(Offset, MultiPtr);
574574
}
575+
void load(size_t Offset, const DataT *Ptr) {
576+
for (int I = 0; I < NumElements; ++I)
577+
m_Data[I] = Ptr[Offset * NumElements + I];
578+
}
579+
575580
template <access::address_space Space, access::decorated DecorateAddress>
576581
void store(size_t Offset,
577582
multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
@@ -591,6 +596,10 @@ class vec : public detail::vec_arith<DataT, NumElements> {
591596
MultiPtr(Acc);
592597
store(Offset, MultiPtr);
593598
}
599+
void store(size_t Offset, DataT *Ptr) const {
600+
for (int I = 0; I < NumElements; ++I)
601+
Ptr[Offset * NumElements + I] = m_Data[I];
602+
}
594603

595604
private:
596605
// fields
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
// RUN: %if preview-breaking-changes-supported %{ %{build} -fpreview-breaking-changes -o %t2.out %}
5+
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}
6+
7+
// Tests load and store on sycl::vec.
8+
9+
#include <sycl/detail/core.hpp>
10+
#include <sycl/ext/oneapi/bfloat16.hpp>
11+
#include <sycl/ext/oneapi/experimental/bfloat16_math.hpp>
12+
#include <sycl/types.hpp>
13+
14+
namespace syclex = sycl::ext::oneapi;
15+
16+
template <size_t N, typename T0, typename T1>
17+
int CheckResult(const T0 &Actual, const T1 &Reference, const char *Category) {
18+
int Failures = 0;
19+
for (size_t I = 0; I < N; ++I) {
20+
if (Actual[I] == Reference[I])
21+
continue;
22+
23+
std::cout << "Failed at index " << I << ": " << Category << " - "
24+
<< Actual[I] << " != " << Reference[I] << std::endl;
25+
++Failures;
26+
}
27+
return Failures;
28+
}
29+
30+
template <typename VecT> int RunTest(sycl::queue &Q) {
31+
using ElemT = typename VecT::element_type;
32+
33+
int Failures = 0;
34+
// Load on host.
35+
// Note: multi_ptr is not usable on host, so only raw pointer is tested.
36+
{
37+
const ElemT Ref[] = {0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13};
38+
VecT V{0};
39+
V.load(2, Ref);
40+
Failures += CheckResult<4>(V, Ref + 8, "load with raw pointer on host");
41+
}
42+
43+
// Store on host.
44+
// Note: multi_ptr is not usable on host, so only raw pointer is tested.
45+
{
46+
ElemT Out[] = {0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13};
47+
const VecT V{4, 3, 2, 1};
48+
V.store(1, Out);
49+
const ElemT Ref[] = {0, 2, 1, 4, 4, 3, 2, 1, 7, 10, 9, 12, 11, 14, 13};
50+
Failures +=
51+
CheckResult<std::size(Ref)>(Out, Ref, "store in raw pointer on host");
52+
}
53+
54+
// Load on device.
55+
{
56+
const ElemT Ref[] = {0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12,
57+
11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24};
58+
VecT V[6] = {VecT{0}};
59+
60+
{
61+
sycl::buffer<const ElemT, 1> RefBuff{Ref, std::size(Ref)};
62+
sycl::buffer<VecT, 1> VBuff{V, std::size(V)};
63+
64+
Q.submit([&](sycl::handler &CGH) {
65+
sycl::accessor GlobalRefAcc{RefBuff, CGH, sycl::read_only};
66+
sycl::accessor VAcc{VBuff, CGH, sycl::read_write};
67+
sycl::local_accessor<ElemT, 1> LocalRefAcc{std::size(Ref), CGH};
68+
CGH.parallel_for(sycl::nd_range<1>{1, 1}, [=](sycl::nd_item<1>) {
69+
// Initialize the local and private memory copies.
70+
ElemT PrivateRef[std::size(Ref)] = {0};
71+
for (size_t I = 0; I < GlobalRefAcc.size(); ++I) {
72+
PrivateRef[I] = GlobalRefAcc[I];
73+
LocalRefAcc[I] = GlobalRefAcc[I];
74+
}
75+
76+
// Load with global multi_ptr.
77+
auto GlobalMPtr =
78+
GlobalRefAcc
79+
.template get_multi_ptr<sycl::access::decorated::no>();
80+
VAcc[0].load(0, GlobalMPtr);
81+
82+
// Load with local multi_ptr.
83+
auto LocalMPtr =
84+
LocalRefAcc.template get_multi_ptr<sycl::access::decorated::no>();
85+
VAcc[1].load(1, LocalMPtr);
86+
87+
// Load with private multi_ptr.
88+
auto PrivateMPtr = sycl::address_space_cast<
89+
sycl::access::address_space::private_space,
90+
sycl::access::decorated::no>(PrivateRef);
91+
VAcc[2].load(2, PrivateMPtr);
92+
93+
// Load with global raw pointer.
94+
const ElemT *GlobalRawPtr = GlobalMPtr.get_raw();
95+
VAcc[3].load(3, GlobalRawPtr);
96+
97+
// Load with local raw pointer.
98+
const ElemT *LocalRawPtr = LocalMPtr.get_raw();
99+
VAcc[4].load(4, LocalRawPtr);
100+
101+
// Load with private raw pointer.
102+
VAcc[5].load(5, PrivateRef);
103+
});
104+
});
105+
}
106+
107+
Failures +=
108+
CheckResult<4>(V[0], Ref, "load with global multi_ptr on device");
109+
Failures +=
110+
CheckResult<4>(V[1], Ref + 4, "load with local multi_ptr on device");
111+
Failures +=
112+
CheckResult<4>(V[2], Ref + 8, "load with private multi_ptr on device");
113+
Failures += CheckResult<4>(V[3], Ref + 12,
114+
"load with global raw pointer on device");
115+
Failures +=
116+
CheckResult<4>(V[4], Ref + 16, "load with local raw pointer on device");
117+
Failures += CheckResult<4>(V[5], Ref + 20,
118+
"load with private raw pointer on device");
119+
}
120+
121+
// Store on device.
122+
{
123+
ElemT Out[24] = {0};
124+
const VecT V[] = {{0, 2, 1, 4}, {3, 6, 5, 8}, {7, 10, 9, 12},
125+
{11, 14, 13, 16}, {15, 18, 17, 20}, {19, 22, 21, 24}};
126+
127+
{
128+
sycl::buffer<ElemT, 1> OutBuff{Out, std::size(Out)};
129+
130+
Q.submit([&](sycl::handler &CGH) {
131+
sycl::accessor OutAcc{OutBuff, CGH, sycl::read_write};
132+
sycl::local_accessor<ElemT, 1> LocalOutAcc{std::size(Out), CGH};
133+
CGH.parallel_for(sycl::nd_range<1>{1, 1}, [=](sycl::nd_item<1>) {
134+
ElemT PrivateVal[std::size(Out)] = {0};
135+
136+
// Store in global multi_ptr.
137+
auto GlobalMPtr =
138+
OutAcc.template get_multi_ptr<sycl::access::decorated::no>();
139+
V[0].store(0, GlobalMPtr);
140+
141+
// Store in local multi_ptr.
142+
auto LocalMPtr =
143+
LocalOutAcc.template get_multi_ptr<sycl::access::decorated::no>();
144+
V[1].store(1, LocalMPtr);
145+
146+
// Store in private multi_ptr.
147+
auto PrivateMPtr = sycl::address_space_cast<
148+
sycl::access::address_space::private_space,
149+
sycl::access::decorated::no>(PrivateVal);
150+
V[2].store(2, PrivateMPtr);
151+
152+
// Store in global raw pointer.
153+
ElemT *GlobalRawPtr = GlobalMPtr.get_raw();
154+
V[3].store(3, GlobalRawPtr);
155+
156+
// Store in local raw pointer.
157+
ElemT *LocalRawPtr = LocalMPtr.get_raw();
158+
V[4].store(4, LocalRawPtr);
159+
160+
// Store in private raw pointer.
161+
V[5].store(5, PrivateVal);
162+
163+
// Write local and private results back to the global buffer.
164+
for (size_t I = 0; I < 4; ++I) {
165+
OutAcc[4 + I] = LocalMPtr[4 + I];
166+
OutAcc[8 + I] = PrivateVal[8 + I];
167+
OutAcc[16 + I] = LocalMPtr[16 + I];
168+
OutAcc[20 + I] = PrivateVal[20 + I];
169+
}
170+
});
171+
});
172+
}
173+
174+
const ElemT Ref[] = {0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12,
175+
11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24};
176+
177+
Failures += CheckResult<4>(Out, Ref, "store in global multi_ptr on device");
178+
Failures +=
179+
CheckResult<4>(Out + 4, Ref + 4, "store in local multi_ptr on device");
180+
Failures += CheckResult<4>(Out + 8, Ref + 8,
181+
"store in private multi_ptr on device");
182+
Failures += CheckResult<4>(Out + 12, Ref + 12,
183+
"store in global raw pointer on device");
184+
Failures += CheckResult<4>(Out + 16, Ref + 16,
185+
"store in local raw pointer on device");
186+
Failures += CheckResult<4>(Out + 20, Ref + 20,
187+
"store in private raw pointer on device");
188+
}
189+
190+
return Failures;
191+
}
192+
193+
int main() {
194+
sycl::queue Q;
195+
196+
int Failures = 0;
197+
198+
Failures += RunTest<sycl::int4>(Q);
199+
Failures += RunTest<sycl::float4>(Q);
200+
Failures += RunTest<sycl::vec<syclex::bfloat16, 4>>(Q);
201+
202+
if (Q.get_device().has(sycl::aspect::fp16))
203+
Failures += RunTest<sycl::half4>(Q);
204+
if (Q.get_device().has(sycl::aspect::fp64))
205+
Failures += RunTest<sycl::double4>(Q);
206+
207+
return Failures;
208+
}

0 commit comments

Comments
 (0)