Skip to content

Commit a75f7d0

Browse files
aarongreigkbenzie
authored andcommitted
Merge pull request oneapi-src#1363 from hdelan/refactor-device-initialization
[CUDA] Refactor device initialization
1 parent f67c6e4 commit a75f7d0

File tree

4 files changed

+21
-58
lines changed

4 files changed

+21
-58
lines changed

source/adapters/cuda/device.hpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ struct ur_device_handle_t_ {
2727
size_t MaxWorkItemSizes[MaxWorkItemDimensions];
2828
size_t MaxWorkGroupSize{0};
2929
size_t MaxAllocSize{0};
30-
int MaxBlockDimY{0};
31-
int MaxBlockDimZ{0};
3230
int MaxRegsPerBlock{0};
3331
int MaxCapacityLocalMem{0};
3432
int MaxChosenLocalMem{0};
@@ -40,17 +38,21 @@ struct ur_device_handle_t_ {
4038
: CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase), RefCount{1},
4139
Platform(platform) {
4240

43-
UR_CHECK_ERROR(cuDeviceGetAttribute(
44-
&MaxBlockDimY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, cuDevice));
45-
UR_CHECK_ERROR(cuDeviceGetAttribute(
46-
&MaxBlockDimZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, cuDevice));
4741
UR_CHECK_ERROR(cuDeviceGetAttribute(
4842
&MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
4943
cuDevice));
5044
UR_CHECK_ERROR(cuDeviceGetAttribute(
5145
&MaxCapacityLocalMem,
5246
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, cuDevice));
5347

48+
UR_CHECK_ERROR(urDeviceGetInfo(this, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES,
49+
sizeof(MaxWorkItemSizes), MaxWorkItemSizes,
50+
nullptr));
51+
52+
UR_CHECK_ERROR(urDeviceGetInfo(this, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
53+
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize,
54+
nullptr));
55+
5456
// Set local mem max size if env var is present
5557
static const char *LocalMemSizePtrUR =
5658
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE");
@@ -91,24 +93,12 @@ struct ur_device_handle_t_ {
9193

9294
uint64_t getElapsedTime(CUevent) const;
9395

94-
void saveMaxWorkItemSizes(size_t Size,
95-
size_t *SaveMaxWorkItemSizes) noexcept {
96-
memcpy(MaxWorkItemSizes, SaveMaxWorkItemSizes, Size);
97-
};
98-
99-
void saveMaxWorkGroupSize(int Value) noexcept { MaxWorkGroupSize = Value; };
100-
101-
void getMaxWorkItemSizes(size_t RetSize,
102-
size_t *RetMaxWorkItemSizes) const noexcept {
103-
memcpy(RetMaxWorkItemSizes, MaxWorkItemSizes, RetSize);
104-
};
96+
size_t getMaxWorkItemSizes(int index) const noexcept {
97+
return MaxWorkItemSizes[index];
98+
}
10599

106100
size_t getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; };
107101

108-
size_t getMaxBlockDimY() const noexcept { return MaxBlockDimY; };
109-
110-
size_t getMaxBlockDimZ() const noexcept { return MaxBlockDimZ; };
111-
112102
size_t getMaxRegsPerBlock() const noexcept { return MaxRegsPerBlock; };
113103

114104
size_t getMaxAllocSize() const noexcept { return MaxAllocSize; };

source/adapters/cuda/enqueue.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
140140
// dimension.
141141
void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
142142
const size_t *GlobalWorkSize, const uint32_t WorkDim,
143-
const size_t MaxThreadsPerBlock[3],
144143
ur_kernel_handle_t Kernel) {
145144
assert(ThreadsPerBlock != nullptr);
146145
assert(GlobalWorkSize != nullptr);
@@ -154,14 +153,14 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
154153
}
155154

156155
size_t MaxBlockDim[3];
157-
MaxBlockDim[0] = MaxThreadsPerBlock[0];
158-
MaxBlockDim[1] = Device->getMaxBlockDimY();
159-
MaxBlockDim[2] = Device->getMaxBlockDimZ();
156+
MaxBlockDim[0] = Device->getMaxWorkItemSizes(0);
157+
MaxBlockDim[1] = Device->getMaxWorkItemSizes(1);
158+
MaxBlockDim[2] = Device->getMaxWorkItemSizes(2);
160159

161160
int MinGrid, MaxBlockSize;
162161
UR_CHECK_ERROR(cuOccupancyMaxPotentialBlockSize(
163162
&MinGrid, &MaxBlockSize, Kernel->get(), NULL, Kernel->getLocalSize(),
164-
MaxThreadsPerBlock[0]));
163+
MaxBlockDim[0]));
165164

