Skip to content

Commit 97088aa

Browse files
committed
Refactor device initialization
Some logic for device initialization was split across platform init and device init. This duplicated some calls to get info for max block dim in Y and Z dimension. This rectifies that and makes sure the max block dims are only queried once, which is called from the device constructor.
1 parent 3fd11f1 commit 97088aa

File tree

2 files changed

+8
-27
lines changed

2 files changed

+8
-27
lines changed

source/adapters/cuda/device.hpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,21 @@ struct ur_device_handle_t_ {
4040
: CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase), RefCount{1},
4141
Platform(platform) {
4242

43-
UR_CHECK_ERROR(cuDeviceGetAttribute(
44-
&MaxBlockDimY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, cuDevice));
45-
UR_CHECK_ERROR(cuDeviceGetAttribute(
46-
&MaxBlockDimZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, cuDevice));
4743
UR_CHECK_ERROR(cuDeviceGetAttribute(
4844
&MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
4945
cuDevice));
5046
UR_CHECK_ERROR(cuDeviceGetAttribute(
5147
&MaxCapacityLocalMem,
5248
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, cuDevice));
5349

50+
UR_CHECK_ERROR(urDeviceGetInfo(this, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES,
51+
sizeof(MaxWorkItemSizes), MaxWorkItemSizes,
52+
nullptr));
53+
54+
UR_CHECK_ERROR(urDeviceGetInfo(this, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
55+
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize,
56+
nullptr));
57+
5458
// Set local mem max size if env var is present
5559
static const char *LocalMemSizePtrUR =
5660
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE");
@@ -91,13 +95,6 @@ struct ur_device_handle_t_ {
9195

9296
uint64_t getElapsedTime(CUevent) const;
9397

94-
void saveMaxWorkItemSizes(size_t Size,
95-
size_t *SaveMaxWorkItemSizes) noexcept {
96-
memcpy(MaxWorkItemSizes, SaveMaxWorkItemSizes, Size);
97-
};
98-
99-
void saveMaxWorkGroupSize(int Value) noexcept { MaxWorkGroupSize = Value; };
100-
10198
void getMaxWorkItemSizes(size_t RetSize,
10299
size_t *RetMaxWorkItemSizes) const noexcept {
103100
memcpy(RetMaxWorkItemSizes, MaxWorkItemSizes, RetSize);

source/adapters/cuda/platform.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,6 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
9595

9696
Platforms[i].Devices.emplace_back(new ur_device_handle_t_{
9797
Device, Context, EvBase, &Platforms[i]});
98-
{
99-
const auto &Dev = Platforms[i].Devices.back().get();
100-
size_t MaxWorkGroupSize = 0u;
101-
size_t MaxThreadsPerBlock[3] = {};
102-
UR_CHECK_ERROR(urDeviceGetInfo(
103-
Dev, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES,
104-
sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr));
105-
106-
UR_CHECK_ERROR(urDeviceGetInfo(
107-
Dev, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
108-
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr));
109-
110-
Dev->saveMaxWorkItemSizes(sizeof(MaxThreadsPerBlock),
111-
MaxThreadsPerBlock);
112-
Dev->saveMaxWorkGroupSize(MaxWorkGroupSize);
113-
}
11498
}
11599
} catch (const std::bad_alloc &) {
116100
// Signal out-of-memory situation

0 commit comments

Comments
 (0)