@@ -119,34 +119,13 @@ class vectorized_binary {
119
119
}
120
120
};
121
121
122
- // Vectorized_binary for logical operations
123
122
template <typename VecT, class BinaryOperation >
124
123
class vectorized_binary <
125
124
VecT, BinaryOperation,
126
- std::enable_if_t <std::is_same_v<
127
- bool , decltype(std::declval<BinaryOperation>()(
128
- std::declval<typename VecT::element_type>(),
129
- std::declval<typename VecT::element_type>()))>>> {
125
+ std::void_t <std::invoke_result_t <BinaryOperation, VecT, VecT>>> {
130
126
public:
131
127
inline VecT operator ()(VecT a, VecT b, const BinaryOperation binary_op) {
132
- unsigned result = 0 ;
133
- constexpr size_t elem_size = 8 * sizeof (typename VecT::element_type);
134
- static_assert (elem_size < 32 ,
135
- " Vector element size must be less than 4 bytes" );
136
- constexpr unsigned bool_mask = (1U << elem_size) - 1 ;
137
-
138
- for (size_t i = 0 ; i < a.size (); ++i) {
139
- bool comp_result = binary_op (a[i], b[i]);
140
- result |= (comp_result ? bool_mask : 0U ) << (i * elem_size);
141
- }
142
-
143
- VecT v4;
144
- for (size_t i = 0 ; i < v4.size (); ++i) {
145
- v4[i] = static_cast <typename VecT::element_type>(
146
- (result >> (i * elem_size)) & bool_mask);
147
- }
148
-
149
- return v4;
128
+ return binary_op (a, b).template as <VecT>();
150
129
}
151
130
};
152
131
@@ -694,8 +673,9 @@ inline unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op) {
694
673
template <typename VecT>
695
674
inline unsigned vectorized_sum_abs_diff (unsigned a, unsigned b) {
696
675
sycl::vec<unsigned , 1 > v0{a}, v1{b};
697
- auto v2 = v0.as <VecT>();
698
- auto v3 = v1.as <VecT>();
676
+ // Need convert element type to wider signed type to avoid overflow.
677
+ auto v2 = v0.as <VecT>().template convert <int >();
678
+ auto v3 = v1.as <VecT>().template convert <int >();
699
679
auto v4 = sycl::abs_diff (v2, v3);
700
680
unsigned sum = 0 ;
701
681
for (size_t i = 0 ; i < v4.size (); ++i) {
@@ -1095,13 +1075,8 @@ inline unsigned vectorized_binary(unsigned a, unsigned b,
1095
1075
auto v3 = v1.as <VecT>();
1096
1076
auto v4 =
1097
1077
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1098
- if constexpr (!std::is_same_v<
1099
- bool , decltype (std::declval<BinaryOperation>()(
1100
- std::declval<typename VecT::element_type>(),
1101
- std::declval<typename VecT::element_type>()))>) {
1102
- if (need_relu)
1103
- v4 = relu (v4);
1104
- }
1078
+ if (need_relu)
1079
+ v4 = relu (v4);
1105
1080
v0 = v4.template as <sycl::vec<unsigned , 1 >>();
1106
1081
return v0;
1107
1082
}
0 commit comments