Skip to content

Commit e7defab

Browse files
[SYCL] Use std::array as storage for sycl::vec on device (#14130)
Replaces #13270 Changing the storage to std::array instead of Clang's extension fixes strict ansi-aliasing violation and simplifies device code.
1 parent 13a7b3a commit e7defab

File tree

9 files changed

+488
-995
lines changed

9 files changed

+488
-995
lines changed

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ template <typename T> auto convertToOpenCLType(T &&x) {
342342
std::declval<ElemTy>()))>,
343343
no_ref::size()>;
344344
#ifdef __SYCL_DEVICE_ONLY__
345+
346+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
345347
// TODO: for some mysterious reasons on NonUniformGroups E2E tests fail if
346348
// we use the "else" version only. I suspect that's an issues with
347349
// non-uniform groups implementation.
@@ -350,6 +352,10 @@ template <typename T> auto convertToOpenCLType(T &&x) {
350352
else
351353
return static_cast<typename MatchingVec::vector_t>(
352354
x.template as<MatchingVec>());
355+
#else // __INTEL_PREVIEW_BREAKING_CHANGES
356+
return sycl::bit_cast<typename MatchingVec::vector_t>(x);
357+
#endif // __INTEL_PREVIEW_BREAKING_CHANGES
358+
353359
#else
354360
return x.template as<MatchingVec>();
355361
#endif

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 78 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,14 @@ using rel_t = typename std::conditional_t<
5050
friend std::enable_if_t<(COND), vec_t> operator BINOP(const vec_t & Lhs, \
5151
const vec_t & Rhs) { \
5252
vec_t Ret; \
53-
if constexpr (vec_t::IsUsingArrayOnDevice) { \
53+
if constexpr (vec_t::IsBfloat16) { \
5454
for (size_t I = 0; I < NumElements; ++I) { \
55-
detail::VecAccess<vec_t>::setValue( \
56-
Ret, I, \
57-
(detail::VecAccess<vec_t>::getValue(Lhs, I) \
58-
BINOP detail::VecAccess<vec_t>::getValue(Rhs, I))); \
55+
Ret[I] = Lhs[I] BINOP Rhs[I]; \
5956
} \
6057
} else { \
61-
Ret.m_Data = Lhs.m_Data BINOP Rhs.m_Data; \
58+
auto ExtVecLhs = sycl::bit_cast<typename vec_t::vector_t>(Lhs); \
59+
auto ExtVecRhs = sycl::bit_cast<typename vec_t::vector_t>(Rhs); \
60+
Ret = vec<DataT, NumElements>(ExtVecLhs BINOP ExtVecRhs); \
6261
if constexpr (std::is_same_v<DataT, bool> && CONVERT) { \
6362
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret); \
6463
} \
@@ -72,13 +71,9 @@ using rel_t = typename std::conditional_t<
7271
friend std::enable_if_t<(COND), vec_t> operator BINOP(const vec_t & Lhs, \
7372
const vec_t & Rhs) { \
7473
vec_t Ret{}; \
75-
for (size_t I = 0; I < NumElements; ++I) \
76-
detail::VecAccess<vec_t>::setValue( \
77-
Ret, I, \
78-
(DataT)(vec_data<DataT>::get( \
79-
detail::VecAccess<vec_t>::getValue(Lhs, I)) \
80-
BINOP vec_data<DataT>::get( \
81-
detail::VecAccess<vec_t>::getValue(Rhs, I)))); \
74+
for (size_t I = 0; I < NumElements; ++I) { \
75+
Ret[I] = Lhs[I] BINOP Rhs[I]; \
76+
} \
8277
return Ret; \
8378
}
8479
#endif // __SYCL_DEVICE_ONLY__
@@ -130,83 +125,78 @@ template <typename DataT, int NumElements>
130125
class vec_arith : public vec_arith_common<DataT, NumElements> {
131126
protected:
132127
using vec_t = vec<DataT, NumElements>;
133-
using ocl_t = rel_t<DataT>;
128+
using ocl_t = detail::select_cl_scalar_integral_signed_t<DataT>;
134129
template <typename T> using vec_data = vec_helper<T>;
135130

136131
// operator!.
137-
friend vec<rel_t<DataT>, NumElements> operator!(const vec_t &Rhs) {
138-
if constexpr (vec_t::IsUsingArrayOnDevice || vec_t::IsUsingArrayOnHost) {
139-
vec_t Ret{};
132+
friend vec<ocl_t, NumElements> operator!(const vec_t &Rhs) {
133+
#ifdef __SYCL_DEVICE_ONLY__
134+
if constexpr (!vec_t::IsBfloat16) {
135+
auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Rhs);
136+
vec<ocl_t, NumElements> Ret{
137+
(typename vec<ocl_t, NumElements>::vector_t) !extVec};
138+
return Ret;
139+
} else
140+
#endif // __SYCL_DEVICE_ONLY__
141+
{
142+
vec<ocl_t, NumElements> Ret{};
140143
for (size_t I = 0; I < NumElements; ++I) {
141-
detail::VecAccess<vec_t>::setValue(
142-
Ret, I,
143-
!vec_data<DataT>::get(detail::VecAccess<vec_t>::getValue(Rhs, I)));
144+
// static_cast will work here as the output of ! operator is either 0 or
145+
// -1.
146+
Ret[I] = static_cast<ocl_t>(-1 * (!Rhs[I]));
144147
}
145-
return Ret.template as<vec<rel_t<DataT>, NumElements>>();
146-
} else {
147-
return vec_t{(typename vec<DataT, NumElements>::DataType) !Rhs.m_Data}
148-
.template as<vec<rel_t<DataT>, NumElements>>();
148+
return Ret;
149149
}
150150
}
151151

152152
// operator +.
153153
friend vec_t operator+(const vec_t &Lhs) {
154-
if constexpr (vec_t::IsUsingArrayOnDevice || vec_t::IsUsingArrayOnHost) {
155-
vec_t Ret{};
156-
for (size_t I = 0; I < NumElements; ++I)
157-
detail::VecAccess<vec_t>::setValue(
158-
Ret, I,
159-
vec_data<DataT>::get(+vec_data<DataT>::get(
160-
detail::VecAccess<vec_t>::getValue(Lhs, I))));
161-
return Ret;
162-
} else {
163-
return vec_t{+Lhs.m_Data};
164-
}
154+
#ifdef __SYCL_DEVICE_ONLY__
155+
auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Lhs);
156+
return vec_t{+extVec};
157+
#else
158+
vec_t Ret{};
159+
for (size_t I = 0; I < NumElements; ++I)
160+
Ret[I] = +Lhs[I];
161+
return Ret;
162+
#endif
165163
}
166164

167165
// operator -.
168166
friend vec_t operator-(const vec_t &Lhs) {
169-
namespace oneapi = sycl::ext::oneapi;
170167
vec_t Ret{};
171-
if constexpr (vec_t::IsBfloat16 && NumElements == 1) {
172-
oneapi::bfloat16 v = oneapi::detail::bitsToBfloat16(Lhs.m_Data);
173-
oneapi::bfloat16 w = -v;
174-
Ret.m_Data = oneapi::detail::bfloat16ToBits(w);
175-
} else if constexpr (vec_t::IsBfloat16) {
176-
for (size_t I = 0; I < NumElements; I++) {
177-
oneapi::bfloat16 v = oneapi::detail::bitsToBfloat16(Lhs.m_Data[I]);
178-
oneapi::bfloat16 w = -v;
179-
Ret.m_Data[I] = oneapi::detail::bfloat16ToBits(w);
180-
}
181-
} else if constexpr (vec_t::IsUsingArrayOnDevice ||
182-
vec_t::IsUsingArrayOnHost) {
183-
for (size_t I = 0; I < NumElements; ++I)
184-
detail::VecAccess<vec_t>::setValue(
185-
Ret, I,
186-
vec_data<DataT>::get(-vec_data<DataT>::get(
187-
detail::VecAccess<vec_t>::getValue(Lhs, I))));
188-
return Ret;
168+
if constexpr (vec_t::IsBfloat16) {
169+
for (size_t I = 0; I < NumElements; I++)
170+
Ret[I] = -Lhs[I];
189171
} else {
190-
Ret = vec_t{-Lhs.m_Data};
172+
#ifndef __SYCL_DEVICE_ONLY__
173+
for (size_t I = 0; I < NumElements; ++I)
174+
Ret[I] = -Lhs[I];
175+
#else
176+
auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Lhs);
177+
Ret = vec_t{-extVec};
191178
if constexpr (std::is_same_v<DataT, bool>) {
192179
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
193180
}
194-
return Ret;
181+
#endif
195182
}
183+
return Ret;
196184
}
197185

198186
// Unary operations on sycl::vec
187+
// FIXME: Don't allow Unary operators on vec<bool> after
188+
// https://github.com/KhronosGroup/SYCL-CTS/issues/896 gets fixed.
199189
#ifdef __SYCL_UOP
200190
#error "Undefine __SYCL_UOP macro"
201191
#endif
202192
#define __SYCL_UOP(UOP, OPASSIGN) \
203193
friend vec_t &operator UOP(vec_t & Rhs) { \
204-
Rhs OPASSIGN vec_data<DataT>::get(1); \
194+
Rhs OPASSIGN DataT{1}; \
205195
return Rhs; \
206196
} \
207197
friend vec_t operator UOP(vec_t &Lhs, int) { \
208198
vec_t Ret(Lhs); \
209-
Lhs OPASSIGN vec_data<DataT>::get(1); \
199+
Lhs OPASSIGN DataT{1}; \
210200
return Ret; \
211201
}
212202

@@ -228,25 +218,24 @@ class vec_arith : public vec_arith_common<DataT, NumElements> {
228218
friend std::enable_if_t<(COND), vec<ocl_t, NumElements>> operator RELLOGOP( \
229219
const vec_t & Lhs, const vec_t & Rhs) { \
230220
vec<ocl_t, NumElements> Ret{}; \
231-
/* This special case is needed since there are no standard operator|| */ \
232-
/* or operator&& functions for std::array. */ \
233-
if constexpr (vec_t::IsUsingArrayOnDevice && \
234-
(std::string_view(#RELLOGOP) == "||" || \
235-
std::string_view(#RELLOGOP) == "&&")) { \
221+
/* ext_vector_type does not support bfloat16, so for these */ \
222+
/* we do element-by-element operation on the underlying std::array. */ \
223+
if constexpr (vec_t::IsBfloat16) { \
236224
for (size_t I = 0; I < NumElements; ++I) { \
237-
/* We cannot use SetValue here as the operator is not a friend of*/ \
238-
/* Ret on Windows. */ \
239-
Ret[I] = static_cast<ocl_t>( \
240-
-(vec_data<DataT>::get(detail::VecAccess<vec_t>::getValue(Lhs, I)) \
241-
RELLOGOP vec_data<DataT>::get( \
242-
detail::VecAccess<vec_t>::getValue(Rhs, I)))); \
225+
Ret[I] = static_cast<ocl_t>(-(Lhs[I] RELLOGOP Rhs[I])); \
243226
} \
244227
} else { \
228+
auto ExtVecLhs = sycl::bit_cast<typename vec_t::vector_t>(Lhs); \
229+
auto ExtVecRhs = sycl::bit_cast<typename vec_t::vector_t>(Rhs); \
230+
/* Cast required to convert unsigned char ext_vec_type to */ \
231+
/* char ext_vec_type. */ \
245232
Ret = vec<ocl_t, NumElements>( \
246233
(typename vec<ocl_t, NumElements>::vector_t)( \
247-
Lhs.m_Data RELLOGOP Rhs.m_Data)); \
248-
if (NumElements == 1) /*Scalar 0/1 logic was applied, invert*/ \
234+
ExtVecLhs RELLOGOP ExtVecRhs)); \
235+
/* For NumElements == 1, we use scalar instead of ext_vector_type. */ \
236+
if constexpr (NumElements == 1) { \
249237
Ret *= -1; \
238+
} \
250239
} \
251240
return Ret; \
252241
}
@@ -257,12 +246,7 @@ class vec_arith : public vec_arith_common<DataT, NumElements> {
257246
const vec_t & Lhs, const vec_t & Rhs) { \
258247
vec<ocl_t, NumElements> Ret{}; \
259248
for (size_t I = 0; I < NumElements; ++I) { \
260-
/* We cannot use SetValue here as the operator is not a friend of*/ \
261-
/* Ret on Windows. */ \
262-
Ret[I] = static_cast<ocl_t>( \
263-
-(vec_data<DataT>::get(detail::VecAccess<vec_t>::getValue(Lhs, I)) \
264-
RELLOGOP vec_data<DataT>::get( \
265-
detail::VecAccess<vec_t>::getValue(Rhs, I)))); \
249+
Ret[I] = static_cast<ocl_t>(-(Lhs[I] RELLOGOP Rhs[I])); \
266250
} \
267251
return Ret; \
268252
}
@@ -376,34 +360,36 @@ template <typename DataT, int NumElements> class vec_arith_common {
376360
protected:
377361
using vec_t = vec<DataT, NumElements>;
378362

363+
static constexpr bool IsBfloat16 =
364+
std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;
365+
379366
// operator~() available only when: dataT != float && dataT != double
380367
// && dataT != half
381368
template <typename T = DataT>
382369
friend std::enable_if_t<!detail::is_vgenfloat_v<T>, vec_t>
383370
operator~(const vec_t &Rhs) {
384-
if constexpr (vec_t::IsUsingArrayOnDevice || vec_t::IsUsingArrayOnHost) {
385-
vec_t Ret{};
386-
for (size_t I = 0; I < NumElements; ++I) {
387-
detail::VecAccess<vec_t>::setValue(
388-
Ret, I, ~detail::VecAccess<vec_t>::getValue(Rhs, I));
389-
}
390-
return Ret;
391-
} else {
392-
vec_t Ret{(typename vec_t::DataType) ~Rhs.m_Data};
393-
if constexpr (std::is_same_v<DataT, bool>) {
394-
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
395-
}
396-
return Ret;
371+
#ifdef __SYCL_DEVICE_ONLY__
372+
auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Rhs);
373+
vec_t Ret{~extVec};
374+
if constexpr (std::is_same_v<DataT, bool>) {
375+
ConvertToDataT(Ret);
397376
}
377+
return Ret;
378+
#else
379+
vec_t Ret{};
380+
for (size_t I = 0; I < NumElements; ++I) {
381+
Ret[I] = ~Rhs[I];
382+
}
383+
return Ret;
384+
#endif
398385
}
399386

400387
#ifdef __SYCL_DEVICE_ONLY__
401388
using vec_bool_t = vec<bool, NumElements>;
402389
// Required only for std::bool.
403390
static void ConvertToDataT(vec_bool_t &Ret) {
404391
for (size_t I = 0; I < NumElements; ++I) {
405-
DataT Tmp = detail::VecAccess<vec_bool_t>::getValue(Ret, I);
406-
detail::VecAccess<vec_bool_t>::setValue(Ret, I, Tmp);
392+
Ret[I] = bit_cast<int8_t>(Ret[I]) != 0;
407393
}
408394
}
409395
#endif

sycl/include/sycl/detail/vector_convert.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,15 @@ NativeToT convertImpl(NativeFromT Value) {
558558
}
559559
}
560560

561+
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
562+
template <typename FromT, typename ToT, sycl::rounding_mode RoundingMode,
563+
int VecSize, typename NativeFromT, typename NativeToT>
564+
auto ConvertImpl(std::byte val) {
565+
return convertImpl<FromT, ToT, RoundingMode, VecSize, NativeFromT, NativeToT>(
566+
(std::int8_t)val);
567+
}
568+
#endif
569+
561570
} // namespace detail
562571
} // namespace _V1
563572
} // namespace sycl

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ template <int N> void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) {
102102

103103
// sycl::vec support
104104
namespace bf16 {
105+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
105106
#ifdef __SYCL_DEVICE_ONLY__
106107
using Vec2StorageT = Bfloat16StorageT __attribute__((ext_vector_type(2)));
107108
using Vec3StorageT = Bfloat16StorageT __attribute__((ext_vector_type(3)));
@@ -115,6 +116,7 @@ using Vec4StorageT = std::array<Bfloat16StorageT, 4>;
115116
using Vec8StorageT = std::array<Bfloat16StorageT, 8>;
116117
using Vec16StorageT = std::array<Bfloat16StorageT, 16>;
117118
#endif
119+
#endif // __INTEL_PREVIEW_BREAKING_CHANGES
118120
} // namespace bf16
119121
} // namespace detail
120122

sycl/include/sycl/half_type.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,18 +249,22 @@ using StorageT = _Float16;
249249
using BIsRepresentationT = _Float16;
250250
using VecElemT = _Float16;
251251

252+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
252253
using Vec2StorageT = VecElemT __attribute__((ext_vector_type(2)));
253254
using Vec3StorageT = VecElemT __attribute__((ext_vector_type(3)));
254255
using Vec4StorageT = VecElemT __attribute__((ext_vector_type(4)));
255256
using Vec8StorageT = VecElemT __attribute__((ext_vector_type(8)));
256257
using Vec16StorageT = VecElemT __attribute__((ext_vector_type(16)));
258+
#endif // __INTEL_PREVIEW_BREAKING_CHANGES
259+
257260
#else // SYCL_DEVICE_ONLY
258261
using StorageT = detail::host_half_impl::half;
259262
// No need to extract underlying data type for built-in functions operating on
260263
// host
261264
using BIsRepresentationT = half;
262265
using VecElemT = half;
263266

267+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
264268
// On the host side we cannot use OpenCL cl_half# types as an underlying type
265269
// for vec because they are actually defined as an integer type under the
266270
// hood. As a result half values will be converted to the integer and passed
@@ -270,6 +274,8 @@ using Vec3StorageT = std::array<VecElemT, 3>;
270274
using Vec4StorageT = std::array<VecElemT, 4>;
271275
using Vec8StorageT = std::array<VecElemT, 8>;
272276
using Vec16StorageT = std::array<VecElemT, 16>;
277+
#endif // __INTEL_PREVIEW_BREAKING_CHANGES
278+
273279
#endif // SYCL_DEVICE_ONLY
274280

275281
#ifndef __SYCL_DEVICE_ONLY__

0 commit comments

Comments
 (0)