@@ -140,7 +140,6 @@ ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
140
140
// dimension.
141
141
void guessLocalWorkSize (ur_device_handle_t Device, size_t *ThreadsPerBlock,
142
142
const size_t *GlobalWorkSize, const uint32_t WorkDim,
143
- const size_t MaxThreadsPerBlock[3 ],
144
143
ur_kernel_handle_t Kernel) {
145
144
assert (ThreadsPerBlock != nullptr );
146
145
assert (GlobalWorkSize != nullptr );
@@ -154,14 +153,14 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
154
153
}
155
154
156
155
size_t MaxBlockDim[3 ];
157
- MaxBlockDim[0 ] = MaxThreadsPerBlock[ 0 ] ;
158
- MaxBlockDim[1 ] = Device->getMaxBlockDimY ( );
159
- MaxBlockDim[2 ] = Device->getMaxBlockDimZ ( );
156
+ MaxBlockDim[0 ] = Device-> getMaxWorkItemSizes ( 0 ) ;
157
+ MaxBlockDim[1 ] = Device->getMaxWorkItemSizes ( 1 );
158
+ MaxBlockDim[2 ] = Device->getMaxWorkItemSizes ( 2 );
160
159
161
160
int MinGrid, MaxBlockSize;
162
161
UR_CHECK_ERROR (cuOccupancyMaxPotentialBlockSize (
163
162
&MinGrid, &MaxBlockSize, Kernel->get (), NULL , Kernel->getLocalSize (),
164
- MaxThreadsPerBlock [0 ]));
163
+ MaxBlockDim [0 ]));
165
164
166
165
roundToHighestFactorOfGlobalSizeIn3d (ThreadsPerBlock, GlobalSizeNormalized,
167
166
MaxBlockDim, MaxBlockSize);
@@ -197,7 +196,6 @@ setKernelParams(const ur_context_handle_t Context,
197
196
size_t (&BlocksPerGrid)[3]) {
198
197
ur_result_t Result = UR_RESULT_SUCCESS;
199
198
size_t MaxWorkGroupSize = 0u ;
200
- size_t MaxThreadsPerBlock[3 ] = {};
201
199
bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr ;
202
200
uint32_t LocalSize = Kernel->getLocalSize ();
203
201
@@ -207,16 +205,14 @@ setKernelParams(const ur_context_handle_t Context,
207
205
{
208
206
size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock ;
209
207
MaxWorkGroupSize = Device->getMaxWorkGroupSize ();
210
- Device->getMaxWorkItemSizes (sizeof (MaxThreadsPerBlock),
211
- MaxThreadsPerBlock);
212
208
213
209
if (ProvidedLocalWorkGroupSize) {
214
210
auto IsValid = [&](int Dim) {
215
211
if (ReqdThreadsPerBlock[Dim] != 0 &&
216
212
LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
217
213
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
218
214
219
- if (LocalWorkSize[Dim] > MaxThreadsPerBlock[ Dim] )
215
+ if (LocalWorkSize[Dim] > Device-> getMaxWorkItemSizes ( Dim) )
220
216
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
221
217
// Checks that local work sizes are a divisor of the global work sizes
222
218
// which includes that the local work sizes are neither larger than
@@ -245,7 +241,7 @@ setKernelParams(const ur_context_handle_t Context,
245
241
}
246
242
} else {
247
243
guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
248
- MaxThreadsPerBlock, Kernel);
244
+ Kernel);
249
245
}
250
246
}
251
247
0 commit comments