@@ -161,26 +161,32 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
161
161
cuOccupancyMaxPotentialBlockSize (&MinGrid, &MaxBlockSize, Kernel->get (),
162
162
NULL , LocalSize, MaxThreadsPerBlock[0 ]));
163
163
164
+ // Helper lambda to make sure each x, y, z dim divide the global dimension.
165
+ // Can optionally specify that we want the wg size to be a power of 2 in a
166
+ // given dimension, which is useful for the X dim for performance reasons.
167
+ static auto roundToHighestFactorOfGlobalSize =
168
+ [](size_t &ThreadsPerBlockInDim, const size_t GlobalWorkSizeInDim,
169
+ bool MakePowerOfTwo) {
170
+ auto IsPowerOf2 = [](size_t Value) -> bool {
171
+ return Value && !(Value & (Value - 1 ));
172
+ };
173
+ while (GlobalWorkSizeInDim % ThreadsPerBlockInDim ||
174
+ (MakePowerOfTwo && !IsPowerOf2 (ThreadsPerBlockInDim)))
175
+ --ThreadsPerBlockInDim;
176
+ };
177
+
164
178
ThreadsPerBlock[2 ] = std::min (GlobalSizeNormalized[2 ], MaxBlockDim[2 ]);
179
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalWorkSize[2 ],
180
+ false );
165
181
ThreadsPerBlock[1 ] =
166
182
std::min (GlobalSizeNormalized[1 ],
167
183
std::min (MaxBlockSize / ThreadsPerBlock[2 ], MaxBlockDim[1 ]));
184
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalWorkSize[1 ],
185
+ false );
168
186
MaxBlockDim[0 ] = MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[2 ]);
169
187
ThreadsPerBlock[0 ] = std::min (
170
188
MaxThreadsPerBlock[0 ], std::min (GlobalSizeNormalized[0 ], MaxBlockDim[0 ]));
171
-
172
- static auto IsPowerOf2 = [](size_t Value) -> bool {
173
- return Value && !(Value & (Value - 1 ));
174
- };
175
-
176
- // Find a local work group size that is a divisor of the global
177
- // work group size to produce uniform work groups.
178
- // Additionally, for best compute utilisation, the local size has
179
- // to be a power of two.
180
- while (0u != (GlobalSizeNormalized[0 ] % ThreadsPerBlock[0 ]) ||
181
- !IsPowerOf2 (ThreadsPerBlock[0 ])) {
182
- --ThreadsPerBlock[0 ];
183
- }
189
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[0 ], GlobalWorkSize[0 ], true );
184
190
}
185
191
186
192
// Helper to verify out-of-registers case (exceeded block max registers).
0 commit comments