Skip to content

Commit 8a5e9a3

Browse files
committed
Fix bug in guessLocalWorkSize in CUDA adapter
The guessLocalWorkSize func for cuda adapter was erroneously giving a factor of 2 even if it creates a global dim which is not the same as the user range without using any range rounding kernels. This reverts that.
1 parent 1ac6f95 commit 8a5e9a3

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,26 +161,32 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
161161
cuOccupancyMaxPotentialBlockSize(&MinGrid, &MaxBlockSize, Kernel->get(),
162162
NULL, LocalSize, MaxThreadsPerBlock[0]));
163163

164+
// Helper lambda to make sure each x, y, z dim divide the global dimension.
165+
// Can optionally specify that we want the wg size to be a power of 2 in a
166+
// given dimension, which is useful for the X dim for performance reasons.
167+
static auto roundToHighestFactorOfGlobalSize =
168+
[](size_t &ThreadsPerBlockInDim, const size_t GlobalWorkSizeInDim,
169+
bool MakePowerOfTwo) {
170+
auto IsPowerOf2 = [](size_t Value) -> bool {
171+
return Value && !(Value & (Value - 1));
172+
};
173+
while (GlobalWorkSizeInDim % ThreadsPerBlockInDim ||
174+
(MakePowerOfTwo && !IsPowerOf2(ThreadsPerBlockInDim)))
175+
--ThreadsPerBlockInDim;
176+
};
177+
164178
ThreadsPerBlock[2] = std::min(GlobalSizeNormalized[2], MaxBlockDim[2]);
179+
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalWorkSize[2],
180+
false);
165181
ThreadsPerBlock[1] =
166182
std::min(GlobalSizeNormalized[1],
167183
std::min(MaxBlockSize / ThreadsPerBlock[2], MaxBlockDim[1]));
184+
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalWorkSize[1],
185+
false);
168186
MaxBlockDim[0] = MaxBlockSize / (ThreadsPerBlock[1] * ThreadsPerBlock[2]);
169187
ThreadsPerBlock[0] = std::min(
170188
MaxThreadsPerBlock[0], std::min(GlobalSizeNormalized[0], MaxBlockDim[0]));
171-
172-
static auto IsPowerOf2 = [](size_t Value) -> bool {
173-
return Value && !(Value & (Value - 1));
174-
};
175-
176-
// Find a local work group size that is a divisor of the global
177-
// work group size to produce uniform work groups.
178-
// Additionally, for best compute utilisation, the local size has
179-
// to be a power of two.
180-
while (0u != (GlobalSizeNormalized[0] % ThreadsPerBlock[0]) ||
181-
!IsPowerOf2(ThreadsPerBlock[0])) {
182-
--ThreadsPerBlock[0];
183-
}
189+
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[0], GlobalWorkSize[0], true);
184190
}
185191

186192
// Helper to verify out-of-registers case (exceeded block max registers).

0 commit comments

Comments
 (0)