Skip to content

Commit 7a3deca

Browse files
bmyatesJaime Arteaga
authored and
Jaime Arteaga
committed
Add implementation of USM pools (intel#11)
Signed-off-by: Brandon Yates <[email protected]>
1 parent 7ce01a7 commit 7a3deca

File tree

3 files changed

+117
-37
lines changed

3 files changed

+117
-37
lines changed

sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
573573
case UR_DEVICE_INFO_USM_CROSS_SHARED_SUPPORT:
574574
case UR_DEVICE_INFO_USM_SYSTEM_SHARED_SUPPORT: {
575575
auto MapCaps = [](const ze_memory_access_cap_flags_t &ZeCapabilities) {
576-
uint64_t Capabilities = 0;
576+
ur_device_usm_access_capability_flags_t Capabilities = 0;
577577
if (ZeCapabilities & ZE_MEMORY_ACCESS_CAP_FLAG_RW)
578578
Capabilities |= UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS;
579579
if (ZeCapabilities & ZE_MEMORY_ACCESS_CAP_FLAG_ATOMIC)

sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.cpp

Lines changed: 99 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
2424
Size, ///< [in] size in bytes of the USM memory object to be allocated
2525
void **RetMem ///< [out] pointer to USM host memory object
2626
) {
27-
std::ignore = Pool;
2827

29-
uint32_t Align = USMDesc->align;
28+
uint32_t Align = USMDesc ? USMDesc->align : 0;
3029
// L0 supports alignment up to 64KB and silently ignores higher values.
3130
// We flag alignment > 64KB as an invalid value.
3231
if (Align > 65536)
3332
return UR_RESULT_ERROR_INVALID_VALUE;
3433

35-
const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints;
36-
std::ignore = USMHintFlags;
37-
3834
ur_platform_handle_t Plt = Context->getPlatform();
3935
// If indirect access tracking is enabled then lock the mutex which is
4036
// guarding contexts container in the platform. This prevents new kernels from
@@ -77,7 +73,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
7773
// find the allocator depending on context as we do for Shared and Device
7874
// allocations.
7975
try {
80-
*RetMem = Context->HostMemAllocContext->allocate(Size, Align);
76+
if (Pool) {
77+
*RetMem = Pool->HostMemPool->allocate(Size, Align);
78+
} else {
79+
*RetMem = Context->HostMemAllocContext->allocate(Size, Align);
80+
}
8181
if (IndirectAccessTrackingEnabled) {
8282
// Keep track of all memory allocations in the context
8383
Context->MemAllocs.emplace(std::piecewise_construct,
@@ -105,18 +105,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
105105
Size, ///< [in] size in bytes of the USM memory object to be allocated
106106
void **RetMem ///< [out] pointer to USM device memory object
107107
) {
108-
std::ignore = Pool;
109108

110-
uint32_t Alignment = USMDesc->align;
109+
uint32_t Alignment = USMDesc ? USMDesc->align : 0;
111110

112111
// L0 supports alignment up to 64KB and silently ignores higher values.
113112
// We flag alignment > 64KB as an invalid value.
114113
if (Alignment > 65536)
115114
return UR_RESULT_ERROR_INVALID_VALUE;
116115

117-
const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints;
118-
std::ignore = USMHintFlags;
119-
120116
ur_platform_handle_t Plt = Device->Platform;
121117

122118
// If indirect access tracking is enabled then lock the mutex which is
@@ -157,11 +153,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
157153
}
158154

159155
try {
160-
auto It = Context->DeviceMemAllocContexts.find(Device->ZeDevice);
161-
if (It == Context->DeviceMemAllocContexts.end())
162-
return UR_RESULT_ERROR_INVALID_VALUE;
163156

164-
*RetMem = It->second.allocate(Size, Alignment);
157+
if (Pool) {
158+
*RetMem = Pool->DeviceMemPools[Device]->allocate(Size, Alignment);
159+
} else {
160+
auto It = Context->DeviceMemAllocContexts.find(Device->ZeDevice);
161+
if (It == Context->DeviceMemAllocContexts.end())
162+
return UR_RESULT_ERROR_INVALID_VALUE;
163+
164+
*RetMem = It->second.allocate(Size, Alignment);
165+
}
165166
if (IndirectAccessTrackingEnabled) {
166167
// Keep track of all memory allocations in the context
167168
Context->MemAllocs.emplace(std::piecewise_construct,
@@ -190,17 +191,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
190191
Size, ///< [in] size in bytes of the USM memory object to be allocated
191192
void **RetMem ///< [out] pointer to USM shared memory object
192193
) {
193-
std::ignore = Pool;
194194

195-
uint32_t Alignment = USMDesc->align;
195+
uint32_t Alignment = USMDesc ? USMDesc->align : 0;
196196

197197
ur_usm_host_mem_flags_t UsmHostFlags{};
198198

199199
// See if the memory is going to be read-only on the device.
200200
bool DeviceReadOnly = false;
201201
ur_usm_device_mem_flags_t UsmDeviceFlags{};
202202

203-
void *pNext = const_cast<void *>(USMDesc->pNext);
203+
void *pNext = USMDesc ? const_cast<void *>(USMDesc->pNext) : nullptr;
204204
while (pNext != nullptr) {
205205
const ur_base_desc_t *BaseDesc =
206206
reinterpret_cast<const ur_base_desc_t *>(pNext);
@@ -259,13 +259,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
259259
}
260260

261261
try {
262-
auto &Allocator = (DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts
263-
: Context->SharedMemAllocContexts);
264-
auto It = Allocator.find(Device->ZeDevice);
265-
if (It == Allocator.end())
266-
return UR_RESULT_ERROR_INVALID_VALUE;
267-
268-
*RetMem = It->second.allocate(Size, Alignment);
262+
if (Pool) {
263+
if (DeviceReadOnly) {
264+
*RetMem =
265+
Pool->SharedMemReadOnlyPools[Device]->allocate(Size, Alignment);
266+
} else {
267+
*RetMem = Pool->SharedMemPools[Device]->allocate(Size, Alignment);
268+
}
269+
} else {
270+
auto &Allocator =
271+
(DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts
272+
: Context->SharedMemAllocContexts);
273+
auto It = Allocator.find(Device->ZeDevice);
274+
if (It == Allocator.end())
275+
return UR_RESULT_ERROR_INVALID_VALUE;
276+
277+
*RetMem = It->second.allocate(Size, Alignment);
278+
}
269279
if (DeviceReadOnly) {
270280
Context->SharedReadOnlyAllocs.insert(*RetMem);
271281
}
@@ -518,34 +528,87 @@ static ur_result_t USMAllocationMakeResident(
518528
return UR_RESULT_SUCCESS;
519529
}
520530

531+
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
532+
ur_usm_pool_desc_t *PoolDesc) {
533+
534+
zeroInit = static_cast<ur_bool_t>(PoolDesc->flags &
535+
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);
536+
537+
void *pNext = const_cast<void *>(PoolDesc->pNext);
538+
while (pNext != nullptr) {
539+
const ur_base_desc_t *BaseDesc =
540+
reinterpret_cast<const ur_base_desc_t *>(pNext);
541+
switch (BaseDesc->stype) {
542+
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
543+
const ur_usm_pool_limits_desc_t *Limits =
544+
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
545+
for (auto &config : USMAllocatorConfigs.Configs) {
546+
config.MaxPoolableSize = Limits->maxPoolableSize;
547+
config.SlabMinSize = Limits->minDriverAllocSize;
548+
}
549+
break;
550+
}
551+
default: {
552+
urPrint("urUSMPoolCreate: unexpected chained stype\n");
553+
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
554+
}
555+
}
556+
pNext = const_cast<void *>(BaseDesc->pNext);
557+
}
558+
559+
HostMemPool = std::make_unique<USMAllocContext>(
560+
std::unique_ptr<SystemMemory>(new USMHostMemoryAlloc(Context)),
561+
this->USMAllocatorConfigs.Configs[usm_settings::MemType::Host]);
562+
563+
for (auto device : Context->Devices) {
564+
DeviceMemPools[device] = std::make_unique<USMAllocContext>(
565+
std::unique_ptr<SystemMemory>(
566+
new USMDeviceMemoryAlloc(Context, device)),
567+
this->USMAllocatorConfigs.Configs[usm_settings::MemType::Device]);
568+
569+
SharedMemPools[device] = std::make_unique<USMAllocContext>(
570+
std::unique_ptr<SystemMemory>(
571+
new USMSharedMemoryAlloc(Context, device)),
572+
this->USMAllocatorConfigs.Configs[usm_settings::MemType::Shared]);
573+
SharedMemReadOnlyPools[device] = std::make_unique<USMAllocContext>(
574+
std::unique_ptr<SystemMemory>(
575+
new USMSharedMemoryAlloc(Context, device)),
576+
this->USMAllocatorConfigs
577+
.Configs[usm_settings::MemType::SharedReadOnly]);
578+
}
579+
}
580+
521581
UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
522582
ur_context_handle_t Context, ///< [in] handle of the context object
523583
ur_usm_pool_desc_t
524584
*PoolDesc, ///< [in] pointer to USM pool descriptor. Can be chained with
525585
///< ::ur_usm_pool_limits_desc_t
526586
ur_usm_pool_handle_t *Pool ///< [out] pointer to USM memory pool
527587
) {
528-
std::ignore = Context;
529-
std::ignore = PoolDesc;
530-
std::ignore = Pool;
531-
urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__);
532-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
588+
589+
try {
590+
*Pool = reinterpret_cast<ur_usm_pool_handle_t>(
591+
new ur_usm_pool_handle_t_(Context, PoolDesc));
592+
} catch (const UsmAllocationException &Ex) {
593+
return Ex.getError();
594+
}
595+
return UR_RESULT_SUCCESS;
533596
}
534597

535598
ur_result_t
536599
urUSMPoolRetain(ur_usm_pool_handle_t Pool ///< [in] pointer to USM memory pool
537600
) {
538-
std::ignore = Pool;
539-
urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__);
540-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
601+
Pool->RefCount.increment();
602+
return UR_RESULT_SUCCESS;
541603
}
542604

543605
ur_result_t
544606
urUSMPoolRelease(ur_usm_pool_handle_t Pool ///< [in] pointer to USM memory pool
545607
) {
546-
std::ignore = Pool;
547-
urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__);
548-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
608+
if (Pool->RefCount.decrementAndTest()) {
609+
delete Pool;
610+
}
611+
return UR_RESULT_SUCCESS;
549612
}
550613

551614
ur_result_t urUSMPoolGetInfo(

sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@
99

1010
#include "ur_level_zero_common.hpp"
1111

12+
struct ur_usm_pool_handle_t_ : _ur_object {
13+
bool zeroInit;
14+
15+
usm_settings::USMAllocatorConfig USMAllocatorConfigs;
16+
17+
std::unique_ptr<USMAllocContext> HostMemPool;
18+
std::unordered_map<ur_device_handle_t, std::unique_ptr<USMAllocContext>>
19+
SharedMemPools;
20+
std::unordered_map<ur_device_handle_t, std::unique_ptr<USMAllocContext>>
21+
SharedMemReadOnlyPools;
22+
std::unordered_map<ur_device_handle_t, std::unique_ptr<USMAllocContext>>
23+
DeviceMemPools;
24+
25+
ur_usm_pool_handle_t_(ur_context_handle_t Context,
26+
ur_usm_pool_desc_t *PoolDesc);
27+
};
28+
1229
// Exception type to pass allocation errors
1330
class UsmAllocationException {
1431
const ur_result_t Error;

0 commit comments

Comments
 (0)