|
16 | 16 | #include "memory.hpp"
|
17 | 17 | #include "queue.hpp"
|
18 | 18 |
|
| 19 | +#include <ur/ur.hpp> |
| 20 | + |
19 | 21 | extern size_t imageElementByteSize(hipArray_Format ArrayFormat);
|
20 | 22 |
|
21 | 23 | ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
|
@@ -48,23 +50,36 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
|
48 | 50 | }
|
49 | 51 | }
|
50 | 52 |
|
51 |
| -void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock, |
52 |
| - const size_t *GlobalWorkSize, |
53 |
| - const size_t MaxThreadsPerBlock[3], |
54 |
| - ur_kernel_handle_t Kernel) { |
| 53 | +// Determine local work sizes that result in uniform work groups. |
| 54 | +// The default threadsPerBlock only require handling the first work_dim |
| 55 | +// dimension. |
| 56 | +void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock, |
| 57 | + const size_t *GlobalWorkSize, const uint32_t WorkDim, |
| 58 | + const size_t MaxThreadsPerBlock[3], |
| 59 | + ur_kernel_handle_t Kernel) { |
55 | 60 | assert(ThreadsPerBlock != nullptr);
|
56 | 61 | assert(GlobalWorkSize != nullptr);
|
57 | 62 | assert(Kernel != nullptr);
|
58 | 63 |
|
59 |
| - std::ignore = Kernel; |
| 64 | + // FIXME: The below assumes a three dimensional range but this is not |
| 65 | + // guaranteed by UR. |
| 66 | + size_t GlobalSizeNormalized[3] = {1, 1, 1}; |
| 67 | + for (uint32_t i = 0; i < WorkDim; i++) { |
| 68 | + GlobalSizeNormalized[i] = GlobalWorkSize[i]; |
| 69 | + } |
| 70 | + |
| 71 | + size_t MaxBlockDim[3]; |
| 72 | + MaxBlockDim[0] = MaxThreadsPerBlock[0]; |
| 73 | + MaxBlockDim[1] = Device->getMaxBlockDimY(); |
| 74 | + MaxBlockDim[2] = Device->getMaxBlockDimZ(); |
60 | 75 |
|
61 |
| - ThreadsPerBlock[0] = std::min(MaxThreadsPerBlock[0], GlobalWorkSize[0]); |
| 76 | + int MinGrid, MaxBlockSize; |
| 77 | + UR_CHECK_ERROR(hipOccupancyMaxPotentialBlockSize( |
| 78 | + &MinGrid, &MaxBlockSize, Kernel->get(), Kernel->getLocalSize(), |
| 79 | + MaxThreadsPerBlock[0])); |
62 | 80 |
|
63 |
| - // Find a local work group size that is a divisor of the global |
64 |
| - // work group size to produce uniform work groups. |
65 |
| - while (GlobalWorkSize[0] % ThreadsPerBlock[0]) { |
66 |
| - --ThreadsPerBlock[0]; |
67 |
| - } |
| 81 | + roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized, |
| 82 | + MaxBlockDim, MaxBlockSize); |
68 | 83 | }
|
69 | 84 |
|
70 | 85 | namespace {
|
@@ -1793,8 +1808,8 @@ setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
|
1793 | 1808 | return err;
|
1794 | 1809 | }
|
1795 | 1810 | } else {
|
1796 |
| - simpleGuessLocalWorkSize(ThreadsPerBlock, GlobalWorkSize, |
1797 |
| - MaxThreadsPerBlock, Kernel); |
| 1811 | + guessLocalWorkSize(Device, ThreadsPerBlock, GlobalWorkSize, WorkDim, |
| 1812 | + MaxThreadsPerBlock, Kernel); |
1798 | 1813 | }
|
1799 | 1814 | }
|
1800 | 1815 |
|
|
0 commit comments