Skip to content

Commit 5515791

Browse files
Revert "[SYCL] Refactor sycl::vec's operators implementation" (#16554)
Reverts #16529 due to failure to build on Windows and MacOS in post-commit.
1 parent 16c2c21 commit 5515791

File tree

3 files changed

+217
-214
lines changed

3 files changed

+217
-214
lines changed

sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,14 @@ template <typename T> constexpr bool is_vec_v = is_vec<T>::value;
3131

3232
template <typename T, typename = void>
3333
struct is_ext_vector : std::false_type {};
34-
template <typename T, typename = void>
35-
struct is_valid_type_for_ext_vector : std::false_type {};
3634
#if defined(__has_extension)
3735
#if __has_extension(attribute_ext_vector_type)
3836
template <typename T, int N>
39-
using ext_vector = T __attribute__((ext_vector_type(N)));
40-
template <typename T, int N>
41-
struct is_ext_vector<ext_vector<T, N>> : std::true_type {};
42-
template <typename T>
43-
struct is_valid_type_for_ext_vector<T, std::void_t<ext_vector<T, 2>>>
44-
: std::true_type {};
37+
struct is_ext_vector<T __attribute__((ext_vector_type(N)))> : std::true_type {};
4538
#endif
4639
#endif
4740
template <typename T>
4841
inline constexpr bool is_ext_vector_v = is_ext_vector<T>::value;
49-
template <typename T>
50-
inline constexpr bool is_valid_type_for_ext_vector_v =
51-
is_valid_type_for_ext_vector<T>::value;
5242

5343
template <typename> struct is_swizzle : std::false_type {};
5444
template <typename VecT, typename OperationLeftT, typename OperationRightT,

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88

99
#pragma once
1010

11-
#include <sycl/aliases.hpp>
12-
#include <sycl/detail/generic_type_traits.hpp>
13-
#include <sycl/detail/type_traits.hpp>
14-
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
15-
#include <sycl/ext/oneapi/bfloat16.hpp>
11+
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
12+
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
13+
#include <sycl/detail/type_traits.hpp> // for is_floating_point
14+
15+
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
16+
17+
#include <cstddef>
18+
#include <type_traits> // for enable_if_t, is_same
1619

1720
namespace sycl {
1821
inline namespace _V1 {
@@ -47,7 +50,13 @@ struct UnaryPlus {
4750
};
4851

4952
struct VecOperators {
50-
template <typename OpTy, typename... ArgTys>
53+
#ifdef __SYCL_DEVICE_ONLY__
54+
static constexpr bool is_host = false;
55+
#else
56+
static constexpr bool is_host = true;
57+
#endif
58+
59+
template <typename BinOp, typename... ArgTys>
5160
static constexpr auto apply(const ArgTys &...Args) {
5261
using Self = nth_type_t<0, ArgTys...>;
5362
static_assert(is_vec_v<Self>);
@@ -56,96 +65,88 @@ struct VecOperators {
5665
using element_type = typename Self::element_type;
5766
constexpr int N = Self::size();
5867
constexpr bool is_logical = check_type_in_v<
59-
OpTy, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
68+
BinOp, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
6069
std::greater<void>, std::less_equal<void>, std::greater_equal<void>,
6170
std::logical_and<void>, std::logical_or<void>, std::logical_not<void>>;
6271

6372
using result_t = std::conditional_t<
6473
is_logical, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
6574

66-
OpTy Op{};
67-
#ifdef __has_extension
68-
#if __has_extension(attribute_ext_vector_type)
69-
// ext_vector_type's bool vectors are mapped onto <N x i1> and have
70-
// different memory layout than sycl::vec<bool ,N> (which has 1 byte per
71-
// element). As such we perform operation on int8_t and then need to
72-
// create bit pattern that can be bit-casted back to the original
73-
// sycl::vec<bool, N>. This is a hack actually, but we've been doing
74-
// that for a long time using sycl::vec::vector_t type.
75-
using vec_elem_ty =
76-
typename detail::map_type<element_type, //
77-
bool, /*->*/ std::int8_t,
78-
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
79-
std::byte, /*->*/ std::uint8_t,
80-
#endif
81-
#ifdef __SYCL_DEVICE_ONLY__
82-
half, /*->*/ _Float16,
83-
#endif
84-
element_type, /*->*/ element_type>::type;
85-
if constexpr (N != 1 &&
86-
detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
87-
using vec_t = ext_vector<vec_elem_ty, N>;
88-
auto tmp = [&](auto... xs) {
75+
BinOp Op{};
76+
if constexpr (is_host || N == 1 ||
77+
std::is_same_v<element_type, ext::oneapi::bfloat16>) {
78+
result_t res{};
79+
for (size_t i = 0; i < N; ++i)
80+
if constexpr (is_logical)
81+
res[i] = Op(Args[i]...) ? -1 : 0;
82+
else
83+
res[i] = Op(Args[i]...);
84+
return res;
85+
} else {
86+
using vector_t = typename Self::vector_t;
87+
88+
auto res = [&](auto... xs) {
8989
// Workaround for https://github.com/llvm/llvm-project/issues/119617.
9090
if constexpr (sizeof...(Args) == 2) {
9191
return [&](auto x, auto y) {
92-
if constexpr (std::is_same_v<OpTy, std::equal_to<void>>)
92+
if constexpr (std::is_same_v<BinOp, std::equal_to<void>>)
9393
return x == y;
94-
else if constexpr (std::is_same_v<OpTy, std::not_equal_to<void>>)
94+
else if constexpr (std::is_same_v<BinOp, std::not_equal_to<void>>)
9595
return x != y;
96-
else if constexpr (std::is_same_v<OpTy, std::less<void>>)
96+
else if constexpr (std::is_same_v<BinOp, std::less<void>>)
9797
return x < y;
98-
else if constexpr (std::is_same_v<OpTy, std::less_equal<void>>)
98+
else if constexpr (std::is_same_v<BinOp, std::less_equal<void>>)
9999
return x <= y;
100-
else if constexpr (std::is_same_v<OpTy, std::greater<void>>)
100+
else if constexpr (std::is_same_v<BinOp, std::greater<void>>)
101101
return x > y;
102-
else if constexpr (std::is_same_v<OpTy, std::greater_equal<void>>)
102+
else if constexpr (std::is_same_v<BinOp, std::greater_equal<void>>)
103103
return x >= y;
104104
else
105105
return Op(x, y);
106106
}(xs...);
107107
} else {
108108
return Op(xs...);
109109
}
110-
}(bit_cast<vec_t>(Args)...);
110+
}(bit_cast<vector_t>(Args)...);
111+
111112
if constexpr (std::is_same_v<element_type, bool>) {
112-
// Some operations are known to produce the required bit patterns and
113-
// the following post-processing isn't necessary for them:
113+
// vec(vector_t) ctor does a simple bit_cast and the way "bool" is
114+
// stored is that only one bit matters. vector_t, however, is a char
115+
// type and it can have non-zero value with lowest bit unset. E.g.,
116+
// consider this:
117+
//
118+
// auto x = true + true; // int x = 2
119+
// bool y = true + true; // bool y = true
120+
//
121+
// and the vec<bool, N> has to behave in a similar way. As such, current
122+
// implementation needs to do some extra processing for operators that
123+
// can result in this scenario.
124+
//
114125
if constexpr (!is_logical &&
115-
!check_type_in_v<OpTy, std::multiplies<void>,
126+
!check_type_in_v<BinOp, std::multiplies<void>,
116127
std::divides<void>, std::bit_or<void>,
117128
std::bit_and<void>, std::bit_xor<void>,
118129
ShiftRight, UnaryPlus>) {
119-
// Extra cast is needed because:
120-
static_assert(std::is_same_v<int8_t, signed char>);
121-
static_assert(!std::is_same_v<
122-
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
123-
ext_vector<int8_t, 2>>);
124-
static_assert(std::is_same_v<
125-
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
126-
ext_vector<char, 2>>);
127-
128-
// `... * -1` is needed because ext_vector_type's comparison follows
129-
// OpenCL binary representation for "true" (-1).
130-
// `std::array<bool, N>` is different and LLVM annotates its
131-
// elements with [0, 2) range metadata when loaded, so we need to
132-
// ensure we generate 0/1 only (and not 2/-1/etc.).
133-
static_assert((ext_vector<int8_t, 2>{1, 0} == 0)[1] == -1);
134-
135-
tmp = reinterpret_cast<decltype(tmp)>((tmp != 0) * -1);
130+
// TODO: Not sure why the following doesn't work
131+
// (test-e2e/Basic/vector/bool.cpp fails).
132+
//
133+
// res = (decltype(res))(res != 0);
134+
for (size_t i = 0; i < N; ++i)
135+
res[i] = bit_cast<int8_t>(res[i]) != 0;
136136
}
137137
}
138-
return bit_cast<result_t>(tmp);
138+
// The following is true:
139+
//
140+
// using char2 = char __attribute__((ext_vector_type(2)));
141+
// using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
142+
// static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
143+
// std::declval<uchar2>()),
144+
// char2>);
145+
//
146+
// so we need some extra casts. Also, static_cast<uchar2>(char2{})
147+
// isn't allowed either.
148+
return result_t{(typename result_t::vector_t)res};
139149
}
140-
#endif
141-
#endif
142-
result_t res{};
143-
for (size_t i = 0; i < N; ++i)
144-
if constexpr (is_logical)
145-
res[i] = Op(Args[i]...) ? -1 : 0;
146-
else
147-
res[i] = Op(Args[i]...);
148-
return res;
149150
}
150151
};
151152

0 commit comments

Comments
 (0)