Skip to content

[SYCL] Optimize checkValueRange #18296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 83 additions & 39 deletions sycl/include/sycl/detail/id_queries_fit_in_int.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,46 +34,88 @@ inline namespace _V1 {
namespace detail {

#if __SYCL_ID_QUERIES_FIT_IN_INT__
template <typename T> struct NotIntMsg;
constexpr static const char *Msg =
"Provided range and/or offset does not fit in int. Pass "
"`-fno-sycl-id-queries-fit-in-int' to remove this limit.";

template <int Dims> struct NotIntMsg<range<Dims>> {
constexpr static const char *Msg =
"Provided range is out of integer limits. Pass "
"`-fno-sycl-id-queries-fit-in-int' to disable range check.";
};

template <int Dims> struct NotIntMsg<id<Dims>> {
constexpr static const char *Msg =
"Provided offset is out of integer limits. Pass "
"`-fno-sycl-id-queries-fit-in-int' to disable offset check.";
};

template <typename T, typename ValT>
template <typename ValT>
typename std::enable_if_t<std::is_same<ValT, size_t>::value ||
std::is_same<ValT, unsigned long long>::value>
checkValueRangeImpl(ValT V) {
static constexpr size_t Limit =
static_cast<size_t>((std::numeric_limits<int>::max)());
if (V > Limit)
throw sycl::exception(make_error_code(errc::nd_range), NotIntMsg<T>::Msg);
throw sycl::exception(make_error_code(errc::nd_range), Msg);
}

inline void checkMulOverflow(size_t a, size_t b) {
#ifndef _MSC_VER
int Product;
// Since we must fit in SIGNED int, we can ignore the upper 32 bits.
if (__builtin_mul_overflow(unsigned(a), unsigned(b), &Product)) {
throw sycl::exception(make_error_code(errc::nd_range), Msg);
}
#else
checkValueRangeImpl(a);
checkValueRangeImpl(b);
size_t Product = a * b;
checkValueRangeImpl(Product);
#endif
}

inline void checkMulOverflow(size_t a, size_t b, size_t c) {
#ifndef _MSC_VER
int Product;
// Since we must fit in SIGNED int, we can ignore the upper 32 bits.
if (__builtin_mul_overflow(unsigned(a), unsigned(b), &Product) ||
__builtin_mul_overflow(Product, unsigned(c), &Product)) {
throw sycl::exception(make_error_code(errc::nd_range), Msg);
}
#else
checkValueRangeImpl(a);
checkValueRangeImpl(b);
size_t Product = a * b;
checkValueRangeImpl(Product);

checkValueRangeImpl(c);
Product *= c;
checkValueRangeImpl(Product);
#endif
}

// TODO: Remove this function when offsets are removed.
template <int Dims>
inline bool hasNonZeroOffset(const sycl::nd_range<Dims> &V) {
size_t Product = 1;
for (int Dim = 0; Dim < Dims; ++Dim) {
Product *= V.get_offset()[Dim];
}
return (Product != 0);
}
#endif //__SYCL_ID_QUERIES_FIT_IN_INT__

template <int Dims>
void checkValueRange([[maybe_unused]] const sycl::range<Dims> &V) {
#if __SYCL_ID_QUERIES_FIT_IN_INT__
if constexpr (Dims == 1) {
// For 1D range, just check the value against MAX_INT.
checkValueRangeImpl(V[0]);
} else if constexpr (Dims == 2) {
// For 2D range, check if computing the linear range overflows.
checkMulOverflow(V[0], V[1]);
} else if constexpr (Dims == 3) {
// For 3D range, check if computing the linear range overflows.
checkMulOverflow(V[0], V[1], V[2]);
}
#endif
}

template <int Dims, typename T>
typename std::enable_if_t<std::is_same_v<T, range<Dims>> ||
std::is_same_v<T, id<Dims>>>
checkValueRange([[maybe_unused]] const T &V) {
template <int Dims>
void checkValueRange([[maybe_unused]] const sycl::id<Dims> &V) {
#if __SYCL_ID_QUERIES_FIT_IN_INT__
for (size_t Dim = 0; Dim < Dims; ++Dim)
checkValueRangeImpl<T>(V[Dim]);

{
unsigned long long Product = 1;
for (size_t Dim = 0; Dim < Dims; ++Dim) {
Product *= V[Dim];
// check value now to prevent product overflow in the end
checkValueRangeImpl<T>(Product);
}
// An id cannot be linearized without a range, so check each component.
for (int Dim = 0; Dim < Dims; ++Dim) {
checkValueRangeImpl(V[Dim]);
}
#endif
}
Expand All @@ -87,21 +129,23 @@ void checkValueRange([[maybe_unused]] const range<Dims> &R,

for (size_t Dim = 0; Dim < Dims; ++Dim) {
unsigned long long Sum = R[Dim] + O[Dim];

checkValueRangeImpl<range<Dims>>(Sum);
checkValueRangeImpl(Sum);
}
#endif
}

template <int Dims, typename T>
typename std::enable_if_t<std::is_same_v<T, nd_range<Dims>>>
checkValueRange([[maybe_unused]] const T &V) {
template <int Dims>
void checkValueRange([[maybe_unused]] const sycl::nd_range<Dims> &V) {
#if __SYCL_ID_QUERIES_FIT_IN_INT__
checkValueRange<Dims>(V.get_global_range());
checkValueRange<Dims>(V.get_local_range());
checkValueRange<Dims>(V.get_offset());

checkValueRange<Dims>(V.get_global_range(), V.get_offset());
// In an ND-range, we only need to check the global linear size, because:
// - The linear size must be greater than any of the dimensions.
// - Each dimension of the global range is larger than the local range.
// TODO: Remove this branch when offsets are removed.
if (hasNonZeroOffset(V)) /*[[unlikely]]*/ {
checkValueRange<Dims>(V.get_global_range(), V.get_offset());
} else {
checkValueRange<Dims>(V.get_global_range());
}
#endif
}

Expand Down
30 changes: 4 additions & 26 deletions sycl/test-e2e/Basic/range_offset_fit_in_int.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,17 @@

namespace S = sycl;

void checkRangeException(S::exception &E) {
constexpr char Msg[] = "Provided range is out of integer limits. "
"Pass `-fno-sycl-id-queries-fit-in-int' to "
"disable range check.";
constexpr char Msg[] = "Provided range and/or offset does not fit in int. "
"Pass `-fno-sycl-id-queries-fit-in-int' to "
"remove this limit.";

void checkRangeException(S::exception &E) {
std::cerr << E.what() << std::endl;

assert(std::string(E.what()).find(Msg) == 0 && "Unexpected message");
}

void checkOffsetException(S::exception &E) {
constexpr char Msg[] = "Provided offset is out of integer limits. "
"Pass `-fno-sycl-id-queries-fit-in-int' to "
"disable offset check.";

std::cerr << E.what() << std::endl;

assert(std::string(E.what()).find(Msg) == 0 && "Unexpected message");
Expand All @@ -48,8 +44,6 @@ void test() {
S::id<2> OffsetInLimits_Large{(OutOfLimitsSize / 4) * 3, 1};
S::nd_range<2> NDRange_ROL_LIL_OIL{RangeOutOfLimits, RangeInLimits,
OffsetInLimits};
S::nd_range<2> NDRange_RIL_LOL_OIL{RangeInLimits, RangeOutOfLimits,
OffsetInLimits};
S::nd_range<2> NDRange_RIL_LIL_OOL{RangeInLimits, RangeInLimits,
OffsetOutOfLimits};
S::nd_range<2> NDRange_RIL_LIL_OIL(RangeInLimits, RangeInLimits,
Expand Down Expand Up @@ -184,22 +178,6 @@ void test() {
assert(false && "Unexpected exception catched");
}

// small offset, local range is out of limits
try {
Queue.submit([&](S::handler &CGH) {
auto Acc = Buf.get_access<sycl::access::mode::read_write>(CGH);

CGH.parallel_for<class PF_ND_GIL_LOL_OIL>(
NDRange_RIL_LOL_OIL, [Acc](S::nd_item<2> Id) { Acc[0] += 1; });
});

assert(false && "Exception expected");
} catch (S::exception &E) {
checkRangeException(E);
} catch (...) {
assert(false && "Unexpected exception catched");
}

// large offset, ranges are in limits
try {
Queue.submit([&](S::handler &CGH) {
Expand Down