Skip to content

[SYCL][ESIMD] Implement compile-time getNextPowerOf2 w/o recursion. #2133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions sycl/include/CL/sycl/intel/esimd/detail/esimd_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,42 @@ namespace __esimd {
/// Constant in number of bytes.
enum { BYTE = 1, WORD = 2, DWORD = 4, QWORD = 8, OWORD = 16, GRF = 32 };

/// Compute the next power of 2 at compile time.
static ESIMD_INLINE constexpr unsigned int getNextPowerOf2(unsigned int n,
unsigned int k = 1) {
return (k >= n) ? k : getNextPowerOf2(n, k * 2);
/// Compute next power of 2 of a constexpr with guaranteed compile-time
/// evaluation.
template <unsigned int N, unsigned int K, bool K_gt_eq_N> struct NextPowerOf2;

template <unsigned int N, unsigned int K> struct NextPowerOf2<N, K, true> {
static constexpr unsigned int get() { return K; }
};

template <unsigned int N, unsigned int K> struct NextPowerOf2<N, K, false> {
static constexpr unsigned int get() {
return NextPowerOf2<N, K * 2, K * 2 >= N>::get();
}
};

template <unsigned int N> unsigned int getNextPowerOf2() {
return NextPowerOf2<N, 1, (1 >= N)>::get();
}

template <> unsigned int getNextPowerOf2<0>() { return 0; }

/// Compute binary logarithm of a constexpr with guaranteed compile-time
/// evaluation.
template <unsigned int N, bool N_gt_1> struct Log2;

template <unsigned int N> struct Log2<N, false> {
static constexpr unsigned int get() { return 0; }
};

template <unsigned int N> struct Log2<N, true> {
static constexpr unsigned int get() {
return 1 + Log2<(N >> 1), ((N >> 1) > 1)>::get();
}
};

template <unsigned int N> constexpr unsigned int log2() {
return Log2<N, (N > 1)>::get();
}

/// Check if a given 32 bit positive integer is a power of 2 at compile time.
Expand All @@ -33,10 +65,6 @@ static ESIMD_INLINE constexpr bool isPowerOf2(unsigned int n,
return (n & (n - 1)) == 0 && n <= limit;
}

static ESIMD_INLINE constexpr unsigned log2(unsigned n) {
return (n > 1) ? 1 + log2(n >> 1) : 0;
}

} // namespace __esimd

__SYCL_INLINE_NAMESPACE(cl) {
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/CL/sycl/intel/esimd/esimd_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,7 @@ T1 esimd_reduce(simd<T1, SZ> v) {
if constexpr (isPowerOf2) {
return esimd_reduce_single<T0, T1, SZ, OpType>(v);
} else {
constexpr unsigned N1 = 1u << __esimd::log2(SZ);
constexpr unsigned N1 = 1u << __esimd::log2<SZ>();
constexpr unsigned N2 = SZ - N1;

simd<T1, N1> v1 = v.template select<N1, 1>(0);
Expand Down
4 changes: 2 additions & 2 deletions sycl/include/CL/sycl/intel/esimd/esimd_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ media_block_load(AccessorTy acc, unsigned x, unsigned y) {
static_assert(plane <= 3u, "valid plane index is in range [0, 3]");
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SYCL_EXPLICIT_SIMD__)
constexpr unsigned int RoundedWidth =
Width < 4 ? 4 : __esimd::getNextPowerOf2(Width);
Width < 4 ? 4 : __esimd::getNextPowerOf2<Width>();

if constexpr (Width < RoundedWidth) {
constexpr unsigned int n1 = RoundedWidth / sizeof(T);
Expand Down Expand Up @@ -617,7 +617,7 @@ media_block_store(AccessorTy acc, unsigned x, unsigned y, simd<T, m * n> vals) {
static_assert(plane <= 3u, "valid plane index is in range [0, 3]");
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SYCL_EXPLICIT_SIMD__)
constexpr unsigned int RoundedWidth =
Width < 4 ? 4 : __esimd::getNextPowerOf2(Width);
Width < 4 ? 4 : __esimd::getNextPowerOf2<Width>();
constexpr unsigned int n1 = RoundedWidth / sizeof(T);

if constexpr (Width < RoundedWidth) {
Expand Down