Skip to content

Commit 71e612a

Browse files
authored
[SYCL][ESIMD] Implement compile-time evaluated math util. (#2133)
Can't have real recursion, even if constexpr - SYCL device compiler does not allow that according to the spec. Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 14fd363 commit 71e612a

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

sycl/include/CL/sycl/intel/esimd/detail/esimd_util.hpp

+36-8
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,42 @@ namespace __esimd {
1717
/// Constant in number of bytes.
1818
enum { BYTE = 1, WORD = 2, DWORD = 4, QWORD = 8, OWORD = 16, GRF = 32 };
1919

20-
/// Compute the next power of 2 at compile time.
21-
static ESIMD_INLINE constexpr unsigned int getNextPowerOf2(unsigned int n,
22-
unsigned int k = 1) {
23-
return (k >= n) ? k : getNextPowerOf2(n, k * 2);
20+
/// Compute next power of 2 of a constexpr with guaranteed compile-time
21+
/// evaluation.
22+
template <unsigned int N, unsigned int K, bool K_gt_eq_N> struct NextPowerOf2;
23+
24+
template <unsigned int N, unsigned int K> struct NextPowerOf2<N, K, true> {
25+
static constexpr unsigned int get() { return K; }
26+
};
27+
28+
template <unsigned int N, unsigned int K> struct NextPowerOf2<N, K, false> {
29+
static constexpr unsigned int get() {
30+
return NextPowerOf2<N, K * 2, K * 2 >= N>::get();
31+
}
32+
};
33+
34+
template <unsigned int N> unsigned int getNextPowerOf2() {
35+
return NextPowerOf2<N, 1, (1 >= N)>::get();
36+
}
37+
38+
template <> unsigned int getNextPowerOf2<0>() { return 0; }
39+
40+
/// Compute binary logarithm of a constexpr with guaranteed compile-time
41+
/// evaluation.
42+
template <unsigned int N, bool N_gt_1> struct Log2;
43+
44+
template <unsigned int N> struct Log2<N, false> {
45+
static constexpr unsigned int get() { return 0; }
46+
};
47+
48+
template <unsigned int N> struct Log2<N, true> {
49+
static constexpr unsigned int get() {
50+
return 1 + Log2<(N >> 1), ((N >> 1) > 1)>::get();
51+
}
52+
};
53+
54+
template <unsigned int N> constexpr unsigned int log2() {
55+
return Log2<N, (N > 1)>::get();
2456
}
2557

2658
/// Check if a given 32 bit positive integer is a power of 2 at compile time.
@@ -33,10 +65,6 @@ static ESIMD_INLINE constexpr bool isPowerOf2(unsigned int n,
3365
return (n & (n - 1)) == 0 && n <= limit;
3466
}
3567

36-
static ESIMD_INLINE constexpr unsigned log2(unsigned n) {
37-
return (n > 1) ? 1 + log2(n >> 1) : 0;
38-
}
39-
4068
} // namespace __esimd
4169

4270
__SYCL_INLINE_NAMESPACE(cl) {

sycl/include/CL/sycl/intel/esimd/esimd_math.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ T1 esimd_reduce(simd<T1, SZ> v) {
18841884
if constexpr (isPowerOf2) {
18851885
return esimd_reduce_single<T0, T1, SZ, OpType>(v);
18861886
} else {
1887-
constexpr unsigned N1 = 1u << __esimd::log2(SZ);
1887+
constexpr unsigned N1 = 1u << __esimd::log2<SZ>();
18881888
constexpr unsigned N2 = SZ - N1;
18891889

18901890
simd<T1, N1> v1 = v.template select<N1, 1>(0);

sycl/include/CL/sycl/intel/esimd/esimd_memory.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ media_block_load(AccessorTy acc, unsigned x, unsigned y) {
567567
static_assert(plane <= 3u, "valid plane index is in range [0, 3]");
568568
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SYCL_EXPLICIT_SIMD__)
569569
constexpr unsigned int RoundedWidth =
570-
Width < 4 ? 4 : __esimd::getNextPowerOf2(Width);
570+
Width < 4 ? 4 : __esimd::getNextPowerOf2<Width>();
571571

572572
if constexpr (Width < RoundedWidth) {
573573
constexpr unsigned int n1 = RoundedWidth / sizeof(T);
@@ -617,7 +617,7 @@ media_block_store(AccessorTy acc, unsigned x, unsigned y, simd<T, m * n> vals) {
617617
static_assert(plane <= 3u, "valid plane index is in range [0, 3]");
618618
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SYCL_EXPLICIT_SIMD__)
619619
constexpr unsigned int RoundedWidth =
620-
Width < 4 ? 4 : __esimd::getNextPowerOf2(Width);
620+
Width < 4 ? 4 : __esimd::getNextPowerOf2<Width>();
621621
constexpr unsigned int n1 = RoundedWidth / sizeof(T);
622622

623623
if constexpr (Width < RoundedWidth) {

0 commit comments

Comments
 (0)