8
8
9
9
#pragma once
10
10
11
- #include < sycl/aliases.hpp>
12
- #include < sycl/detail/generic_type_traits.hpp>
13
- #include < sycl/detail/type_traits.hpp>
14
- #include < sycl/detail/type_traits/vec_marray_traits.hpp>
15
- #include < sycl/ext/oneapi/bfloat16.hpp>
11
+ #include < sycl/aliases.hpp> // for half, cl_char, cl_int
12
+ #include < sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
13
+ #include < sycl/detail/type_traits.hpp> // for is_floating_point
14
+
15
+ #include < sycl/ext/oneapi/bfloat16.hpp> // bfloat16
16
+
17
+ #include < cstddef>
18
+ #include < type_traits> // for enable_if_t, is_same
16
19
17
20
namespace sycl {
18
21
inline namespace _V1 {
@@ -47,7 +50,13 @@ struct UnaryPlus {
47
50
};
48
51
49
52
struct VecOperators {
50
- template <typename OpTy, typename ... ArgTys>
53
+ #ifdef __SYCL_DEVICE_ONLY__
54
+ static constexpr bool is_host = false ;
55
+ #else
56
+ static constexpr bool is_host = true ;
57
+ #endif
58
+
59
+ template <typename BinOp, typename ... ArgTys>
51
60
static constexpr auto apply (const ArgTys &...Args) {
52
61
using Self = nth_type_t <0 , ArgTys...>;
53
62
static_assert (is_vec_v<Self>);
@@ -56,96 +65,88 @@ struct VecOperators {
56
65
using element_type = typename Self::element_type;
57
66
constexpr int N = Self::size ();
58
67
constexpr bool is_logical = check_type_in_v<
59
- OpTy , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
68
+ BinOp , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
60
69
std::greater<void >, std::less_equal<void >, std::greater_equal<void >,
61
70
std::logical_and<void >, std::logical_or<void >, std::logical_not<void >>;
62
71
63
72
using result_t = std::conditional_t <
64
73
is_logical, vec<fixed_width_signed<sizeof (element_type)>, N>, Self>;
65
74
66
- OpTy Op{};
67
- #ifdef __has_extension
68
- #if __has_extension(attribute_ext_vector_type)
69
- // ext_vector_type's bool vectors are mapped onto <N x i1> and have
70
- // different memory layout than sycl::vec<bool ,N> (which has 1 byte per
71
- // element). As such we perform operation on int8_t and then need to
72
- // create bit pattern that can be bit-casted back to the original
73
- // sycl::vec<bool, N>. This is a hack actually, but we've been doing
74
- // that for a long time using sycl::vec::vector_t type.
75
- using vec_elem_ty =
76
- typename detail::map_type<element_type, //
77
- bool , /* ->*/ std::int8_t ,
78
- #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
79
- std::byte, /* ->*/ std::uint8_t ,
80
- #endif
81
- #ifdef __SYCL_DEVICE_ONLY__
82
- half, /* ->*/ _Float16,
83
- #endif
84
- element_type, /* ->*/ element_type>::type;
85
- if constexpr (N != 1 &&
86
- detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
87
- using vec_t = ext_vector<vec_elem_ty, N>;
88
- auto tmp = [&](auto ... xs) {
75
+ BinOp Op{};
76
+ if constexpr (is_host || N == 1 ||
77
+ std::is_same_v<element_type, ext::oneapi::bfloat16>) {
78
+ result_t res{};
79
+ for (size_t i = 0 ; i < N; ++i)
80
+ if constexpr (is_logical)
81
+ res[i] = Op (Args[i]...) ? -1 : 0 ;
82
+ else
83
+ res[i] = Op (Args[i]...);
84
+ return res;
85
+ } else {
86
+ using vector_t = typename Self::vector_t ;
87
+
88
+ auto res = [&](auto ... xs) {
89
89
// Workaround for https://github.com/llvm/llvm-project/issues/119617.
90
90
if constexpr (sizeof ...(Args) == 2 ) {
91
91
return [&](auto x, auto y) {
92
- if constexpr (std::is_same_v<OpTy , std::equal_to<void >>)
92
+ if constexpr (std::is_same_v<BinOp , std::equal_to<void >>)
93
93
return x == y;
94
- else if constexpr (std::is_same_v<OpTy , std::not_equal_to<void >>)
94
+ else if constexpr (std::is_same_v<BinOp , std::not_equal_to<void >>)
95
95
return x != y;
96
- else if constexpr (std::is_same_v<OpTy , std::less<void >>)
96
+ else if constexpr (std::is_same_v<BinOp , std::less<void >>)
97
97
return x < y;
98
- else if constexpr (std::is_same_v<OpTy , std::less_equal<void >>)
98
+ else if constexpr (std::is_same_v<BinOp , std::less_equal<void >>)
99
99
return x <= y;
100
- else if constexpr (std::is_same_v<OpTy , std::greater<void >>)
100
+ else if constexpr (std::is_same_v<BinOp , std::greater<void >>)
101
101
return x > y;
102
- else if constexpr (std::is_same_v<OpTy , std::greater_equal<void >>)
102
+ else if constexpr (std::is_same_v<BinOp , std::greater_equal<void >>)
103
103
return x >= y;
104
104
else
105
105
return Op (x, y);
106
106
}(xs...);
107
107
} else {
108
108
return Op (xs...);
109
109
}
110
- }(bit_cast<vec_t >(Args)...);
110
+ }(bit_cast<vector_t >(Args)...);
111
+
111
112
if constexpr (std::is_same_v<element_type, bool >) {
112
- // Some operations are known to produce the required bit patterns and
113
- // the following post-processing isn't necessary for them:
113
+ // vec(vector_t) ctor does a simple bit_cast and the way "bool" is
114
+ // stored is that only one bit matters. vector_t, however, is a char
115
+ // type and it can have non-zero value with lowest bit unset. E.g.,
116
+ // consider this:
117
+ //
118
+ // auto x = true + true; // int x = 2
119
+ // bool y = true + true; // bool y = true
120
+ //
121
+ // and the vec<bool, N> has to behave in a similar way. As such, current
122
+ // implementation needs to do some extra processing for operators that
123
+ // can result in this scenario.
124
+ //
114
125
if constexpr (!is_logical &&
115
- !check_type_in_v<OpTy , std::multiplies<void >,
126
+ !check_type_in_v<BinOp , std::multiplies<void >,
116
127
std::divides<void >, std::bit_or<void >,
117
128
std::bit_and<void >, std::bit_xor<void >,
118
129
ShiftRight, UnaryPlus>) {
119
- // Extra cast is needed because:
120
- static_assert (std::is_same_v<int8_t , signed char >);
121
- static_assert (!std::is_same_v<
122
- decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
123
- ext_vector<int8_t , 2 >>);
124
- static_assert (std::is_same_v<
125
- decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
126
- ext_vector<char , 2 >>);
127
-
128
- // `... * -1` is needed because ext_vector_type's comparison follows
129
- // OpenCL binary representation for "true" (-1).
130
- // `std::array<bool, N>` is different and LLVM annotates its
131
- // elements with [0, 2) range metadata when loaded, so we need to
132
- // ensure we generate 0/1 only (and not 2/-1/etc.).
133
- static_assert ((ext_vector<int8_t , 2 >{1 , 0 } == 0 )[1 ] == -1 );
134
-
135
- tmp = reinterpret_cast <decltype (tmp)>((tmp != 0 ) * -1 );
130
+ // TODO: Not sure why the following doesn't work
131
+ // (test-e2e/Basic/vector/bool.cpp fails).
132
+ //
133
+ // res = (decltype(res))(res != 0);
134
+ for (size_t i = 0 ; i < N; ++i)
135
+ res[i] = bit_cast<int8_t >(res[i]) != 0 ;
136
136
}
137
137
}
138
- return bit_cast<result_t >(tmp);
138
+ // The following is true:
139
+ //
140
+ // using char2 = char __attribute__((ext_vector_type(2)));
141
+ // using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
142
+ // static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
143
+ // std::declval<uchar2>()),
144
+ // char2>);
145
+ //
146
+ // so we need some extra casts. Also, static_cast<uchar2>(char2{})
147
+ // isn't allowed either.
148
+ return result_t {(typename result_t ::vector_t )res};
139
149
}
140
- #endif
141
- #endif
142
- result_t res{};
143
- for (size_t i = 0 ; i < N; ++i)
144
- if constexpr (is_logical)
145
- res[i] = Op (Args[i]...) ? -1 : 0 ;
146
- else
147
- res[i] = Op (Args[i]...);
148
- return res;
149
150
}
150
151
};
151
152
0 commit comments