@@ -50,15 +50,14 @@ using rel_t = typename std::conditional_t<
50
50
friend std::enable_if_t <(COND), vec_t > operator BINOP (const vec_t & Lhs, \
51
51
const vec_t & Rhs) { \
52
52
vec_t Ret; \
53
- if constexpr (vec_t ::IsUsingArrayOnDevice ) { \
53
+ if constexpr (vec_t ::IsBfloat16 ) { \
54
54
for (size_t I = 0 ; I < NumElements; ++I) { \
55
- detail::VecAccess<vec_t >::setValue ( \
56
- Ret, I, \
57
- (detail::VecAccess<vec_t >::getValue (Lhs, I) \
58
- BINOP detail::VecAccess<vec_t >::getValue (Rhs, I))); \
55
+ Ret[I] = Lhs[I] BINOP Rhs[I]; \
59
56
} \
60
57
} else { \
61
- Ret.m_Data = Lhs.m_Data BINOP Rhs.m_Data ; \
58
+ auto ExtVecLhs = sycl::bit_cast<typename vec_t ::vector_t >(Lhs); \
59
+ auto ExtVecRhs = sycl::bit_cast<typename vec_t ::vector_t >(Rhs); \
60
+ Ret = vec<DataT, NumElements>(ExtVecLhs BINOP ExtVecRhs); \
62
61
if constexpr (std::is_same_v<DataT, bool > && CONVERT) { \
63
62
vec_arith_common<bool , NumElements>::ConvertToDataT (Ret); \
64
63
} \
@@ -72,13 +71,9 @@ using rel_t = typename std::conditional_t<
72
71
friend std::enable_if_t <(COND), vec_t > operator BINOP (const vec_t & Lhs, \
73
72
const vec_t & Rhs) { \
74
73
vec_t Ret{}; \
75
- for (size_t I = 0 ; I < NumElements; ++I) \
76
- detail::VecAccess<vec_t >::setValue ( \
77
- Ret, I, \
78
- (DataT)(vec_data<DataT>::get ( \
79
- detail::VecAccess<vec_t >::getValue (Lhs, I)) \
80
- BINOP vec_data<DataT>::get ( \
81
- detail::VecAccess<vec_t >::getValue (Rhs, I)))); \
74
+ for (size_t I = 0 ; I < NumElements; ++I) { \
75
+ Ret[I] = Lhs[I] BINOP Rhs[I]; \
76
+ } \
82
77
return Ret; \
83
78
}
84
79
#endif // __SYCL_DEVICE_ONLY__
@@ -130,83 +125,78 @@ template <typename DataT, int NumElements>
130
125
class vec_arith : public vec_arith_common <DataT, NumElements> {
131
126
protected:
132
127
using vec_t = vec<DataT, NumElements>;
133
- using ocl_t = rel_t <DataT>;
128
+ using ocl_t = detail:: select_cl_scalar_integral_signed_t <DataT>;
134
129
template <typename T> using vec_data = vec_helper<T>;
135
130
136
131
// operator!.
137
- friend vec<rel_t <DataT>, NumElements> operator !(const vec_t &Rhs) {
138
- if constexpr (vec_t ::IsUsingArrayOnDevice || vec_t ::IsUsingArrayOnHost) {
139
- vec_t Ret{};
132
+ friend vec<ocl_t , NumElements> operator !(const vec_t &Rhs) {
133
+ #ifdef __SYCL_DEVICE_ONLY__
134
+ if constexpr (!vec_t ::IsBfloat16) {
135
+ auto extVec = sycl::bit_cast<typename vec_t ::vector_t >(Rhs);
136
+ vec<ocl_t , NumElements> Ret{
137
+ (typename vec<ocl_t , NumElements>::vector_t ) !extVec};
138
+ return Ret;
139
+ } else
140
+ #endif // __SYCL_DEVICE_ONLY__
141
+ {
142
+ vec<ocl_t , NumElements> Ret{};
140
143
for (size_t I = 0 ; I < NumElements; ++I) {
141
- detail::VecAccess< vec_t >:: setValue (
142
- Ret, I,
143
- !vec_data<DataT>:: get (detail::VecAccess< vec_t >:: getValue ( Rhs, I) ));
144
+ // static_cast will work here as the output of ! operator is either 0 or
145
+ // -1.
146
+ Ret[I] = static_cast < ocl_t >(- 1 * (! Rhs[I] ));
144
147
}
145
- return Ret.template as <vec<rel_t <DataT>, NumElements>>();
146
- } else {
147
- return vec_t {(typename vec<DataT, NumElements>::DataType) !Rhs.m_Data }
148
- .template as <vec<rel_t <DataT>, NumElements>>();
148
+ return Ret;
149
149
}
150
150
}
151
151
152
152
// operator +.
153
153
friend vec_t operator +(const vec_t &Lhs) {
154
- if constexpr (vec_t ::IsUsingArrayOnDevice || vec_t ::IsUsingArrayOnHost) {
155
- vec_t Ret{};
156
- for (size_t I = 0 ; I < NumElements; ++I)
157
- detail::VecAccess<vec_t >::setValue (
158
- Ret, I,
159
- vec_data<DataT>::get (+vec_data<DataT>::get (
160
- detail::VecAccess<vec_t >::getValue (Lhs, I))));
161
- return Ret;
162
- } else {
163
- return vec_t {+Lhs.m_Data };
164
- }
154
+ #ifdef __SYCL_DEVICE_ONLY__
155
+ auto extVec = sycl::bit_cast<typename vec_t ::vector_t >(Lhs);
156
+ return vec_t {+extVec};
157
+ #else
158
+ vec_t Ret{};
159
+ for (size_t I = 0 ; I < NumElements; ++I)
160
+ Ret[I] = +Lhs[I];
161
+ return Ret;
162
+ #endif
165
163
}
166
164
167
165
// operator -.
168
166
friend vec_t operator -(const vec_t &Lhs) {
169
- namespace oneapi = sycl::ext::oneapi;
170
167
vec_t Ret{};
171
- if constexpr (vec_t ::IsBfloat16 && NumElements == 1 ) {
172
- oneapi::bfloat16 v = oneapi::detail::bitsToBfloat16 (Lhs.m_Data );
173
- oneapi::bfloat16 w = -v;
174
- Ret.m_Data = oneapi::detail::bfloat16ToBits (w);
175
- } else if constexpr (vec_t ::IsBfloat16) {
176
- for (size_t I = 0 ; I < NumElements; I++) {
177
- oneapi::bfloat16 v = oneapi::detail::bitsToBfloat16 (Lhs.m_Data [I]);
178
- oneapi::bfloat16 w = -v;
179
- Ret.m_Data [I] = oneapi::detail::bfloat16ToBits (w);
180
- }
181
- } else if constexpr (vec_t ::IsUsingArrayOnDevice ||
182
- vec_t ::IsUsingArrayOnHost) {
183
- for (size_t I = 0 ; I < NumElements; ++I)
184
- detail::VecAccess<vec_t >::setValue (
185
- Ret, I,
186
- vec_data<DataT>::get (-vec_data<DataT>::get (
187
- detail::VecAccess<vec_t >::getValue (Lhs, I))));
188
- return Ret;
168
+ if constexpr (vec_t ::IsBfloat16) {
169
+ for (size_t I = 0 ; I < NumElements; I++)
170
+ Ret[I] = -Lhs[I];
189
171
} else {
190
- Ret = vec_t {-Lhs.m_Data };
172
+ #ifndef __SYCL_DEVICE_ONLY__
173
+ for (size_t I = 0 ; I < NumElements; ++I)
174
+ Ret[I] = -Lhs[I];
175
+ #else
176
+ auto extVec = sycl::bit_cast<typename vec_t ::vector_t >(Lhs);
177
+ Ret = vec_t {-extVec};
191
178
if constexpr (std::is_same_v<DataT, bool >) {
192
179
vec_arith_common<bool , NumElements>::ConvertToDataT (Ret);
193
180
}
194
- return Ret;
181
+ # endif
195
182
}
183
+ return Ret;
196
184
}
197
185
198
186
// Unary operations on sycl::vec
187
+ // FIXME: Don't allow Unary operators on vec<bool> after
188
+ // https://github.com/KhronosGroup/SYCL-CTS/issues/896 gets fixed.
199
189
#ifdef __SYCL_UOP
200
190
#error "Undefine __SYCL_UOP macro"
201
191
#endif
202
192
#define __SYCL_UOP (UOP, OPASSIGN ) \
203
193
friend vec_t &operator UOP (vec_t & Rhs) { \
204
- Rhs OPASSIGN vec_data< DataT>:: get ( 1 ); \
194
+ Rhs OPASSIGN DataT{ 1 }; \
205
195
return Rhs; \
206
196
} \
207
197
friend vec_t operator UOP (vec_t &Lhs, int ) { \
208
198
vec_t Ret (Lhs); \
209
- Lhs OPASSIGN vec_data< DataT>:: get ( 1 ); \
199
+ Lhs OPASSIGN DataT{ 1 }; \
210
200
return Ret; \
211
201
}
212
202
@@ -228,25 +218,24 @@ class vec_arith : public vec_arith_common<DataT, NumElements> {
228
218
friend std::enable_if_t <(COND), vec<ocl_t , NumElements>> operator RELLOGOP ( \
229
219
const vec_t & Lhs, const vec_t & Rhs) { \
230
220
vec<ocl_t , NumElements> Ret{}; \
231
- /* This special case is needed since there are no standard operator|| */ \
232
- /* or operator&& functions for std::array. */ \
233
- if constexpr (vec_t ::IsUsingArrayOnDevice && \
234
- (std::string_view (#RELLOGOP) == " ||" || \
235
- std::string_view (#RELLOGOP) == " &&" )) { \
221
+ /* ext_vector_type does not support bfloat16, so for these */ \
222
+ /* we do element-by-element operation on the underlying std::array. */ \
223
+ if constexpr (vec_t ::IsBfloat16) { \
236
224
for (size_t I = 0 ; I < NumElements; ++I) { \
237
- /* We cannot use SetValue here as the operator is not a friend of*/ \
238
- /* Ret on Windows. */ \
239
- Ret[I] = static_cast <ocl_t >( \
240
- -(vec_data<DataT>::get (detail::VecAccess<vec_t >::getValue (Lhs, I)) \
241
- RELLOGOP vec_data<DataT>::get ( \
242
- detail::VecAccess<vec_t >::getValue (Rhs, I)))); \
225
+ Ret[I] = static_cast <ocl_t >(-(Lhs[I] RELLOGOP Rhs[I])); \
243
226
} \
244
227
} else { \
228
+ auto ExtVecLhs = sycl::bit_cast<typename vec_t ::vector_t >(Lhs); \
229
+ auto ExtVecRhs = sycl::bit_cast<typename vec_t ::vector_t >(Rhs); \
230
+ /* Cast required to convert unsigned char ext_vec_type to */ \
231
+ /* char ext_vec_type. */ \
245
232
Ret = vec<ocl_t , NumElements>( \
246
233
(typename vec<ocl_t , NumElements>::vector_t )( \
247
- Lhs.m_Data RELLOGOP Rhs.m_Data )); \
248
- if (NumElements == 1 ) /* Scalar 0/1 logic was applied, invert*/ \
234
+ ExtVecLhs RELLOGOP ExtVecRhs)); \
235
+ /* For NumElements == 1, we use scalar instead of ext_vector_type. */ \
236
+ if constexpr (NumElements == 1 ) { \
249
237
Ret *= -1 ; \
238
+ } \
250
239
} \
251
240
return Ret; \
252
241
}
@@ -257,12 +246,7 @@ class vec_arith : public vec_arith_common<DataT, NumElements> {
257
246
const vec_t & Lhs, const vec_t & Rhs) { \
258
247
vec<ocl_t , NumElements> Ret{}; \
259
248
for (size_t I = 0 ; I < NumElements; ++I) { \
260
- /* We cannot use SetValue here as the operator is not a friend of*/ \
261
- /* Ret on Windows. */ \
262
- Ret[I] = static_cast <ocl_t >( \
263
- -(vec_data<DataT>::get (detail::VecAccess<vec_t >::getValue (Lhs, I)) \
264
- RELLOGOP vec_data<DataT>::get ( \
265
- detail::VecAccess<vec_t >::getValue (Rhs, I)))); \
249
+ Ret[I] = static_cast <ocl_t >(-(Lhs[I] RELLOGOP Rhs[I])); \
266
250
} \
267
251
return Ret; \
268
252
}
@@ -376,34 +360,36 @@ template <typename DataT, int NumElements> class vec_arith_common {
376
360
protected:
377
361
using vec_t = vec<DataT, NumElements>;
378
362
363
+ static constexpr bool IsBfloat16 =
364
+ std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;
365
+
379
366
// operator~() available only when: dataT != float && dataT != double
380
367
// && dataT != half
381
368
template <typename T = DataT>
382
369
friend std::enable_if_t <!detail::is_vgenfloat_v<T>, vec_t >
383
370
operator ~(const vec_t &Rhs) {
384
- if constexpr (vec_t ::IsUsingArrayOnDevice || vec_t ::IsUsingArrayOnHost) {
385
- vec_t Ret{};
386
- for (size_t I = 0 ; I < NumElements; ++I) {
387
- detail::VecAccess<vec_t >::setValue (
388
- Ret, I, ~detail::VecAccess<vec_t >::getValue (Rhs, I));
389
- }
390
- return Ret;
391
- } else {
392
- vec_t Ret{(typename vec_t ::DataType) ~Rhs.m_Data };
393
- if constexpr (std::is_same_v<DataT, bool >) {
394
- vec_arith_common<bool , NumElements>::ConvertToDataT (Ret);
395
- }
396
- return Ret;
371
+ #ifdef __SYCL_DEVICE_ONLY__
372
+ auto extVec = sycl::bit_cast<typename vec_t ::vector_t >(Rhs);
373
+ vec_t Ret{~extVec};
374
+ if constexpr (std::is_same_v<DataT, bool >) {
375
+ ConvertToDataT (Ret);
397
376
}
377
+ return Ret;
378
+ #else
379
+ vec_t Ret{};
380
+ for (size_t I = 0 ; I < NumElements; ++I) {
381
+ Ret[I] = ~Rhs[I];
382
+ }
383
+ return Ret;
384
+ #endif
398
385
}
399
386
400
387
#ifdef __SYCL_DEVICE_ONLY__
401
388
using vec_bool_t = vec<bool , NumElements>;
402
389
// Required only for std::bool.
403
390
static void ConvertToDataT (vec_bool_t &Ret) {
404
391
for (size_t I = 0 ; I < NumElements; ++I) {
405
- DataT Tmp = detail::VecAccess<vec_bool_t >::getValue (Ret, I);
406
- detail::VecAccess<vec_bool_t >::setValue (Ret, I, Tmp);
392
+ Ret[I] = bit_cast<int8_t >(Ret[I]) != 0 ;
407
393
}
408
394
}
409
395
#endif
0 commit comments