Skip to content

Commit 25a80f9

Browse files
committed
Always make the inner dimension a factor of 2
The inner dimension may be in the [0], [1] or [2] index depending on the dimensionality of the global range. This makes the inner dim always a power of 2
1 parent 3e2effd commit 25a80f9

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
164164
MaxThreadsPerBlock[0]));
165165

166166
roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized,
167-
MaxBlockDim, MaxBlockSize);
167+
MaxBlockDim, MaxBlockSize, WorkDim);
168168
}
169169

170170
// Helper to verify out-of-registers case (exceeded block max registers).

source/adapters/hip/enqueue.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
7474
MaxBlockDim[2] = Device->getMaxBlockDimZ();
7575

7676
roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized,
77-
MaxBlockDim, MaxThreadsPerBlock[0]);
77+
MaxBlockDim, MaxThreadsPerBlock[0],
78+
WorkDim);
7879
}
7980

8081
namespace {

source/ur/ur.hpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -347,23 +347,32 @@ template <typename T> inline bool isPowerOf2(const T &Value) {
347347
// dims == 1)
348348
// In: MaxBlockDim - The max size of block in 3d
349349
// In: MaxBlockSize - The max total size of block in all dimensions
350+
// In: WorkDim - The workdim (1, 2 or 3)
350351
static inline void roundToHighestFactorOfGlobalSizeIn3d(
351352
size_t *ThreadsPerBlock, const size_t *GlobalSize,
352-
const size_t *MaxBlockDim, const size_t MaxBlockSize) {
353+
const size_t *MaxBlockDim, const size_t MaxBlockSize,
354+
const size_t WorkDim) {
353355
ThreadsPerBlock[0] = std::min(GlobalSize[0], MaxBlockDim[0]);
354356
// Make the X dim a factor of 2
355357
do {
356358
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[0], GlobalSize[0]);
357-
} while (!isPowerOf2(ThreadsPerBlock[0]) && ThreadsPerBlock[0] > 32 &&
358-
--ThreadsPerBlock[0]);
359+
} while (WorkDim == 3 && !isPowerOf2(ThreadsPerBlock[0]) &&
360+
ThreadsPerBlock[0] > 32 && --ThreadsPerBlock[0]);
359361

360362
ThreadsPerBlock[1] =
361363
std::min(GlobalSize[1],
362364
std::min(MaxBlockSize / ThreadsPerBlock[0], MaxBlockDim[1]));
363-
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalSize[1]);
365+
do {
366+
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalSize[1]);
367+
} while (WorkDim == 2 && !isPowerOf2(ThreadsPerBlock[1]) &&
368+
ThreadsPerBlock[1] > 32 && --ThreadsPerBlock[1]);
364369

365370
ThreadsPerBlock[2] = std::min(
366-
GlobalSize[2], MaxBlockSize / (ThreadsPerBlock[1] * ThreadsPerBlock[0]));
367-
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalSize[2]);
368-
371+
GlobalSize[2],
372+
std::min(MaxBlockSize / (ThreadsPerBlock[1] * ThreadsPerBlock[0]),
373+
MaxBlockDim[2]));
374+
do {
375+
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalSize[2]);
376+
} while (WorkDim == 1 && !isPowerOf2(ThreadsPerBlock[2]) &&
377+
ThreadsPerBlock[2] > 32 && --ThreadsPerBlock[2]);
369378
}

0 commit comments

Comments
 (0)