@@ -24,17 +24,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
24
24
Size , // /< [in] size in bytes of the USM memory object to be allocated
25
25
void **RetMem // /< [out] pointer to USM host memory object
26
26
) {
27
- std::ignore = Pool;
28
27
29
- uint32_t Align = USMDesc->align ;
28
+ uint32_t Align = USMDesc ? USMDesc ->align : 0 ;
30
29
// L0 supports alignment up to 64KB and silently ignores higher values.
31
30
// We flag alignment > 64KB as an invalid value.
32
31
if (Align > 65536 )
33
32
return UR_RESULT_ERROR_INVALID_VALUE;
34
33
35
- const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints ;
36
- std::ignore = USMHintFlags;
37
-
38
34
ur_platform_handle_t Plt = Context->getPlatform ();
39
35
// If indirect access tracking is enabled then lock the mutex which is
40
36
// guarding contexts container in the platform. This prevents new kernels from
@@ -77,7 +73,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
77
73
// find the allocator depending on context as we do for Shared and Device
78
74
// allocations.
79
75
try {
80
- *RetMem = Context->HostMemAllocContext ->allocate (Size , Align);
76
+ if (Pool) {
77
+ *RetMem = Pool->HostMemPool ->allocate (Size , Align);
78
+ } else {
79
+ *RetMem = Context->HostMemAllocContext ->allocate (Size , Align);
80
+ }
81
81
if (IndirectAccessTrackingEnabled) {
82
82
// Keep track of all memory allocations in the context
83
83
Context->MemAllocs .emplace (std::piecewise_construct,
@@ -105,18 +105,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
105
105
Size , // /< [in] size in bytes of the USM memory object to be allocated
106
106
void **RetMem // /< [out] pointer to USM device memory object
107
107
) {
108
- std::ignore = Pool;
109
108
110
- uint32_t Alignment = USMDesc->align ;
109
+ uint32_t Alignment = USMDesc ? USMDesc ->align : 0 ;
111
110
112
111
// L0 supports alignment up to 64KB and silently ignores higher values.
113
112
// We flag alignment > 64KB as an invalid value.
114
113
if (Alignment > 65536 )
115
114
return UR_RESULT_ERROR_INVALID_VALUE;
116
115
117
- const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints ;
118
- std::ignore = USMHintFlags;
119
-
120
116
ur_platform_handle_t Plt = Device->Platform ;
121
117
122
118
// If indirect access tracking is enabled then lock the mutex which is
@@ -157,11 +153,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
157
153
}
158
154
159
155
try {
160
- auto It = Context->DeviceMemAllocContexts .find (Device->ZeDevice );
161
- if (It == Context->DeviceMemAllocContexts .end ())
162
- return UR_RESULT_ERROR_INVALID_VALUE;
163
156
164
- *RetMem = It->second .allocate (Size , Alignment);
157
+ if (Pool) {
158
+ *RetMem = Pool->DeviceMemPools [Device]->allocate (Size , Alignment);
159
+ } else {
160
+ auto It = Context->DeviceMemAllocContexts .find (Device->ZeDevice );
161
+ if (It == Context->DeviceMemAllocContexts .end ())
162
+ return UR_RESULT_ERROR_INVALID_VALUE;
163
+
164
+ *RetMem = It->second .allocate (Size , Alignment);
165
+ }
165
166
if (IndirectAccessTrackingEnabled) {
166
167
// Keep track of all memory allocations in the context
167
168
Context->MemAllocs .emplace (std::piecewise_construct,
@@ -190,17 +191,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
190
191
Size , // /< [in] size in bytes of the USM memory object to be allocated
191
192
void **RetMem // /< [out] pointer to USM shared memory object
192
193
) {
193
- std::ignore = Pool;
194
194
195
- uint32_t Alignment = USMDesc->align ;
195
+ uint32_t Alignment = USMDesc ? USMDesc ->align : 0 ;
196
196
197
197
ur_usm_host_mem_flags_t UsmHostFlags{};
198
198
199
199
// See if the memory is going to be read-only on the device.
200
200
bool DeviceReadOnly = false ;
201
201
ur_usm_device_mem_flags_t UsmDeviceFlags{};
202
202
203
- void *pNext = const_cast <void *>(USMDesc->pNext );
203
+ void *pNext = USMDesc ? const_cast <void *>(USMDesc->pNext ) : nullptr ;
204
204
while (pNext != nullptr ) {
205
205
const ur_base_desc_t *BaseDesc =
206
206
reinterpret_cast <const ur_base_desc_t *>(pNext);
@@ -259,13 +259,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
259
259
}
260
260
261
261
try {
262
- auto &Allocator = (DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts
263
- : Context->SharedMemAllocContexts );
264
- auto It = Allocator.find (Device->ZeDevice );
265
- if (It == Allocator.end ())
266
- return UR_RESULT_ERROR_INVALID_VALUE;
267
-
268
- *RetMem = It->second .allocate (Size , Alignment);
262
+ if (Pool) {
263
+ if (DeviceReadOnly) {
264
+ *RetMem =
265
+ Pool->SharedMemReadOnlyPools [Device]->allocate (Size , Alignment);
266
+ } else {
267
+ *RetMem = Pool->SharedMemPools [Device]->allocate (Size , Alignment);
268
+ }
269
+ } else {
270
+ auto &Allocator =
271
+ (DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts
272
+ : Context->SharedMemAllocContexts );
273
+ auto It = Allocator.find (Device->ZeDevice );
274
+ if (It == Allocator.end ())
275
+ return UR_RESULT_ERROR_INVALID_VALUE;
276
+
277
+ *RetMem = It->second .allocate (Size , Alignment);
278
+ }
269
279
if (DeviceReadOnly) {
270
280
Context->SharedReadOnlyAllocs .insert (*RetMem);
271
281
}
@@ -518,34 +528,87 @@ static ur_result_t USMAllocationMakeResident(
518
528
return UR_RESULT_SUCCESS;
519
529
}
520
530
531
+ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
532
+ ur_usm_pool_desc_t *PoolDesc) {
533
+
534
+ zeroInit = static_cast <ur_bool_t >(PoolDesc->flags &
535
+ UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);
536
+
537
+ void *pNext = const_cast <void *>(PoolDesc->pNext );
538
+ while (pNext != nullptr ) {
539
+ const ur_base_desc_t *BaseDesc =
540
+ reinterpret_cast <const ur_base_desc_t *>(pNext);
541
+ switch (BaseDesc->stype ) {
542
+ case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
543
+ const ur_usm_pool_limits_desc_t *Limits =
544
+ reinterpret_cast <const ur_usm_pool_limits_desc_t *>(BaseDesc);
545
+ for (auto &config : USMAllocatorConfigs.Configs ) {
546
+ config.MaxPoolableSize = Limits->maxPoolableSize ;
547
+ config.SlabMinSize = Limits->minDriverAllocSize ;
548
+ }
549
+ break ;
550
+ }
551
+ default : {
552
+ urPrint (" urUSMPoolCreate: unexpected chained stype\n " );
553
+ throw UsmAllocationException (UR_RESULT_ERROR_INVALID_ARGUMENT);
554
+ }
555
+ }
556
+ pNext = const_cast <void *>(BaseDesc->pNext );
557
+ }
558
+
559
+ HostMemPool = std::make_unique<USMAllocContext>(
560
+ std::unique_ptr<SystemMemory>(new USMHostMemoryAlloc (Context)),
561
+ this ->USMAllocatorConfigs .Configs [usm_settings::MemType::Host]);
562
+
563
+ for (auto device : Context->Devices ) {
564
+ DeviceMemPools[device] = std::make_unique<USMAllocContext>(
565
+ std::unique_ptr<SystemMemory>(
566
+ new USMDeviceMemoryAlloc (Context, device)),
567
+ this ->USMAllocatorConfigs .Configs [usm_settings::MemType::Device]);
568
+
569
+ SharedMemPools[device] = std::make_unique<USMAllocContext>(
570
+ std::unique_ptr<SystemMemory>(
571
+ new USMSharedMemoryAlloc (Context, device)),
572
+ this ->USMAllocatorConfigs .Configs [usm_settings::MemType::Shared]);
573
+ SharedMemReadOnlyPools[device] = std::make_unique<USMAllocContext>(
574
+ std::unique_ptr<SystemMemory>(
575
+ new USMSharedMemoryAlloc (Context, device)),
576
+ this ->USMAllocatorConfigs
577
+ .Configs [usm_settings::MemType::SharedReadOnly]);
578
+ }
579
+ }
580
+
521
581
UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate (
522
582
ur_context_handle_t Context, // /< [in] handle of the context object
523
583
ur_usm_pool_desc_t
524
584
*PoolDesc, // /< [in] pointer to USM pool descriptor. Can be chained with
525
585
// /< ::ur_usm_pool_limits_desc_t
526
586
ur_usm_pool_handle_t *Pool // /< [out] pointer to USM memory pool
527
587
) {
528
- std::ignore = Context;
529
- std::ignore = PoolDesc;
530
- std::ignore = Pool;
531
- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
532
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
588
+
589
+ try {
590
+ *Pool = reinterpret_cast <ur_usm_pool_handle_t >(
591
+ new ur_usm_pool_handle_t_ (Context, PoolDesc));
592
+ } catch (const UsmAllocationException &Ex) {
593
+ return Ex.getError ();
594
+ }
595
+ return UR_RESULT_SUCCESS;
533
596
}
534
597
535
598
ur_result_t
536
599
urUSMPoolRetain (ur_usm_pool_handle_t Pool // /< [in] pointer to USM memory pool
537
600
) {
538
- std::ignore = Pool;
539
- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
540
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
601
+ Pool->RefCount .increment ();
602
+ return UR_RESULT_SUCCESS;
541
603
}
542
604
543
605
ur_result_t
544
606
urUSMPoolRelease (ur_usm_pool_handle_t Pool // /< [in] pointer to USM memory pool
545
607
) {
546
- std::ignore = Pool;
547
- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
548
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
608
+ if (Pool->RefCount .decrementAndTest ()) {
609
+ delete Pool;
610
+ }
611
+ return UR_RESULT_SUCCESS;
549
612
}
550
613
551
614
ur_result_t urUSMPoolGetInfo (
0 commit comments