@@ -49,19 +49,59 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
49
49
}
50
50
}
51
51
52
- void simpleGuessLocalWorkSize (size_t *ThreadsPerBlock,
53
- const size_t *GlobalWorkSize,
54
- const size_t MaxThreadsPerBlock[3 ]) {
52
+ // Determine local work sizes that result in uniform work groups.
53
+ // The default threadsPerBlock only require handling the first work_dim
54
+ // dimension.
55
+ void guessLocalWorkSize (ur_device_handle_t Device, size_t *ThreadsPerBlock,
56
+ const size_t *GlobalWorkSize, const uint32_t WorkDim,
57
+ const size_t MaxThreadsPerBlock[3 ],
58
+ ur_kernel_handle_t Kernel) {
55
59
assert (ThreadsPerBlock != nullptr );
56
60
assert (GlobalWorkSize != nullptr );
61
+ assert (Kernel != nullptr );
62
+ int MinGrid, MaxBlockSize;
63
+ size_t MaxBlockDim[3 ];
64
+
65
+ // The below assumes a three dimensional range but this is not guaranteed by
66
+ // UR.
67
+ size_t GlobalSizeNormalized[3 ] = {1 , 1 , 1 };
68
+ for (uint32_t i = 0 ; i < WorkDim; i++) {
69
+ GlobalSizeNormalized[i] = GlobalWorkSize[i];
70
+ }
57
71
58
- ThreadsPerBlock[0 ] = std::min (MaxThreadsPerBlock[0 ], GlobalWorkSize[0 ]);
72
+ MaxBlockDim[1 ] = Device->getMaxBlockDimY ();
73
+ MaxBlockDim[2 ] = Device->getMaxBlockDimZ ();
74
+
75
+ UR_CHECK_ERROR (hipOccupancyMaxPotentialBlockSize (
76
+ &MinGrid, &MaxBlockSize, Kernel->get (), Kernel->getLocalSize (),
77
+ MaxThreadsPerBlock[0 ]));
78
+
79
+ // Helper lambda to make sure each x, y, z dim divide the global dimension.
80
+ // Can optionally specify that we want the wg size to be a power of 2 in a
81
+ // given dimension, which is useful for the X dim for performance reasons.
82
+ static auto roundToHighestFactorOfGlobalSize =
83
+ [](size_t &ThreadsPerBlockInDim, const size_t GlobalWorkSizeInDim,
84
+ bool MakePowerOfTwo) {
85
+ auto IsPowerOf2 = [](size_t Value) -> bool {
86
+ return Value && !(Value & (Value - 1 ));
87
+ };
88
+ while (GlobalWorkSizeInDim % ThreadsPerBlockInDim ||
89
+ (MakePowerOfTwo && !IsPowerOf2 (ThreadsPerBlockInDim)))
90
+ --ThreadsPerBlockInDim;
91
+ };
59
92
60
- // Find a local work group size that is a divisor of the global
61
- // work group size to produce uniform work groups.
62
- while (GlobalWorkSize[0 ] % ThreadsPerBlock[0 ]) {
63
- --ThreadsPerBlock[0 ];
64
- }
93
+ ThreadsPerBlock[2 ] = std::min (GlobalSizeNormalized[2 ], MaxBlockDim[2 ]);
94
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalWorkSize[2 ],
95
+ false );
96
+ ThreadsPerBlock[1 ] =
97
+ std::min (GlobalSizeNormalized[1 ],
98
+ std::min (MaxBlockSize / ThreadsPerBlock[2 ], MaxBlockDim[1 ]));
99
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalWorkSize[1 ],
100
+ false );
101
+ MaxBlockDim[0 ] = MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[2 ]);
102
+ ThreadsPerBlock[0 ] = std::min (
103
+ MaxThreadsPerBlock[0 ], std::min (GlobalSizeNormalized[0 ], MaxBlockDim[0 ]));
104
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[0 ], GlobalWorkSize[0 ], true );
65
105
}
66
106
67
107
ur_result_t setHipMemAdvise (const void *DevPtr, const size_t Size ,
@@ -340,8 +380,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
340
380
return err;
341
381
}
342
382
} else {
343
- simpleGuessLocalWorkSize ( ThreadsPerBlock, pGlobalWorkSize,
344
- MaxThreadsPerBlock);
383
+ guessLocalWorkSize (hQueue-> getDevice (), ThreadsPerBlock, pGlobalWorkSize,
384
+ workDim, MaxThreadsPerBlock, hKernel );
345
385
}
346
386
}
347
387
0 commit comments