|
15 | 15 | #include "memory.hpp"
|
16 | 16 | #include "queue.hpp"
|
17 | 17 |
|
| 18 | +#include <ur/ur.hpp> |
| 19 | + |
18 | 20 | extern size_t imageElementByteSize(hipArray_Format ArrayFormat);
|
19 | 21 |
|
20 | 22 | namespace {
|
@@ -49,23 +51,36 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
|
49 | 51 | }
|
50 | 52 | }
|
51 | 53 |
|
52 |
| -void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock, |
53 |
| - const size_t *GlobalWorkSize, |
54 |
| - const size_t MaxThreadsPerBlock[3], |
55 |
| - ur_kernel_handle_t Kernel) { |
| 54 | +// Determine local work sizes that result in uniform work groups. |
| 55 | +// The default threadsPerBlock only require handling the first work_dim |
| 56 | +// dimension. |
| 57 | +void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock, |
| 58 | + const size_t *GlobalWorkSize, const uint32_t WorkDim, |
| 59 | + const size_t MaxThreadsPerBlock[3], |
| 60 | + ur_kernel_handle_t Kernel) { |
56 | 61 | assert(ThreadsPerBlock != nullptr);
|
57 | 62 | assert(GlobalWorkSize != nullptr);
|
58 | 63 | assert(Kernel != nullptr);
|
59 | 64 |
|
60 |
| - std::ignore = Kernel; |
| 65 | + // FIXME: The below assumes a three dimensional range but this is not |
| 66 | + // guaranteed by UR. |
| 67 | + size_t GlobalSizeNormalized[3] = {1, 1, 1}; |
| 68 | + for (uint32_t i = 0; i < WorkDim; i++) { |
| 69 | + GlobalSizeNormalized[i] = GlobalWorkSize[i]; |
| 70 | + } |
| 71 | + |
| 72 | + size_t MaxBlockDim[3]; |
| 73 | + MaxBlockDim[0] = MaxThreadsPerBlock[0]; |
| 74 | + MaxBlockDim[1] = Device->getMaxBlockDimY(); |
| 75 | + MaxBlockDim[2] = Device->getMaxBlockDimZ(); |
61 | 76 |
|
62 |
| - ThreadsPerBlock[0] = std::min(MaxThreadsPerBlock[0], GlobalWorkSize[0]); |
| 77 | + int MinGrid, MaxBlockSize; |
| 78 | + UR_CHECK_ERROR(hipOccupancyMaxPotentialBlockSize( |
| 79 | + &MinGrid, &MaxBlockSize, Kernel->get(), Kernel->getLocalSize(), |
| 80 | + MaxThreadsPerBlock[0])); |
63 | 81 |
|
64 |
| - // Find a local work group size that is a divisor of the global |
65 |
| - // work group size to produce uniform work groups. |
66 |
| - while (GlobalWorkSize[0] % ThreadsPerBlock[0]) { |
67 |
| - --ThreadsPerBlock[0]; |
68 |
| - } |
| 82 | + roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized, |
| 83 | + MaxBlockDim, MaxBlockSize); |
69 | 84 | }
|
70 | 85 |
|
71 | 86 | ur_result_t setHipMemAdvise(const void *DevPtr, const size_t Size,
|
@@ -344,8 +359,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
|
344 | 359 | return err;
|
345 | 360 | }
|
346 | 361 | } else {
|
347 |
| - simpleGuessLocalWorkSize(ThreadsPerBlock, pGlobalWorkSize, |
348 |
| - MaxThreadsPerBlock, hKernel); |
| 362 | + guessLocalWorkSize(hQueue->getDevice(), ThreadsPerBlock, pGlobalWorkSize, |
| 363 | + workDim, MaxThreadsPerBlock, hKernel); |
349 | 364 | }
|
350 | 365 | }
|
351 | 366 |
|
|
0 commit comments