Skip to content

[SYCL] Allow USM ptrs in exclusive and inclusive scans #4310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 15, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions sycl/include/CL/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
const ptrdiff_t &divisor) -> ptrdiff_t {
return ((v + divisor - 1) / divisor) * divisor;
};
typename InPtr::element_type x;
typename OutPtr::element_type carry = init;
typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
x;
typename detail::remove_pointer<OutPtr>::type carry = init;
for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
ptrdiff_t i = chunk + offset;
if (i < N) {
x = first[i];
}
typename OutPtr::element_type out =
typename detail::remove_pointer<OutPtr>::type out =
exclusive_scan_over_group(g, x, carry, binary_op);
if (i < N) {
result[i] = out;
Expand Down Expand Up @@ -664,13 +665,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
// FIXME: Do not special-case for half precision
static_assert(
std::is_same<decltype(binary_op(*first, *first)),
typename OutPtr::element_type>::value ||
(std::is_same<typename OutPtr::element_type, half>::value &&
typename detail::remove_pointer<OutPtr>::type>::value ||
(std::is_same<typename detail::remove_pointer<OutPtr>::type,
half>::value &&
std::is_same<decltype(binary_op(*first, *first)), float>::value),
"Result type of binary_op must match scan accumulation type.");
return joint_exclusive_scan(
g, first, last, result,
sycl::known_identity_v<BinaryOperation, typename OutPtr::element_type>,
sycl::known_identity_v<BinaryOperation,
typename detail::remove_pointer<OutPtr>::type>,
binary_op);
}

Expand Down Expand Up @@ -791,14 +794,15 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
const ptrdiff_t &divisor) -> ptrdiff_t {
return ((v + divisor - 1) / divisor) * divisor;
};
typename InPtr::element_type x;
typename OutPtr::element_type carry = init;
typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
x;
typename detail::remove_pointer<OutPtr>::type carry = init;
for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
ptrdiff_t i = chunk + offset;
if (i < N) {
x = first[i];
}
typename OutPtr::element_type out =
typename detail::remove_pointer<OutPtr>::type out =
inclusive_scan_over_group(g, x, binary_op, carry);
if (i < N) {
result[i] = out;
Expand Down Expand Up @@ -830,13 +834,15 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
// FIXME: Do not special-case for half precision
static_assert(
std::is_same<decltype(binary_op(*first, *first)),
typename OutPtr::element_type>::value ||
(std::is_same<typename OutPtr::element_type, half>::value &&
typename detail::remove_pointer<OutPtr>::type>::value ||
(std::is_same<typename detail::remove_pointer<OutPtr>::type,
half>::value &&
std::is_same<decltype(binary_op(*first, *first)), float>::value),
"Result type of binary_op must match scan accumulation type.");
return joint_inclusive_scan(
g, first, last, result, binary_op,
sycl::known_identity_v<BinaryOperation, typename OutPtr::element_type>);
sycl::known_identity_v<BinaryOperation,
typename detail::remove_pointer<OutPtr>::type>);
}

namespace detail {
Expand Down