Skip to content

Commit 68e51aa

Browse files
[SYCL] Optimize checkValueRange (#18296)
checkValueRange is used to determine if an nd_range is compatible with -fsycl-queries-fit-in-int, and is run as part of every kernel launch. The previous implementation checked the size of each component of the global range, local range, offset, and global range + offset, and also checked the linearized version of each of these values. The new implementation simplifies these checks, based on the following logic: - The linear global range size must be >= every component of the global range. If the linear global range fits in int, we don't need to check anything else. - Each value in the global range must be >= the value in the local range. If the global range fits in int, we don't need to check the local range. - There is no need to check offset-related values if the offset is zero. The new implementation also makes use of __builtin_mul_overflow where available. This shifts the burden of maintaining fast code for these checks to the compiler, and allows us to benefit from aggressive optimizations. The new implementation could be optimized further if there was a quick way to check whether an nd_range has an offset. --------- Signed-off-by: John Pennycook <[email protected]> Co-authored-by: Udit Kumar Agarwal <[email protected]>
1 parent 6f48502 commit 68e51aa

File tree

2 files changed

+87
-65
lines changed

2 files changed

+87
-65
lines changed

sycl/include/sycl/detail/id_queries_fit_in_int.hpp

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,46 +34,88 @@ inline namespace _V1 {
3434
namespace detail {
3535

3636
#if __SYCL_ID_QUERIES_FIT_IN_INT__
37-
template <typename T> struct NotIntMsg;
37+
constexpr static const char *Msg =
38+
"Provided range and/or offset does not fit in int. Pass "
39+
"`-fno-sycl-id-queries-fit-in-int' to remove this limit.";
3840

39-
template <int Dims> struct NotIntMsg<range<Dims>> {
40-
constexpr static const char *Msg =
41-
"Provided range is out of integer limits. Pass "
42-
"`-fno-sycl-id-queries-fit-in-int' to disable range check.";
43-
};
44-
45-
template <int Dims> struct NotIntMsg<id<Dims>> {
46-
constexpr static const char *Msg =
47-
"Provided offset is out of integer limits. Pass "
48-
"`-fno-sycl-id-queries-fit-in-int' to disable offset check.";
49-
};
50-
51-
template <typename T, typename ValT>
41+
template <typename ValT>
5242
typename std::enable_if_t<std::is_same<ValT, size_t>::value ||
5343
std::is_same<ValT, unsigned long long>::value>
5444
checkValueRangeImpl(ValT V) {
5545
static constexpr size_t Limit =
5646
static_cast<size_t>((std::numeric_limits<int>::max)());
5747
if (V > Limit)
58-
throw sycl::exception(make_error_code(errc::nd_range), NotIntMsg<T>::Msg);
48+
throw sycl::exception(make_error_code(errc::nd_range), Msg);
49+
}
50+
51+
inline void checkMulOverflow(size_t a, size_t b) {
52+
#ifndef _MSC_VER
53+
int Product;
54+
// Since we must fit in SIGNED int, we can ignore the upper 32 bits.
55+
if (__builtin_mul_overflow(unsigned(a), unsigned(b), &Product)) {
56+
throw sycl::exception(make_error_code(errc::nd_range), Msg);
57+
}
58+
#else
59+
checkValueRangeImpl(a);
60+
checkValueRangeImpl(b);
61+
size_t Product = a * b;
62+
checkValueRangeImpl(Product);
63+
#endif
64+
}
65+
66+
inline void checkMulOverflow(size_t a, size_t b, size_t c) {
67+
#ifndef _MSC_VER
68+
int Product;
69+
// Since we must fit in SIGNED int, we can ignore the upper 32 bits.
70+
if (__builtin_mul_overflow(unsigned(a), unsigned(b), &Product) ||
71+
__builtin_mul_overflow(Product, unsigned(c), &Product)) {
72+
throw sycl::exception(make_error_code(errc::nd_range), Msg);
73+
}
74+
#else
75+
checkValueRangeImpl(a);
76+
checkValueRangeImpl(b);
77+
size_t Product = a * b;
78+
checkValueRangeImpl(Product);
79+
80+
checkValueRangeImpl(c);
81+
Product *= c;
82+
checkValueRangeImpl(Product);
83+
#endif
84+
}
85+
86+
// TODO: Remove this function when offsets are removed.
87+
template <int Dims>
88+
inline bool hasNonZeroOffset(const sycl::nd_range<Dims> &V) {
89+
size_t Product = 1;
90+
for (int Dim = 0; Dim < Dims; ++Dim) {
91+
Product *= V.get_offset()[Dim];
92+
}
93+
return (Product != 0);
5994
}
95+
#endif //__SYCL_ID_QUERIES_FIT_IN_INT__
96+
97+
template <int Dims>
98+
void checkValueRange([[maybe_unused]] const sycl::range<Dims> &V) {
99+
#if __SYCL_ID_QUERIES_FIT_IN_INT__
100+
if constexpr (Dims == 1) {
101+
// For 1D range, just check the value against MAX_INT.
102+
checkValueRangeImpl(V[0]);
103+
} else if constexpr (Dims == 2) {
104+
// For 2D range, check if computing the linear range overflows.
105+
checkMulOverflow(V[0], V[1]);
106+
} else if constexpr (Dims == 3) {
107+
// For 3D range, check if computing the linear range overflows.
108+
checkMulOverflow(V[0], V[1], V[2]);
109+
}
60110
#endif
111+
}
61112

62-
template <int Dims, typename T>
63-
typename std::enable_if_t<std::is_same_v<T, range<Dims>> ||
64-
std::is_same_v<T, id<Dims>>>
65-
checkValueRange([[maybe_unused]] const T &V) {
113+
template <int Dims>
114+
void checkValueRange([[maybe_unused]] const sycl::id<Dims> &V) {
66115
#if __SYCL_ID_QUERIES_FIT_IN_INT__
67-
for (size_t Dim = 0; Dim < Dims; ++Dim)
68-
checkValueRangeImpl<T>(V[Dim]);
69-
70-
{
71-
unsigned long long Product = 1;
72-
for (size_t Dim = 0; Dim < Dims; ++Dim) {
73-
Product *= V[Dim];
74-
// check value now to prevent product overflow in the end
75-
checkValueRangeImpl<T>(Product);
76-
}
116+
// An id cannot be linearized without a range, so check each component.
117+
for (int Dim = 0; Dim < Dims; ++Dim) {
118+
checkValueRangeImpl(V[Dim]);
77119
}
78120
#endif
79121
}
@@ -87,21 +129,23 @@ void checkValueRange([[maybe_unused]] const range<Dims> &R,
87129

88130
for (size_t Dim = 0; Dim < Dims; ++Dim) {
89131
unsigned long long Sum = R[Dim] + O[Dim];
90-
91-
checkValueRangeImpl<range<Dims>>(Sum);
132+
checkValueRangeImpl(Sum);
92133
}
93134
#endif
94135
}
95136

96-
template <int Dims, typename T>
97-
typename std::enable_if_t<std::is_same_v<T, nd_range<Dims>>>
98-
checkValueRange([[maybe_unused]] const T &V) {
137+
template <int Dims>
138+
void checkValueRange([[maybe_unused]] const sycl::nd_range<Dims> &V) {
99139
#if __SYCL_ID_QUERIES_FIT_IN_INT__
100-
checkValueRange<Dims>(V.get_global_range());
101-
checkValueRange<Dims>(V.get_local_range());
102-
checkValueRange<Dims>(V.get_offset());
103-
104-
checkValueRange<Dims>(V.get_global_range(), V.get_offset());
140+
// In an ND-range, we only need to check the global linear size, because:
141+
// - The linear size must be greater than any of the dimensions.
142+
// - Each dimension of the global range is larger than the local range.
143+
// TODO: Remove this branch when offsets are removed.
144+
if (hasNonZeroOffset(V)) /*[[unlikely]]*/ {
145+
checkValueRange<Dims>(V.get_global_range(), V.get_offset());
146+
} else {
147+
checkValueRange<Dims>(V.get_global_range());
148+
}
105149
#endif
106150
}
107151

sycl/test-e2e/Basic/range_offset_fit_in_int.cpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,17 @@
88

99
namespace S = sycl;
1010

11-
void checkRangeException(S::exception &E) {
12-
constexpr char Msg[] = "Provided range is out of integer limits. "
13-
"Pass `-fno-sycl-id-queries-fit-in-int' to "
14-
"disable range check.";
11+
constexpr char Msg[] = "Provided range and/or offset does not fit in int. "
12+
"Pass `-fno-sycl-id-queries-fit-in-int' to "
13+
"remove this limit.";
1514

15+
void checkRangeException(S::exception &E) {
1616
std::cerr << E.what() << std::endl;
1717

1818
assert(std::string(E.what()).find(Msg) == 0 && "Unexpected message");
1919
}
2020

2121
void checkOffsetException(S::exception &E) {
22-
constexpr char Msg[] = "Provided offset is out of integer limits. "
23-
"Pass `-fno-sycl-id-queries-fit-in-int' to "
24-
"disable offset check.";
25-
2622
std::cerr << E.what() << std::endl;
2723

2824
assert(std::string(E.what()).find(Msg) == 0 && "Unexpected message");
@@ -48,8 +44,6 @@ void test() {
4844
S::id<2> OffsetInLimits_Large{(OutOfLimitsSize / 4) * 3, 1};
4945
S::nd_range<2> NDRange_ROL_LIL_OIL{RangeOutOfLimits, RangeInLimits,
5046
OffsetInLimits};
51-
S::nd_range<2> NDRange_RIL_LOL_OIL{RangeInLimits, RangeOutOfLimits,
52-
OffsetInLimits};
5347
S::nd_range<2> NDRange_RIL_LIL_OOL{RangeInLimits, RangeInLimits,
5448
OffsetOutOfLimits};
5549
S::nd_range<2> NDRange_RIL_LIL_OIL(RangeInLimits, RangeInLimits,
@@ -184,22 +178,6 @@ void test() {
184178
assert(false && "Unexpected exception catched");
185179
}
186180

187-
// small offset, local range is out of limits
188-
try {
189-
Queue.submit([&](S::handler &CGH) {
190-
auto Acc = Buf.get_access<sycl::access::mode::read_write>(CGH);
191-
192-
CGH.parallel_for<class PF_ND_GIL_LOL_OIL>(
193-
NDRange_RIL_LOL_OIL, [Acc](S::nd_item<2> Id) { Acc[0] += 1; });
194-
});
195-
196-
assert(false && "Exception expected");
197-
} catch (S::exception &E) {
198-
checkRangeException(E);
199-
} catch (...) {
200-
assert(false && "Unexpected exception catched");
201-
}
202-
203181
// large offset, ranges are in limits
204182
try {
205183
Queue.submit([&](S::handler &CGH) {

0 commit comments

Comments
 (0)