Skip to content

Commit 2de0f92

Browse files
[SYCL] Allow USM ptrs in exclusive and inclusive scans (#4310)
The group algorithms are currently accepting only multi_ptrs while the SYCL spec (4.17.4) states that data ranges can be described using pointers, iterators or instances of the multi_ptr class. The spec allows to introduce additional restrictions, but now that this implementation supports USM should we still keep enforcing the use of multi_ptr ? The [doc](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/GroupAlgorithms/SYCL_INTEL_group_algorithms.asciidoc) has no reference to the use of multi_ptr.
1 parent daae147 commit 2de0f92

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

sycl/include/CL/sycl/group_algorithm.hpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -624,14 +624,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
624624
const ptrdiff_t &divisor) -> ptrdiff_t {
625625
return ((v + divisor - 1) / divisor) * divisor;
626626
};
627-
typename InPtr::element_type x;
628-
typename OutPtr::element_type carry = init;
627+
typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
628+
x;
629+
typename detail::remove_pointer<OutPtr>::type carry = init;
629630
for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
630631
ptrdiff_t i = chunk + offset;
631632
if (i < N) {
632633
x = first[i];
633634
}
634-
typename OutPtr::element_type out =
635+
typename detail::remove_pointer<OutPtr>::type out =
635636
exclusive_scan_over_group(g, x, carry, binary_op);
636637
if (i < N) {
637638
result[i] = out;
@@ -664,13 +665,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
664665
// FIXME: Do not special-case for half precision
665666
static_assert(
666667
std::is_same<decltype(binary_op(*first, *first)),
667-
typename OutPtr::element_type>::value ||
668-
(std::is_same<typename OutPtr::element_type, half>::value &&
668+
typename detail::remove_pointer<OutPtr>::type>::value ||
669+
(std::is_same<typename detail::remove_pointer<OutPtr>::type,
670+
half>::value &&
669671
std::is_same<decltype(binary_op(*first, *first)), float>::value),
670672
"Result type of binary_op must match scan accumulation type.");
671673
return joint_exclusive_scan(
672674
g, first, last, result,
673-
sycl::known_identity_v<BinaryOperation, typename OutPtr::element_type>,
675+
sycl::known_identity_v<BinaryOperation,
676+
typename detail::remove_pointer<OutPtr>::type>,
674677
binary_op);
675678
}
676679

@@ -791,14 +794,15 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
791794
const ptrdiff_t &divisor) -> ptrdiff_t {
792795
return ((v + divisor - 1) / divisor) * divisor;
793796
};
794-
typename InPtr::element_type x;
795-
typename OutPtr::element_type carry = init;
797+
typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
798+
x;
799+
typename detail::remove_pointer<OutPtr>::type carry = init;
796800
for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
797801
ptrdiff_t i = chunk + offset;
798802
if (i < N) {
799803
x = first[i];
800804
}
801-
typename OutPtr::element_type out =
805+
typename detail::remove_pointer<OutPtr>::type out =
802806
inclusive_scan_over_group(g, x, binary_op, carry);
803807
if (i < N) {
804808
result[i] = out;
@@ -830,13 +834,15 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
830834
// FIXME: Do not special-case for half precision
831835
static_assert(
832836
std::is_same<decltype(binary_op(*first, *first)),
833-
typename OutPtr::element_type>::value ||
834-
(std::is_same<typename OutPtr::element_type, half>::value &&
837+
typename detail::remove_pointer<OutPtr>::type>::value ||
838+
(std::is_same<typename detail::remove_pointer<OutPtr>::type,
839+
half>::value &&
835840
std::is_same<decltype(binary_op(*first, *first)), float>::value),
836841
"Result type of binary_op must match scan accumulation type.");
837842
return joint_inclusive_scan(
838843
g, first, last, result, binary_op,
839-
sycl::known_identity_v<BinaryOperation, typename OutPtr::element_type>);
844+
sycl::known_identity_v<BinaryOperation,
845+
typename detail::remove_pointer<OutPtr>::type>);
840846
}
841847

842848
namespace detail {

0 commit comments

Comments
 (0)