Skip to content

Commit 8b83ba8

Browse files
[NFC][SYCL] Minor refactoring in sycl::vec (#16934)
We're planning big refactoring under preview breaking changes mode to implement latest changes to the SYCL specification (not yet merged but already reviewed with Khronos). This is of several patches to split the changes into NFC refactoring and localized functional changes under preview guard.
1 parent eaff40c commit 8b83ba8

File tree

4 files changed

+80
-65
lines changed

4 files changed

+80
-65
lines changed

sycl/include/sycl/detail/named_swizzles_mixin.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace sycl {
1818
inline namespace _V1 {
1919
namespace detail {
2020

21+
// Will be defined in another header.
22+
template <typename T> struct from_incomplete;
23+
2124
#ifndef SYCL_SIMPLE_SWIZZLES
2225
#define __SYCL_SWIZZLE_MIXIN_SIMPLE_SWIZZLES
2326
#else
@@ -785,7 +788,8 @@ namespace detail {
785788
return (*static_cast<const Self_ *>(this))[INDEX]; \
786789
}
787790

788-
template <typename Self, int NumElements> struct NamedSwizzlesMixinConst {
791+
template <typename Self, int NumElements = from_incomplete<Self>::size()>
792+
struct NamedSwizzlesMixinConst {
789793
#define __SYCL_SWIZZLE_MIXIN_METHOD(COND, NAME, ...) \
790794
__SYCL_SWIZZLE_MIXIN_METHOD_CONST(COND, NAME, __VA_ARGS__)
791795

@@ -798,7 +802,8 @@ template <typename Self, int NumElements> struct NamedSwizzlesMixinConst {
798802
#undef __SYCL_SWIZZLE_MIXIN_METHOD
799803
};
800804

801-
template <typename Self, int NumElements> struct NamedSwizzlesMixinBoth {
805+
template <typename Self, int NumElements = from_incomplete<Self>::size()>
806+
struct NamedSwizzlesMixinBoth {
802807
#define __SYCL_SWIZZLE_MIXIN_METHOD(COND, NAME, ...) \
803808
__SYCL_SWIZZLE_MIXIN_METHOD_NON_CONST(COND, NAME, __VA_ARGS__) \
804809
__SYCL_SWIZZLE_MIXIN_METHOD_CONST(COND, NAME, __VA_ARGS__)

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@ template <typename DataT, int NumElem> class __SYCL_EBO vec;
2121

2222
namespace detail {
2323

24+
template <typename T> struct from_incomplete;
25+
template <typename T>
26+
struct from_incomplete<const T> : public from_incomplete<T> {};
27+
28+
template <typename DataT, int NumElements>
29+
struct from_incomplete<vec<DataT, NumElements>> {
30+
using element_type = DataT;
31+
static constexpr size_t size() { return NumElements; }
32+
};
33+
34+
template <bool Cond, typename Mixin> struct ApplyIf {};
35+
template <typename Mixin> struct ApplyIf<true, Mixin> : Mixin {};
36+
2437
// We use std::plus<void> and similar to "map" template parameter to an
2538
// overloaded operator. These three below are missing from `<functional>`.
2639
struct ShiftLeft {

sycl/include/sycl/vector.hpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,11 @@ template <typename T> class GetOp {
113113
//
114114
// must go throw `v.x()` returning a swizzle, then its `operator==` returning
115115
// vec<int, 1> and we want that code to compile.
116-
template <typename Vec, typename T, int N, typename = void>
117-
struct ScalarConversionOperatorMixIn {};
116+
template <typename Self> class ScalarConversionOperatorMixIn {
117+
using T = typename from_incomplete<Self>::element_type;
118118

119-
template <typename Vec, typename T, int N>
120-
struct ScalarConversionOperatorMixIn<Vec, T, N, std::enable_if_t<N == 1>> {
121-
operator T() const { return (*static_cast<const Vec *>(this))[0]; }
119+
public:
120+
operator T() const { return (*static_cast<const Self *>(this))[0]; }
122121
};
123122

124123
template <typename T>
@@ -134,10 +133,10 @@ inline constexpr bool is_fundamental_or_half_or_bfloat16 =
134133
template <typename DataT, int NumElements>
135134
class __SYCL_EBO vec
136135
: public detail::vec_arith<DataT, NumElements>,
137-
public detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>,
138-
DataT, NumElements>,
139-
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>,
140-
NumElements> {
136+
public detail::ApplyIf<
137+
NumElements == 1,
138+
detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>>>,
139+
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>> {
141140
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
142141
"DataT must be cv-unqualified");
143142

@@ -177,6 +176,24 @@ class __SYCL_EBO vec
177176
element_type_for_vector_t __attribute__((
178177
ext_vector_type(NumElements)))>;
179178

179+
// Make it a template to avoid ambiguity with `vec(const DataT &)` when
180+
// `vector_t` is the same as `DataT`. Not that the other ctor isn't a template
181+
// so we don't even need a smart `enable_if` condition here, the mere fact of
182+
// this being a template makes the other ctor preferred.
183+
template <
184+
typename vector_t_ = vector_t,
185+
typename = typename std::enable_if_t<std::is_same_v<vector_t_, vector_t>>>
186+
constexpr vec(vector_t_ openclVector) {
187+
m_Data = sycl::bit_cast<DataType>(openclVector);
188+
}
189+
190+
/* @SYCL2020
191+
* Available only when: compiled for the device.
192+
* Converts this SYCL vec instance to the underlying backend-native vector
193+
* type defined by vector_t.
194+
*/
195+
operator vector_t() const { return sycl::bit_cast<vector_t>(m_Data); }
196+
180197
private:
181198
#endif // __SYCL_DEVICE_ONLY__
182199

@@ -299,26 +316,6 @@ class __SYCL_EBO vec
299316
return *this;
300317
}
301318

302-
#ifdef __SYCL_DEVICE_ONLY__
303-
// Make it a template to avoid ambiguity with `vec(const DataT &)` when
304-
// `vector_t` is the same as `DataT`. Not that the other ctor isn't a template
305-
// so we don't even need a smart `enable_if` condition here, the mere fact of
306-
// this being a template makes the other ctor preferred.
307-
template <
308-
typename vector_t_ = vector_t,
309-
typename = typename std::enable_if_t<std::is_same_v<vector_t_, vector_t>>>
310-
constexpr vec(vector_t_ openclVector) {
311-
m_Data = sycl::bit_cast<DataType>(openclVector);
312-
}
313-
314-
/* @SYCL2020
315-
* Available only when: compiled for the device.
316-
* Converts this SYCL vec instance to the underlying backend-native vector
317-
* type defined by vector_t.
318-
*/
319-
operator vector_t() const { return sycl::bit_cast<vector_t>(m_Data); }
320-
#endif // __SYCL_DEVICE_ONLY__
321-
322319
__SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead")
323320
static constexpr size_t get_count() { return size(); }
324321
static constexpr size_t size() noexcept { return NumElements; }

0 commit comments

Comments
 (0)