166165
roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized,
167166
MaxBlockDim, MaxBlockSize);
@@ -197,7 +196,6 @@ setKernelParams(const ur_context_handle_t Context,
197196
size_t (&BlocksPerGrid)[3]) {
198197
ur_result_t Result = UR_RESULT_SUCCESS;
199198
size_t MaxWorkGroupSize = 0u;
200-
size_t MaxThreadsPerBlock[3] = {};
201199
bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr;
202200
uint32_t LocalSize = Kernel->getLocalSize();
203201

@@ -207,16 +205,14 @@ setKernelParams(const ur_context_handle_t Context,
207205
{
208206
size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock;
209207
MaxWorkGroupSize = Device->getMaxWorkGroupSize();
210-
Device->getMaxWorkItemSizes(sizeof(MaxThreadsPerBlock),
211-
MaxThreadsPerBlock);
212208

213209
if (ProvidedLocalWorkGroupSize) {
214210
auto IsValid = [&](int Dim) {
215211
if (ReqdThreadsPerBlock[Dim] != 0 &&
216212
LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
217213
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
218214

219-
if (LocalWorkSize[Dim] > MaxThreadsPerBlock[Dim])
215+
if (LocalWorkSize[Dim] > Device->getMaxWorkItemSizes(Dim))
220216
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
221217
// Checks that local work sizes are a divisor of the global work sizes
222218
// which includes that the local work sizes are neither larger than
@@ -245,7 +241,7 @@ setKernelParams(const ur_context_handle_t Context,
245241
}
246242
} else {
247243
guessLocalWorkSize(Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
248-
MaxThreadsPerBlock, Kernel);
244+
Kernel);
249245
}
250246
}
251247

source/adapters/cuda/kernel.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,6 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
6868
case UR_KERNEL_GROUP_INFO_GLOBAL_WORK_SIZE: {
6969
size_t GlobalWorkSize[3] = {0, 0, 0};
7070

71-
int MaxBlockDimX{0}, MaxBlockDimY{0}, MaxBlockDimZ{0};
72-
UR_CHECK_ERROR(cuDeviceGetAttribute(
73-
&MaxBlockDimX, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, hDevice->get()));
74-
UR_CHECK_ERROR(cuDeviceGetAttribute(
75-
&MaxBlockDimY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, hDevice->get()));
76-
UR_CHECK_ERROR(cuDeviceGetAttribute(
77-
&MaxBlockDimZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, hDevice->get()));
78-
7971
int MaxGridDimX{0}, MaxGridDimY{0}, MaxGridDimZ{0};
8072
UR_CHECK_ERROR(cuDeviceGetAttribute(
8173
&MaxGridDimX, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, hDevice->get()));
@@ -84,9 +76,10 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
8476
UR_CHECK_ERROR(cuDeviceGetAttribute(
8577
&MaxGridDimZ, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, hDevice->get()));
8678

87-
GlobalWorkSize[0] = MaxBlockDimX * MaxGridDimX;
88-
GlobalWorkSize[1] = MaxBlockDimY * MaxGridDimY;
89-
GlobalWorkSize[2] = MaxBlockDimZ * MaxGridDimZ;
79+
GlobalWorkSize[0] = hDevice->getMaxWorkItemSizes(0) * MaxGridDimX;
80+
GlobalWorkSize[1] = hDevice->getMaxWorkItemSizes(1) * MaxGridDimY;
81+
GlobalWorkSize[2] = hDevice->getMaxWorkItemSizes(2) * MaxGridDimZ;
82+
9083
return ReturnValue(GlobalWorkSize, 3);
9184
}
9285
case UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE: {

source/adapters/cuda/platform.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,6 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
9595

9696
Platforms[i].Devices.emplace_back(new ur_device_handle_t_{
9797
Device, Context, EvBase, &Platforms[i]});
98-
{
99-
const auto &Dev = Platforms[i].Devices.back().get();
100-
size_t MaxWorkGroupSize = 0u;
101-
size_t MaxThreadsPerBlock[3] = {};
102-
UR_CHECK_ERROR(urDeviceGetInfo(
103-
Dev, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES,
104-
sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr));
105-
106-
UR_CHECK_ERROR(urDeviceGetInfo(
107-
Dev, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
108-
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr));
109-
110-
Dev->saveMaxWorkItemSizes(sizeof(MaxThreadsPerBlock),
111-
MaxThreadsPerBlock);
112-
Dev->saveMaxWorkGroupSize(MaxWorkGroupSize);
113-
}
11498
}
11599
} catch (const std::bad_alloc &) {
116100
// Signal out-of-memory situation

0 commit comments

Comments
 (0)