Skip to content

Commit 9fa0ce5

Browse files
[NFC][SYCL] Minor refactoring in sycl::vec
1 parent a7bf1ca commit 9fa0ce5

File tree

4 files changed

+80
-65
lines changed

4 files changed

+80
-65
lines changed

sycl/include/sycl/detail/named_swizzles_mixin.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace sycl {
1818
inline namespace _V1 {
1919
namespace detail {
2020

21+
// Will be defined in another header.
22+
template <typename T> struct from_incomplete;
23+
2124
#ifndef SYCL_SIMPLE_SWIZZLES
2225
#define __SYCL_SWIZZLE_MIXIN_SIMPLE_SWIZZLES
2326
#else
@@ -785,7 +788,8 @@ namespace detail {
785788
return (*static_cast<const Self_ *>(this))[INDEX]; \
786789
}
787790

788-
template <typename Self, int NumElements> struct NamedSwizzlesMixinConst {
791+
template <typename Self, int NumElements = from_incomplete<Self>::size()>
792+
struct NamedSwizzlesMixinConst {
789793
#define __SYCL_SWIZZLE_MIXIN_METHOD(COND, NAME, ...) \
790794
__SYCL_SWIZZLE_MIXIN_METHOD_CONST(COND, NAME, __VA_ARGS__)
791795

@@ -798,7 +802,8 @@ template <typename Self, int NumElements> struct NamedSwizzlesMixinConst {
798802
#undef __SYCL_SWIZZLE_MIXIN_METHOD
799803
};
800804

801-
template <typename Self, int NumElements> struct NamedSwizzlesMixinBoth {
805+
template <typename Self, int NumElements = from_incomplete<Self>::size()>
806+
struct NamedSwizzlesMixinBoth {
802807
#define __SYCL_SWIZZLE_MIXIN_METHOD(COND, NAME, ...) \
803808
__SYCL_SWIZZLE_MIXIN_METHOD_NON_CONST(COND, NAME, __VA_ARGS__) \
804809
__SYCL_SWIZZLE_MIXIN_METHOD_CONST(COND, NAME, __VA_ARGS__)

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@ template <typename DataT, int NumElem> class __SYCL_EBO vec;
2121

2222
namespace detail {
2323

24+
template <typename T> struct from_incomplete;
25+
template <typename T>
26+
struct from_incomplete<const T> : public from_incomplete<T> {};
27+
28+
template <typename DataT, int NumElements>
29+
struct from_incomplete<vec<DataT, NumElements>> {
30+
using element_type = DataT;
31+
static constexpr size_t size() { return NumElements; }
32+
};
33+
34+
template <bool Cond, typename Mixin> struct ApplyIf {};
35+
template <typename Mixin> struct ApplyIf<true, Mixin> : Mixin {};
36+
2437
// We use std::plus<void> and similar to "map" template parameter to an
2538
// overloaded operator. These three below are missing from `<functional>`.
2639
struct ShiftLeft {

sycl/include/sycl/vector.hpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,11 @@ template <typename T> class GetOp {
113113
//
114114
// must go throw `v.x()` returning a swizzle, then its `operator==` returning
115115
// vec<int, 1> and we want that code to compile.
116-
template <typename Vec, typename T, int N, typename = void>
117-
struct ScalarConversionOperatorMixIn {};
116+
template <typename Self> class ScalarConversionOperatorMixIn {
117+
using T = typename from_incomplete<Self>::element_type;
118118

119-
template <typename Vec, typename T, int N>
120-
struct ScalarConversionOperatorMixIn<Vec, T, N, std::enable_if_t<N == 1>> {
121-
operator T() const { return (*static_cast<const Vec *>(this))[0]; }
119+
public:
120+
operator T() const { return (*static_cast<const Self *>(this))[0]; }
122121
};
123122

124123
template <typename T>
@@ -134,10 +133,10 @@ inline constexpr bool is_fundamental_or_half_or_bfloat16 =
134133
template <typename DataT, int NumElements>
135134
class __SYCL_EBO vec
136135
: public detail::vec_arith<DataT, NumElements>,
137-
public detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>,
138-
DataT, NumElements>,
139-
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>,
140-
NumElements> {
136+
public detail::ApplyIf<
137+
NumElements == 1,
138+
detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>>>,
139+
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>> {
141140
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
142141
"DataT must be cv-unqualified");
143142

@@ -177,6 +176,24 @@ class __SYCL_EBO vec
177176
element_type_for_vector_t __attribute__((
178177
ext_vector_type(NumElements)))>;
179178

179+
// Make it a template to avoid ambiguity with `vec(const DataT &)` when
180+
// `vector_t` is the same as `DataT`. Not that the other ctor isn't a template
181+
// so we don't even need a smart `enable_if` condition here, the mere fact of
182+
// this being a template makes the other ctor preferred.
183+
template <
184+
typename vector_t_ = vector_t,
185+
typename = typename std::enable_if_t<std::is_same_v<vector_t_, vector_t>>>
186+
constexpr vec(vector_t_ openclVector) {
187+
m_Data = sycl::bit_cast<DataType>(openclVector);
188+
}
189+
190+
/* @SYCL2020
191+
* Available only when: compiled for the device.
192+
* Converts this SYCL vec instance to the underlying backend-native vector
193+
* type defined by vector_t.
194+
*/
195+
operator vector_t() const { return sycl::bit_cast<vector_t>(m_Data); }
196+
180197
private:
181198
#endif // __SYCL_DEVICE_ONLY__
182199

@@ -299,26 +316,6 @@ class __SYCL_EBO vec
299316
return *this;
300317
}
301318

302-
#ifdef __SYCL_DEVICE_ONLY__
303-
// Make it a template to avoid ambiguity with `vec(const DataT &)` when
304-
// `vector_t` is the same as `DataT`. Not that the other ctor isn't a template
305-
// so we don't even need a smart `enable_if` condition here, the mere fact of
306-
// this being a template makes the other ctor preferred.
307-
template <
308-
typename vector_t_ = vector_t,
309-
typename = typename std::enable_if_t<std::is_same_v<vector_t_, vector_t>>>
310-
constexpr vec(vector_t_ openclVector) {
311-
m_Data = sycl::bit_cast<DataType>(openclVector);
312-
}
313-
314-
/* @SYCL2020
315-
* Available only when: compiled for the device.
316-
* Converts this SYCL vec instance to the underlying backend-native vector
317-
* type defined by vector_t.
318-
*/
319-
operator vector_t() const { return sycl::bit_cast<vector_t>(m_Data); }
320-
#endif // __SYCL_DEVICE_ONLY__
321-
322319
__SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead")
323320
static constexpr size_t get_count() { return size(); }
324321
static constexpr size_t size() noexcept { return NumElements; }

0 commit comments

Comments
 (0)