@@ -125,18 +125,17 @@ ur_result_t USMFreeImpl(ur_context_handle_t hContext, void *Pointer) {
125
125
UR_ASSERT (DeviceOrdinal < NumDevices, UR_RESULT_ERROR_INVALID_DEVICE);
126
126
127
127
ur_device_handle_t Device = Platform->Devices [DeviceOrdinal].get ();
128
- umf_memory_provider_handle_t MemoryProvider ;
128
+ umf_memory_pool_handle_t MemoryPool ;
129
129
130
130
if (IsManaged) {
131
- MemoryProvider = Device->MemoryProviderShared ;
131
+ MemoryPool = Device->MemoryPoolShared ;
132
132
} else if (Type == CU_MEMORYTYPE_DEVICE) {
133
- MemoryProvider = Device->MemoryProviderDevice ;
133
+ MemoryPool = Device->MemoryPoolDevice ;
134
134
} else {
135
- MemoryProvider = hContext->MemoryProviderHost ;
135
+ MemoryPool = hContext->MemoryPoolHost ;
136
136
}
137
137
138
- UMF_CHECK_ERROR (umfMemoryProviderFree (MemoryProvider, Pointer,
139
- 0 /* size is unknown */ ));
138
+ UMF_CHECK_ERROR (umfPoolFree (MemoryPool, Pointer));
140
139
} catch (ur_result_t Err) {
141
140
Result = Err;
142
141
}
@@ -158,8 +157,8 @@ ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t,
158
157
uint32_t Alignment) {
159
158
try {
160
159
ScopedContext Active (Device);
161
- UMF_CHECK_ERROR ( umfMemoryProviderAlloc ( Device->MemoryProviderDevice , Size ,
162
- Alignment, ResultPtr) );
160
+ *ResultPtr = umfPoolMalloc ( Device->MemoryPoolDevice , Size );
161
+ UMF_CHECK_PTR (* ResultPtr);
163
162
} catch (ur_result_t Err) {
164
163
return Err;
165
164
}
@@ -180,8 +179,8 @@ ur_result_t USMSharedAllocImpl(void **ResultPtr, ur_context_handle_t,
180
179
uint32_t Alignment) {
181
180
try {
182
181
ScopedContext Active (Device);
183
- UMF_CHECK_ERROR ( umfMemoryProviderAlloc ( Device->MemoryProviderShared , Size ,
184
- Alignment, ResultPtr) );
182
+ *ResultPtr = umfPoolMalloc ( Device->MemoryPoolShared , Size );
183
+ UMF_CHECK_PTR (* ResultPtr);
185
184
} catch (ur_result_t Err) {
186
185
return Err;
187
186
}
@@ -199,8 +198,8 @@ ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t hContext,
199
198
ur_usm_host_mem_flags_t , size_t Size ,
200
199
uint32_t Alignment) {
201
200
try {
202
- UMF_CHECK_ERROR ( umfMemoryProviderAlloc ( hContext->MemoryProviderHost , Size ,
203
- Alignment, ResultPtr) );
201
+ *ResultPtr = umfPoolMalloc ( hContext->MemoryPoolHost , Size );
202
+ UMF_CHECK_PTR (* ResultPtr);
204
203
} catch (ur_result_t Err) {
205
204
return Err;
206
205
}
@@ -326,73 +325,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMReleaseExp(ur_context_handle_t Context,
326
325
return UR_RESULT_SUCCESS;
327
326
}
328
327
329
- umf_result_t USMMemoryProvider::initialize (ur_context_handle_t Ctx,
330
- ur_device_handle_t Dev) {
331
- Context = Ctx;
332
- Device = Dev;
333
- // There isn't a way to query this in cuda, and there isn't much info on
334
- // cuda's approach to alignment or transfer granularity between host and
335
- // device. Within UMF this is only used to influence alignment, and since we
336
- // discard that in our alloc implementations it seems we can safely ignore
337
- // this as well, for now.
338
- MinPageSize = 0 ;
339
-
340
- return UMF_RESULT_SUCCESS;
341
- }
342
-
343
- enum umf_result_t USMMemoryProvider::alloc (size_t Size , size_t Align,
344
- void **Ptr ) {
345
- auto Res = allocateImpl (Ptr , Size , Align);
346
- if (Res != UR_RESULT_SUCCESS) {
347
- getLastStatusRef () = Res;
348
- return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
349
- }
350
-
351
- return UMF_RESULT_SUCCESS;
352
- }
353
-
354
- enum umf_result_t USMMemoryProvider::free (void *Ptr , size_t Size ) {
355
- (void )Size ;
356
-
357
- auto Res = USMFreeImpl (Context, Ptr );
358
- if (Res != UR_RESULT_SUCCESS) {
359
- getLastStatusRef () = Res;
360
- return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
361
- }
362
-
363
- return UMF_RESULT_SUCCESS;
364
- }
365
-
366
- void USMMemoryProvider::get_last_native_error (const char **ErrMsg,
367
- int32_t *ErrCode) {
368
- (void )ErrMsg;
369
- *ErrCode = static_cast <int32_t >(getLastStatusRef ());
370
- }
371
-
372
- umf_result_t USMMemoryProvider::get_min_page_size (void *Ptr , size_t *PageSize) {
373
- (void )Ptr ;
374
- *PageSize = MinPageSize;
375
-
376
- return UMF_RESULT_SUCCESS;
377
- }
378
-
379
- ur_result_t USMSharedMemoryProvider::allocateImpl (void **ResultPtr, size_t Size ,
380
- uint32_t Alignment) {
381
- return USMSharedAllocImpl (ResultPtr, Context, Device, /* host flags*/ 0 ,
382
- /* device flags*/ 0 , Size , Alignment);
383
- }
384
-
385
- ur_result_t USMDeviceMemoryProvider::allocateImpl (void **ResultPtr, size_t Size ,
386
- uint32_t Alignment) {
387
- return USMDeviceAllocImpl (ResultPtr, Context, Device, /* flags */ 0 , Size ,
388
- Alignment);
389
- }
390
-
391
- ur_result_t USMHostMemoryProvider::allocateImpl (void **ResultPtr, size_t Size ,
392
- uint32_t Alignment) {
393
- return USMHostAllocImpl (ResultPtr, Context, /* flags */ 0 , Size , Alignment);
394
- }
395
-
396
328
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
397
329
ur_usm_pool_desc_t *PoolDesc)
398
330
: Context{Context} {
@@ -416,36 +348,28 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
416
348
pNext = BaseDesc->pNext ;
417
349
}
418
350
419
- auto MemProvider =
420
- umf::memoryProviderMakeUnique<USMHostMemoryProvider>(Context, nullptr )
421
- .second ;
422
-
423
351
auto UmfHostParamsHandle = getUmfParamsHandle (
424
352
DisjointPoolConfigs.Configs [usm::DisjointPoolMemType::Host]);
425
- HostMemPool =
426
- umf::poolMakeUniqueFromOps ( umfDisjointPoolOps (), std::move (MemProvider) ,
427
- UmfHostParamsHandle.get ())
428
- .second ;
353
+ HostMemPool = umf::poolMakeUniqueFromOps_CudaProvider (
354
+ umfDisjointPoolOps (), Context-> MemoryProviderHost ,
355
+ UmfHostParamsHandle.get ())
356
+ .second ;
429
357
430
358
for (const auto &Device : Context->getDevices ()) {
431
- MemProvider =
432
- umf::memoryProviderMakeUnique<USMDeviceMemoryProvider>(Context, Device)
433
- .second ;
434
359
auto UmfDeviceParamsHandle = getUmfParamsHandle (
435
360
DisjointPoolConfigs.Configs [usm::DisjointPoolMemType::Device]);
436
- DeviceMemPool =
437
- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (MemProvider),
438
- UmfDeviceParamsHandle.get ())
439
- .second ;
440
- MemProvider =
441
- umf::memoryProviderMakeUnique<USMSharedMemoryProvider>(Context, Device)
442
- .second ;
361
+ DeviceMemPool = umf::poolMakeUniqueFromOps_CudaProvider (
362
+ umfDisjointPoolOps (), Device->MemoryProviderDevice ,
363
+ UmfDeviceParamsHandle.get ())
364
+ .second ;
365
+
443
366
auto UmfSharedParamsHandle = getUmfParamsHandle (
444
367
DisjointPoolConfigs.Configs [usm::DisjointPoolMemType::Shared]);
445
- SharedMemPool =
446
- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (MemProvider),
447
- UmfSharedParamsHandle.get ())
448
- .second ;
368
+ SharedMemPool = umf::poolMakeUniqueFromOps_CudaProvider (
369
+ umfDisjointPoolOps (), Device->MemoryProviderShared ,
370
+ UmfSharedParamsHandle.get ())
371
+ .second ;
372
+
449
373
Context->addPool (this );
450
374
}
451
375
}
0 commit comments