Skip to content

Commit 80acd9b

Browse files
[SYCLCompat] Fix compare_mask implementations and test (#16768)
The `compare_mask` and `unordered_compare_mask` implementations were placing the results of the comparison operations in the wrong 2-byte segments of the 4-byte output. The `math_compare.cpp` test has also been fixed, where the "expected" results were previously incorrect, they now reflect the values returned by the corresponding CUDA math functions.
1 parent 3ff6428 commit 80acd9b

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

sycl/include/syclcompat/math.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,8 @@ template <typename ValueT, class BinaryOperation>
598598
inline std::enable_if_t<ValueT::size() == 2, unsigned>
599599
compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
600600
// Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF
601-
return ((-compare(a[0], b[0], binary_op)) << 16) |
602-
((-compare(a[1], b[1], binary_op)) & 0xFFFF);
601+
return ((-compare(a[0], b[0], binary_op)) & 0xFFFF) |
602+
((-compare(a[1], b[1], binary_op)) << 16u);
603603
}
604604

605605
/// Performs 2 elements unordered comparison, compare result of each element is
@@ -613,8 +613,8 @@ template <typename ValueT, class BinaryOperation>
613613
inline std::enable_if_t<ValueT::size() == 2, unsigned>
614614
unordered_compare_mask(const ValueT a, const ValueT b,
615615
const BinaryOperation binary_op) {
616-
return ((-unordered_compare(a[0], b[0], binary_op)) << 16) |
617-
((-unordered_compare(a[1], b[1], binary_op)) & 0xFFFF);
616+
return ((-unordered_compare(a[0], b[0], binary_op)) & 0xFFFF) |
617+
((-unordered_compare(a[1], b[1], binary_op)) << 16);
618618
}
619619

620620
/// Compute vectorized max for two values, with each value treated as a vector

sycl/test-e2e/syclcompat/math/math_compare.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,12 @@ typename ValueT> void test_compare_mask() {
301301
// 1.0 == 1.0, 2.0 == 3.0 -> 0xffff0000
302302
BinaryOpTestLauncher<Container, Container, unsigned>(grid, threads)
303303
.template launch_test<compare_mask_kernel<Container>>(op1, op3,
304-
0xffff0000);
304+
0x0000ffff);
305305

306306
// 1.0 == 3.0, 2.0 == 2.0 -> 0x0000ffff
307307
BinaryOpTestLauncher<Container, Container, unsigned>(grid, threads)
308308
.template launch_test<compare_mask_kernel<Container>>(op1, op4,
309-
0x0000ffff);
309+
0xffff0000);
310310

311311
// 1.0 == NaN, 2.0 == NaN -> 0x00000000
312312
BinaryOpTestLauncher<Container, Container, unsigned>(grid, threads)
@@ -350,12 +350,12 @@ typename ValueT> void test_unordered_compare_mask() {
350350
// 1.0 == 1.0, 2.0 == 3.0 -> 0xffff0000
351351
BinaryOpTestLauncher<Container, Container, unsigned>(grid, threads)
352352
.template launch_test<unordered_compare_mask_kernel<Container>>(
353-
op1, op3, 0xffff0000);
353+
op1, op3, 0x0000ffff);
354354

355355
// 1.0 == 3.0, 2.0 == 2.0 -> 0x0000ffff
356356
BinaryOpTestLauncher<Container, Container, unsigned>(grid, threads)
357357
.template launch_test<unordered_compare_mask_kernel<Container>>(
358-
op1, op4, 0x0000ffff);
358+
op1, op4, 0xffff0000);
359359

360360
// 1.0 == NaN, 2.0 == NaN -> 0xffffffff
361361
BinaryOpTestLauncher<Container, Container, unsigned>(grid, threads)

0 commit comments

Comments
 (0)