Skip to content

Commit 6e0d90e

Browse files
authored
[SYCLCompat] Fix vectorized_binary impl to make SYCLomatic migrated code run pass (#16553)
--------- Signed-off-by: Jiang, Zhiwei <[email protected]>
1 parent a5f83c2 commit 6e0d90e

File tree

3 files changed

+2992
-32
lines changed

3 files changed

+2992
-32
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,6 +2090,25 @@ struct sub_sat {
20902090
} // namespace syclcompat
20912091
```
20922092
2093+
`vectorized_binary` also supports comparison operators from the standard library (`std::equal_to`, `std::not_equal_to`, etc)
2094+
and the semantics can be modified by changing the comparison operator template instantiation. For example:
2095+
2096+
```cpp
2097+
unsigned int Input1;
2098+
unsigned int Input2;
2099+
// initialize inputs...
2100+
2101+
// Performs comparison on sycl::ushort2, following sycl::vec semantics
2102+
// Returns unsigned int containing, per vector element, 0xFFFF if true, and 0x0000 if false
2103+
syclcompat::vectorized_binary<sycl::ushort2>(
2104+
Input1, Input2, std::equal_to<>());
2105+
2106+
// Performs element-wise comparison on unsigned short
2107+
// Returns unsigned int containing, per vector element, 1 if true, and 0 if false
2108+
syclcompat::vectorized_binary<sycl::ushort2>(
2109+
Input1, Input2, std::equal_to<unsigned short>());
2110+
```
2111+
20932112
The math header provides a set of functions to extend 32-bit operations
20942113
to 33 bit, and handle sign extension internally. There is support for `add`,
20952114
`sub`, `absdiff`, `min` and `max` operations. Each operation provides overloads

sycl/include/syclcompat/math.hpp

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -119,34 +119,13 @@ class vectorized_binary {
119119
}
120120
};
121121

122-
// Vectorized_binary for logical operations
123122
template <typename VecT, class BinaryOperation>
124123
class vectorized_binary<
125124
VecT, BinaryOperation,
126-
std::enable_if_t<std::is_same_v<
127-
bool, decltype(std::declval<BinaryOperation>()(
128-
std::declval<typename VecT::element_type>(),
129-
std::declval<typename VecT::element_type>()))>>> {
125+
std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>> {
130126
public:
131127
inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) {
132-
unsigned result = 0;
133-
constexpr size_t elem_size = 8 * sizeof(typename VecT::element_type);
134-
static_assert(elem_size < 32,
135-
"Vector element size must be less than 4 bytes");
136-
constexpr unsigned bool_mask = (1U << elem_size) - 1;
137-
138-
for (size_t i = 0; i < a.size(); ++i) {
139-
bool comp_result = binary_op(a[i], b[i]);
140-
result |= (comp_result ? bool_mask : 0U) << (i * elem_size);
141-
}
142-
143-
VecT v4;
144-
for (size_t i = 0; i < v4.size(); ++i) {
145-
v4[i] = static_cast<typename VecT::element_type>(
146-
(result >> (i * elem_size)) & bool_mask);
147-
}
148-
149-
return v4;
128+
return binary_op(a, b).template as<VecT>();
150129
}
151130
};
152131

@@ -694,8 +673,9 @@ inline unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op) {
694673
template <typename VecT>
695674
inline unsigned vectorized_sum_abs_diff(unsigned a, unsigned b) {
696675
sycl::vec<unsigned, 1> v0{a}, v1{b};
697-
auto v2 = v0.as<VecT>();
698-
auto v3 = v1.as<VecT>();
676+
// Need convert element type to wider signed type to avoid overflow.
677+
auto v2 = v0.as<VecT>().template convert<int>();
678+
auto v3 = v1.as<VecT>().template convert<int>();
699679
auto v4 = sycl::abs_diff(v2, v3);
700680
unsigned sum = 0;
701681
for (size_t i = 0; i < v4.size(); ++i) {
@@ -1095,13 +1075,8 @@ inline unsigned vectorized_binary(unsigned a, unsigned b,
10951075
auto v3 = v1.as<VecT>();
10961076
auto v4 =
10971077
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1098-
if constexpr (!std::is_same_v<
1099-
bool, decltype(std::declval<BinaryOperation>()(
1100-
std::declval<typename VecT::element_type>(),
1101-
std::declval<typename VecT::element_type>()))>) {
1102-
if (need_relu)
1103-
v4 = relu(v4);
1104-
}
1078+
if (need_relu)
1079+
v4 = relu(v4);
11051080
v0 = v4.template as<sycl::vec<unsigned, 1>>();
11061081
return v0;
11071082
}

0 commit comments

Comments
 (0)