@@ -139,7 +139,6 @@ ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
139
139
// dimension.
140
140
void guessLocalWorkSize (ur_device_handle_t Device, size_t *ThreadsPerBlock,
141
141
const size_t *GlobalWorkSize, const uint32_t WorkDim,
142
- const size_t MaxThreadsPerBlock[3 ],
143
142
ur_kernel_handle_t Kernel, uint32_t LocalSize) {
144
143
assert (ThreadsPerBlock != nullptr );
145
144
assert (GlobalWorkSize != nullptr );
@@ -154,20 +153,21 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
154
153
GlobalSizeNormalized[i] = GlobalWorkSize[i];
155
154
}
156
155
157
- MaxBlockDim[1 ] = Device->getMaxBlockDimY ( );
158
- MaxBlockDim[2 ] = Device->getMaxBlockDimZ ( );
156
+ MaxBlockDim[1 ] = Device->getMaxWorkItemSizes ( 1 );
157
+ MaxBlockDim[2 ] = Device->getMaxWorkItemSizes ( 2 );
159
158
160
- UR_CHECK_ERROR (
161
- cuOccupancyMaxPotentialBlockSize ( &MinGrid, &MaxBlockSize, Kernel->get (),
162
- NULL , LocalSize, MaxThreadsPerBlock[ 0 ] ));
159
+ UR_CHECK_ERROR (cuOccupancyMaxPotentialBlockSize (
160
+ &MinGrid, &MaxBlockSize, Kernel->get (), NULL , LocalSize ,
161
+ Device-> getMaxWorkItemSizes ( 0 ) ));
163
162
164
163
ThreadsPerBlock[2 ] = std::min (GlobalSizeNormalized[2 ], MaxBlockDim[2 ]);
165
164
ThreadsPerBlock[1 ] =
166
165
std::min (GlobalSizeNormalized[1 ],
167
166
std::min (MaxBlockSize / ThreadsPerBlock[2 ], MaxBlockDim[1 ]));
168
167
MaxBlockDim[0 ] = MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[2 ]);
169
- ThreadsPerBlock[0 ] = std::min (
170
- MaxThreadsPerBlock[0 ], std::min (GlobalSizeNormalized[0 ], MaxBlockDim[0 ]));
168
+ ThreadsPerBlock[0 ] =
169
+ std::min (Device->getMaxWorkItemSizes (0 ),
170
+ std::min (GlobalSizeNormalized[0 ], MaxBlockDim[0 ]));
171
171
172
172
static auto IsPowerOf2 = [](size_t Value) -> bool {
173
173
return Value && !(Value & (Value - 1 ));
@@ -213,7 +213,6 @@ setKernelParams(const ur_context_handle_t Context,
213
213
size_t (&BlocksPerGrid)[3]) {
214
214
ur_result_t Result = UR_RESULT_SUCCESS;
215
215
size_t MaxWorkGroupSize = 0u ;
216
- size_t MaxThreadsPerBlock[3 ] = {};
217
216
bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr ;
218
217
uint32_t LocalSize = Kernel->getLocalSize ();
219
218
@@ -223,16 +222,14 @@ setKernelParams(const ur_context_handle_t Context,
223
222
{
224
223
size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock ;
225
224
MaxWorkGroupSize = Device->getMaxWorkGroupSize ();
226
- Device->getMaxWorkItemSizes (sizeof (MaxThreadsPerBlock),
227
- MaxThreadsPerBlock);
228
225
229
226
if (ProvidedLocalWorkGroupSize) {
230
227
auto IsValid = [&](int Dim) {
231
228
if (ReqdThreadsPerBlock[Dim] != 0 &&
232
229
LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
233
230
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
234
231
235
- if (LocalWorkSize[Dim] > MaxThreadsPerBlock[ Dim] )
232
+ if (LocalWorkSize[Dim] > Device-> getMaxWorkItemSizes ( Dim) )
236
233
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
237
234
// Checks that local work sizes are a divisor of the global work sizes
238
235
// which includes that the local work sizes are neither larger than
@@ -261,7 +258,7 @@ setKernelParams(const ur_context_handle_t Context,
261
258
}
262
259
} else {
263
260
guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
264
- MaxThreadsPerBlock, Kernel, LocalSize);
261
+ Kernel, LocalSize);
265
262
}
266
263
}
267
264
0 commit comments