Skip to content

[NFC][SYCL] Minor refactoring in sycl::vec #16934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sycl/include/sycl/detail/named_swizzles_mixin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

// Will be defined in another header.
template <typename T> struct from_incomplete;

#ifndef SYCL_SIMPLE_SWIZZLES
#define __SYCL_SWIZZLE_MIXIN_SIMPLE_SWIZZLES
#else
Expand Down Expand Up @@ -785,7 +788,8 @@ namespace detail {
return (*static_cast<const Self_ *>(this))[INDEX]; \
}

template <typename Self, int NumElements> struct NamedSwizzlesMixinConst {
template <typename Self, int NumElements = from_incomplete<Self>::size()>
struct NamedSwizzlesMixinConst {
#define __SYCL_SWIZZLE_MIXIN_METHOD(COND, NAME, ...) \
__SYCL_SWIZZLE_MIXIN_METHOD_CONST(COND, NAME, __VA_ARGS__)

Expand All @@ -798,7 +802,8 @@ template <typename Self, int NumElements> struct NamedSwizzlesMixinConst {
#undef __SYCL_SWIZZLE_MIXIN_METHOD
};

template <typename Self, int NumElements> struct NamedSwizzlesMixinBoth {
template <typename Self, int NumElements = from_incomplete<Self>::size()>
struct NamedSwizzlesMixinBoth {
#define __SYCL_SWIZZLE_MIXIN_METHOD(COND, NAME, ...) \
__SYCL_SWIZZLE_MIXIN_METHOD_NON_CONST(COND, NAME, __VA_ARGS__) \
__SYCL_SWIZZLE_MIXIN_METHOD_CONST(COND, NAME, __VA_ARGS__)
Expand Down
13 changes: 13 additions & 0 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ template <typename DataT, int NumElem> class __SYCL_EBO vec;

namespace detail {

template <typename T> struct from_incomplete;
template <typename T>
struct from_incomplete<const T> : public from_incomplete<T> {};

template <typename DataT, int NumElements>
struct from_incomplete<vec<DataT, NumElements>> {
using element_type = DataT;
static constexpr size_t size() { return NumElements; }
};

template <bool Cond, typename Mixin> struct ApplyIf {};
template <typename Mixin> struct ApplyIf<true, Mixin> : Mixin {};

// We use std::plus<void> and similar to "map" template parameter to an
// overloaded operator. These three below are missing from `<functional>`.
struct ShiftLeft {
Expand Down
55 changes: 26 additions & 29 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,11 @@ template <typename T> class GetOp {
//
// must go throw `v.x()` returning a swizzle, then its `operator==` returning
// vec<int, 1> and we want that code to compile.
template <typename Vec, typename T, int N, typename = void>
struct ScalarConversionOperatorMixIn {};
template <typename Self> class ScalarConversionOperatorMixIn {
using T = typename from_incomplete<Self>::element_type;

template <typename Vec, typename T, int N>
struct ScalarConversionOperatorMixIn<Vec, T, N, std::enable_if_t<N == 1>> {
operator T() const { return (*static_cast<const Vec *>(this))[0]; }
public:
operator T() const { return (*static_cast<const Self *>(this))[0]; }
};

template <typename T>
Expand All @@ -134,10 +133,10 @@ inline constexpr bool is_fundamental_or_half_or_bfloat16 =
template <typename DataT, int NumElements>
class __SYCL_EBO vec
: public detail::vec_arith<DataT, NumElements>,
public detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>,
DataT, NumElements>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>,
NumElements> {
public detail::ApplyIf<
NumElements == 1,
detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>>>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>> {
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
"DataT must be cv-unqualified");

Expand Down Expand Up @@ -177,6 +176,24 @@ class __SYCL_EBO vec
element_type_for_vector_t __attribute__((
ext_vector_type(NumElements)))>;

// Make it a template to avoid ambiguity with `vec(const DataT &)` when
// `vector_t` is the same as `DataT`. Not that the other ctor isn't a template
// so we don't even need a smart `enable_if` condition here, the mere fact of
// this being a template makes the other ctor preferred.
template <
typename vector_t_ = vector_t,
typename = typename std::enable_if_t<std::is_same_v<vector_t_, vector_t>>>
constexpr vec(vector_t_ openclVector) {
m_Data = sycl::bit_cast<DataType>(openclVector);
}

/* @SYCL2020
* Available only when: compiled for the device.
* Converts this SYCL vec instance to the underlying backend-native vector
* type defined by vector_t.
*/
operator vector_t() const { return sycl::bit_cast<vector_t>(m_Data); }

private:
#endif // __SYCL_DEVICE_ONLY__

Expand Down Expand Up @@ -299,26 +316,6 @@ class __SYCL_EBO vec
return *this;
}

#ifdef __SYCL_DEVICE_ONLY__
// Make it a template to avoid ambiguity with `vec(const DataT &)` when
// `vector_t` is the same as `DataT`. Not that the other ctor isn't a template
// so we don't even need a smart `enable_if` condition here, the mere fact of
// this being a template makes the other ctor preferred.
template <
typename vector_t_ = vector_t,
typename = typename std::enable_if_t<std::is_same_v<vector_t_, vector_t>>>
constexpr vec(vector_t_ openclVector) {
m_Data = sycl::bit_cast<DataType>(openclVector);
}

/* @SYCL2020
* Available only when: compiled for the device.
* Converts this SYCL vec instance to the underlying backend-native vector
* type defined by vector_t.
*/
operator vector_t() const { return sycl::bit_cast<vector_t>(m_Data); }
#endif // __SYCL_DEVICE_ONLY__

__SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead")
static constexpr size_t get_count() { return size(); }
static constexpr size_t size() noexcept { return NumElements; }
Expand Down
Loading