diff --git a/source/adapters/cuda/context.hpp b/source/adapters/cuda/context.hpp index 96a1464a87..e84b4b7f7a 100644 --- a/source/adapters/cuda/context.hpp +++ b/source/adapters/cuda/context.hpp @@ -77,8 +77,9 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data); /// static ur_result_t -CreateHostMemoryProvider(ur_device_handle_t_ *DeviceHandle, - umf_memory_provider_handle_t *MemoryProviderHost) { +CreateHostMemoryProviderPool(ur_device_handle_t_ *DeviceHandle, + umf_memory_provider_handle_t *MemoryProviderHost, + umf_memory_pool_handle_t *MemoryPoolHost) { umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr; *MemoryProviderHost = nullptr; @@ -91,10 +92,20 @@ CreateHostMemoryProvider(ur_device_handle_t_ *DeviceHandle, umf::cuda_params_unique_handle_t CUMemoryProviderParamsUnique( CUMemoryProviderParams, umfCUDAMemoryProviderParamsDestroy); - // create UMF CUDA memory provider for the host memory (UMF_MEMORY_TYPE_HOST) - UmfResult = umf::createMemoryProvider( - CUMemoryProviderParamsUnique.get(), 0 /* cuDevice */, context, - UMF_MEMORY_TYPE_HOST, MemoryProviderHost); + UmfResult = umf::setCUMemoryProviderParams(CUMemoryProviderParamsUnique.get(), + 0 /* cuDevice */, context, + UMF_MEMORY_TYPE_HOST); + UMF_RETURN_UR_ERROR(UmfResult); + + // create UMF CUDA memory provider and pool for the host memory + // (UMF_MEMORY_TYPE_HOST) + UmfResult = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(), + CUMemoryProviderParamsUnique.get(), + MemoryProviderHost); + UMF_RETURN_UR_ERROR(UmfResult); + + UmfResult = umfPoolCreate(umfProxyPoolOps(), *MemoryProviderHost, nullptr, 0, + MemoryPoolHost); UMF_RETURN_UR_ERROR(UmfResult); return UR_RESULT_SUCCESS; @@ -112,8 +123,10 @@ struct ur_context_handle_t_ { std::vector Devices; std::atomic_uint32_t RefCount; - // UMF CUDA memory provider for the host memory (UMF_MEMORY_TYPE_HOST) + // UMF CUDA memory provider and pool for the host memory + // (UMF_MEMORY_TYPE_HOST) umf_memory_provider_handle_t MemoryProviderHost = nullptr; + umf_memory_pool_handle_t MemoryPoolHost = nullptr; ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) : Devices{Devs, Devs + NumDevices}, RefCount{1} { @@ -124,10 +137,14 @@ struct ur_context_handle_t_ { // Create UMF CUDA memory provider for the host memory // (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because // it is guaranteed to exist). - UR_CHECK_ERROR(CreateHostMemoryProvider(Devices[0], &MemoryProviderHost)); + UR_CHECK_ERROR(CreateHostMemoryProviderPool(Devices[0], &MemoryProviderHost, + &MemoryPoolHost)); }; ~ur_context_handle_t_() { + if (MemoryPoolHost) { + umfPoolDestroy(MemoryPoolHost); + } if (MemoryProviderHost) { umfMemoryProviderDestroy(MemoryProviderHost); } diff --git a/source/adapters/cuda/device.hpp b/source/adapters/cuda/device.hpp index eaf4ba6765..e94291367b 100644 --- a/source/adapters/cuda/device.hpp +++ b/source/adapters/cuda/device.hpp @@ -11,6 +11,7 @@ #include +#include #include #include "common.hpp" @@ -84,9 +85,17 @@ struct ur_device_handle_t_ { MemoryProviderDevice = nullptr; MemoryProviderShared = nullptr; + MemoryPoolDevice = nullptr; + MemoryPoolShared = nullptr; } ~ur_device_handle_t_() { + if (MemoryPoolDevice) { + umfPoolDestroy(MemoryPoolDevice); + } + if (MemoryPoolShared) { + umfPoolDestroy(MemoryPoolShared); + } if (MemoryProviderDevice) { umfMemoryProviderDestroy(MemoryProviderDevice); } @@ -131,11 +140,15 @@ struct ur_device_handle_t_ { // bookkeeping for mipmappedArray leaks in Mapping external Memory std::map ChildCuarrayFromMipmapMap; - // UMF CUDA memory provider for the device memory (UMF_MEMORY_TYPE_DEVICE) + // UMF CUDA memory provider and pool for the device memory + // (UMF_MEMORY_TYPE_DEVICE) umf_memory_provider_handle_t MemoryProviderDevice; + umf_memory_pool_handle_t MemoryPoolDevice; - // UMF CUDA memory provider for the shared memory (UMF_MEMORY_TYPE_SHARED) + // UMF CUDA memory provider and pool for the shared memory + // (UMF_MEMORY_TYPE_SHARED) umf_memory_provider_handle_t MemoryProviderShared; + umf_memory_pool_handle_t MemoryPoolShared; }; int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute); diff --git a/source/adapters/cuda/memory.cpp b/source/adapters/cuda/memory.cpp index 651fe0f43d..6e68275f3a 100644 --- a/source/adapters/cuda/memory.cpp +++ b/source/adapters/cuda/memory.cpp @@ -50,8 +50,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( cuMemHostRegister(HostPtr, size, CU_MEMHOSTREGISTER_DEVICEMAP)); AllocMode = BufferMem::AllocMode::UseHostPtr; } else if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) { - UMF_CHECK_ERROR(umfMemoryProviderAlloc(hContext->MemoryProviderHost, size, - 0, &HostPtr)); + HostPtr = umfPoolMalloc(hContext->MemoryPoolHost, size); + UMF_CHECK_PTR(HostPtr); AllocMode = BufferMem::AllocMode::AllocHostPtr; } else if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) { AllocMode = BufferMem::AllocMode::CopyIn; @@ -442,8 +442,8 @@ ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t Mem, CU_MEMHOSTALLOC_DEVICEMAP)); UR_CHECK_ERROR(cuMemHostGetDevicePointer(&DevPtr, Buffer.HostPtr, 0)); } else { - UMF_CHECK_ERROR(umfMemoryProviderAlloc(hDevice->MemoryProviderDevice, - Buffer.Size, 0, (void **)&DevPtr)); + *(void **)&DevPtr = umfPoolMalloc(hDevice->MemoryPoolDevice, Buffer.Size); + UMF_CHECK_PTR(*(void **)&DevPtr); } } else { CUarray ImageArray{}; diff --git a/source/adapters/cuda/memory.hpp b/source/adapters/cuda/memory.hpp index 6dcaa28414..f0fc14f864 100644 --- a/source/adapters/cuda/memory.hpp +++ b/source/adapters/cuda/memory.hpp @@ -158,7 +158,7 @@ struct BufferMem { case AllocMode::Classic: for (auto &DevPtr : Ptrs) { if (DevPtr != native_type{0}) { - UR_CHECK_ERROR(cuMemFree(DevPtr)); + UMF_CHECK_ERROR(umfFree((void *)DevPtr)); } } break; @@ -166,7 +166,7 @@ struct BufferMem { UR_CHECK_ERROR(cuMemHostUnregister(HostPtr)); break; case AllocMode::AllocHostPtr: - UR_CHECK_ERROR(cuMemFreeHost(HostPtr)); + UMF_CHECK_ERROR(umfFree((void *)HostPtr)); } return UR_RESULT_SUCCESS; } diff --git a/source/adapters/cuda/platform.cpp b/source/adapters/cuda/platform.cpp index d53a027160..ac66c39afb 100644 --- a/source/adapters/cuda/platform.cpp +++ b/source/adapters/cuda/platform.cpp @@ -20,7 +20,7 @@ #include static ur_result_t -CreateDeviceMemoryProviders(ur_platform_handle_t_ *Platform) { +CreateDeviceMemoryProvidersPools(ur_platform_handle_t_ *Platform) { umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr; umf_result_t UmfResult = @@ -37,16 +37,40 @@ CreateDeviceMemoryProviders(ur_platform_handle_t_ *Platform) { // create UMF CUDA memory provider for the device memory // (UMF_MEMORY_TYPE_DEVICE) - UmfResult = umf::createMemoryProvider( - CUMemoryProviderParamsUnique.get(), device, context, - UMF_MEMORY_TYPE_DEVICE, &device_handle->MemoryProviderDevice); + UmfResult = + umf::setCUMemoryProviderParams(CUMemoryProviderParamsUnique.get(), + device, context, UMF_MEMORY_TYPE_DEVICE); + UMF_RETURN_UR_ERROR(UmfResult); + + UmfResult = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(), + CUMemoryProviderParamsUnique.get(), + &device_handle->MemoryProviderDevice); UMF_RETURN_UR_ERROR(UmfResult); // create UMF CUDA memory provider for the shared memory // (UMF_MEMORY_TYPE_SHARED) - UmfResult = umf::createMemoryProvider( - CUMemoryProviderParamsUnique.get(), device, context, - UMF_MEMORY_TYPE_SHARED, &device_handle->MemoryProviderShared); + UmfResult = + umf::setCUMemoryProviderParams(CUMemoryProviderParamsUnique.get(), + device, context, UMF_MEMORY_TYPE_SHARED); + UMF_RETURN_UR_ERROR(UmfResult); + + UmfResult = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(), + CUMemoryProviderParamsUnique.get(), + &device_handle->MemoryProviderShared); + UMF_RETURN_UR_ERROR(UmfResult); + + // create UMF CUDA memory pool for the device memory + // (UMF_MEMORY_TYPE_DEVICE) + UmfResult = + umfPoolCreate(umfProxyPoolOps(), device_handle->MemoryProviderDevice, + nullptr, 0, &device_handle->MemoryPoolDevice); + UMF_RETURN_UR_ERROR(UmfResult); + + // create UMF CUDA memory pool for the shared memory + // (UMF_MEMORY_TYPE_SHARED) + UmfResult = + umfPoolCreate(umfProxyPoolOps(), device_handle->MemoryProviderShared, + nullptr, 0, &device_handle->MemoryPoolShared); UMF_RETURN_UR_ERROR(UmfResult); } @@ -134,7 +158,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, static_cast(i)}); } - UR_CHECK_ERROR(CreateDeviceMemoryProviders(&Platform)); + UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(&Platform)); } catch (const std::bad_alloc &) { // Signal out-of-memory situation for (int i = 0; i < NumDevices; ++i) { diff --git a/source/adapters/cuda/usm.cpp b/source/adapters/cuda/usm.cpp index e40927b7a8..5d2d43442d 100644 --- a/source/adapters/cuda/usm.cpp +++ b/source/adapters/cuda/usm.cpp @@ -102,54 +102,12 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice, return UR_RESULT_SUCCESS; } -ur_result_t USMFreeImpl(ur_context_handle_t hContext, void *Pointer) { - ur_result_t Result = UR_RESULT_SUCCESS; - try { - unsigned int IsManaged; - unsigned int Type; - unsigned int DeviceOrdinal; - const int NumAttributes = 3; - void *AttributeValues[NumAttributes] = {&IsManaged, &Type, &DeviceOrdinal}; - - CUpointer_attribute Attributes[NumAttributes] = { - CU_POINTER_ATTRIBUTE_IS_MANAGED, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, - CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL}; - UR_CHECK_ERROR(cuPointerGetAttributes( - NumAttributes, Attributes, AttributeValues, (CUdeviceptr)Pointer)); - UR_ASSERT(Type == CU_MEMORYTYPE_DEVICE || Type == CU_MEMORYTYPE_HOST, - UR_RESULT_ERROR_INVALID_MEM_OBJECT); - - std::vector ContextDevices = hContext->getDevices(); - ur_platform_handle_t Platform = ContextDevices[0]->getPlatform(); - unsigned int NumDevices = Platform->Devices.size(); - UR_ASSERT(DeviceOrdinal < NumDevices, UR_RESULT_ERROR_INVALID_DEVICE); - - ur_device_handle_t Device = Platform->Devices[DeviceOrdinal].get(); - umf_memory_provider_handle_t MemoryProvider; - - if (IsManaged) { - MemoryProvider = Device->MemoryProviderShared; - } else if (Type == CU_MEMORYTYPE_DEVICE) { - MemoryProvider = Device->MemoryProviderDevice; - } else { - MemoryProvider = hContext->MemoryProviderHost; - } - - UMF_CHECK_ERROR(umfMemoryProviderFree(MemoryProvider, Pointer, - 0 /* size is unknown */)); - } catch (ur_result_t Err) { - Result = Err; - } - return Result; -} - /// USM: Frees the given USM pointer associated with the context. /// UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, void *pMem) { - if (auto Pool = umfPoolByPtr(pMem)) - return umf::umf2urResult(umfPoolFree(Pool, pMem)); - return USMFreeImpl(hContext, pMem); + (void)hContext; // unused + return umf::umf2urResult(umfFree(pMem)); } ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t, @@ -158,8 +116,8 @@ ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t, uint32_t Alignment) { try { ScopedContext Active(Device); - UMF_CHECK_ERROR(umfMemoryProviderAlloc(Device->MemoryProviderDevice, Size, - Alignment, ResultPtr)); + *ResultPtr = umfPoolMalloc(Device->MemoryPoolDevice, Size); + UMF_CHECK_PTR(*ResultPtr); } catch (ur_result_t Err) { return Err; } @@ -180,8 +138,8 @@ ur_result_t USMSharedAllocImpl(void **ResultPtr, ur_context_handle_t, uint32_t Alignment) { try { ScopedContext Active(Device); - UMF_CHECK_ERROR(umfMemoryProviderAlloc(Device->MemoryProviderShared, Size, - Alignment, ResultPtr)); + *ResultPtr = umfPoolMalloc(Device->MemoryPoolShared, Size); + UMF_CHECK_PTR(*ResultPtr); } catch (ur_result_t Err) { return Err; } @@ -199,8 +157,8 @@ ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t hContext, ur_usm_host_mem_flags_t, size_t Size, uint32_t Alignment) { try { - UMF_CHECK_ERROR(umfMemoryProviderAlloc(hContext->MemoryProviderHost, Size, - Alignment, ResultPtr)); + *ResultPtr = umfPoolMalloc(hContext->MemoryPoolHost, Size); + UMF_CHECK_PTR(*ResultPtr); } catch (ur_result_t Err) { return Err; } @@ -326,73 +284,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMReleaseExp(ur_context_handle_t Context, return UR_RESULT_SUCCESS; } -umf_result_t USMMemoryProvider::initialize(ur_context_handle_t Ctx, - ur_device_handle_t Dev) { - Context = Ctx; - Device = Dev; - // There isn't a way to query this in cuda, and there isn't much info on - // cuda's approach to alignment or transfer granularity between host and - // device. Within UMF this is only used to influence alignment, and since we - // discard that in our alloc implementations it seems we can safely ignore - // this as well, for now. - MinPageSize = 0; - - return UMF_RESULT_SUCCESS; -} - -enum umf_result_t USMMemoryProvider::alloc(size_t Size, size_t Align, - void **Ptr) { - auto Res = allocateImpl(Ptr, Size, Align); - if (Res != UR_RESULT_SUCCESS) { - getLastStatusRef() = Res; - return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC; - } - - return UMF_RESULT_SUCCESS; -} - -enum umf_result_t USMMemoryProvider::free(void *Ptr, size_t Size) { - (void)Size; - - auto Res = USMFreeImpl(Context, Ptr); - if (Res != UR_RESULT_SUCCESS) { - getLastStatusRef() = Res; - return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC; - } - - return UMF_RESULT_SUCCESS; -} - -void USMMemoryProvider::get_last_native_error(const char **ErrMsg, - int32_t *ErrCode) { - (void)ErrMsg; - *ErrCode = static_cast(getLastStatusRef()); -} - -umf_result_t USMMemoryProvider::get_min_page_size(void *Ptr, size_t *PageSize) { - (void)Ptr; - *PageSize = MinPageSize; - - return UMF_RESULT_SUCCESS; -} - -ur_result_t USMSharedMemoryProvider::allocateImpl(void **ResultPtr, size_t Size, - uint32_t Alignment) { - return USMSharedAllocImpl(ResultPtr, Context, Device, /*host flags*/ 0, - /*device flags*/ 0, Size, Alignment); -} - -ur_result_t USMDeviceMemoryProvider::allocateImpl(void **ResultPtr, size_t Size, - uint32_t Alignment) { - return USMDeviceAllocImpl(ResultPtr, Context, Device, /* flags */ 0, Size, - Alignment); -} - -ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size, - uint32_t Alignment) { - return USMHostAllocImpl(ResultPtr, Context, /* flags */ 0, Size, Alignment); -} - ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc) : Context{Context} { @@ -416,36 +307,28 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, pNext = BaseDesc->pNext; } - auto MemProvider = - umf::memoryProviderMakeUnique(Context, nullptr) - .second; - auto UmfHostParamsHandle = getUmfParamsHandle( DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Host]); - HostMemPool = - umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(MemProvider), - UmfHostParamsHandle.get()) - .second; + HostMemPool = umf::poolMakeUniqueFromOpsProviderHandle( + umfDisjointPoolOps(), Context->MemoryProviderHost, + UmfHostParamsHandle.get()) + .second; for (const auto &Device : Context->getDevices()) { - MemProvider = - umf::memoryProviderMakeUnique(Context, Device) - .second; auto UmfDeviceParamsHandle = getUmfParamsHandle( DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Device]); - DeviceMemPool = - umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(MemProvider), - UmfDeviceParamsHandle.get()) - .second; - MemProvider = - umf::memoryProviderMakeUnique(Context, Device) - .second; + DeviceMemPool = umf::poolMakeUniqueFromOpsProviderHandle( + umfDisjointPoolOps(), Device->MemoryProviderDevice, + UmfDeviceParamsHandle.get()) + .second; + auto UmfSharedParamsHandle = getUmfParamsHandle( DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Shared]); - SharedMemPool = - umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(MemProvider), - UmfSharedParamsHandle.get()) - .second; + SharedMemPool = umf::poolMakeUniqueFromOpsProviderHandle( + umfDisjointPoolOps(), Device->MemoryProviderShared, + UmfSharedParamsHandle.get()) + .second; + Context->addPool(this); } } diff --git a/source/adapters/cuda/usm.hpp b/source/adapters/cuda/usm.hpp index 7c6a2ea666..8258043d2b 100644 --- a/source/adapters/cuda/usm.hpp +++ b/source/adapters/cuda/usm.hpp @@ -48,80 +48,6 @@ class UsmAllocationException { ur_result_t getError() const { return Error; } }; -// Implements memory allocation via driver API for USM allocator interface. -class USMMemoryProvider { -private: - ur_result_t &getLastStatusRef() { - static thread_local ur_result_t LastStatus = UR_RESULT_SUCCESS; - return LastStatus; - } - -protected: - ur_context_handle_t Context; - ur_device_handle_t Device; - size_t MinPageSize; - - // Internal allocation routine which must be implemented for each allocation - // type - virtual ur_result_t allocateImpl(void **ResultPtr, size_t Size, - uint32_t Alignment) = 0; - -public: - umf_result_t initialize(ur_context_handle_t Ctx, ur_device_handle_t Dev); - umf_result_t alloc(size_t Size, size_t Align, void **Ptr); - umf_result_t free(void *Ptr, size_t Size); - void get_last_native_error(const char **ErrMsg, int32_t *ErrCode); - umf_result_t get_min_page_size(void *, size_t *); - umf_result_t get_recommended_page_size(size_t, size_t *) { - return UMF_RESULT_ERROR_NOT_SUPPORTED; - }; - umf_result_t purge_lazy(void *, size_t) { - return UMF_RESULT_ERROR_NOT_SUPPORTED; - }; - umf_result_t purge_force(void *, size_t) { - return UMF_RESULT_ERROR_NOT_SUPPORTED; - }; - umf_result_t allocation_merge(void *, void *, size_t) { - return UMF_RESULT_ERROR_UNKNOWN; - } - umf_result_t allocation_split(void *, size_t, size_t) { - return UMF_RESULT_ERROR_UNKNOWN; - } - virtual const char *get_name() = 0; - - virtual ~USMMemoryProvider() = default; -}; - -// Allocation routines for shared memory type -class USMSharedMemoryProvider final : public USMMemoryProvider { -public: - const char *get_name() override { return "USMSharedMemoryProvider"; } - -protected: - ur_result_t allocateImpl(void **ResultPtr, size_t Size, - uint32_t Alignment) override; -}; - -// Allocation routines for device memory type -class USMDeviceMemoryProvider final : public USMMemoryProvider { -public: - const char *get_name() override { return "USMSharedMemoryProvider"; } - -protected: - ur_result_t allocateImpl(void **ResultPtr, size_t Size, - uint32_t Alignment) override; -}; - -// Allocation routines for host memory type -class USMHostMemoryProvider final : public USMMemoryProvider { -public: - const char *get_name() override { return "USMSharedMemoryProvider"; } - -protected: - ur_result_t allocateImpl(void **ResultPtr, size_t Size, - uint32_t Alignment) override; -}; - ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t Context, ur_device_handle_t Device, ur_usm_device_mem_flags_t Flags, size_t Size, diff --git a/source/common/umf_helpers.hpp b/source/common/umf_helpers.hpp index 2433560a39..8f8f4c3c8c 100644 --- a/source/common/umf_helpers.hpp +++ b/source/common/umf_helpers.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -30,6 +31,13 @@ #define UMF_CHECK_ERROR(UmfResult) UR_CHECK_ERROR(umf::umf2urResult(UmfResult)); +#define UMF_CHECK_PTR(ptr) \ + do { \ + if ((ptr) == nullptr) { \ + UR_CHECK_ERROR(UR_RESULT_ERROR_OUT_OF_HOST_MEMORY); \ + } \ + } while (0) + #define UMF_RETURN_UMF_ERROR(UmfResult) \ do { \ umf_result_t UmfResult_ = (UmfResult); \ @@ -243,6 +251,21 @@ static inline auto poolMakeUniqueFromOps(umf_memory_pool_ops_t *ops, UMF_RESULT_SUCCESS, pool_unique_handle_t(hPool, umfPoolDestroy)}; } +static inline auto +poolMakeUniqueFromOpsProviderHandle(umf_memory_pool_ops_t *ops, + umf_memory_provider_handle_t provider, + void *params) { + umf_memory_pool_handle_t hPool; + auto ret = umfPoolCreate(ops, provider, params, 0, &hPool); + if (ret != UMF_RESULT_SUCCESS) { + return std::pair{ + ret, pool_unique_handle_t(nullptr, nullptr)}; + } + + return std::pair{ + UMF_RESULT_SUCCESS, pool_unique_handle_t(hPool, umfPoolDestroy)}; +} + static inline auto providerMakeUniqueFromOps(umf_memory_provider_ops_t *ops, void *params) { umf_memory_provider_handle_t hProvider; @@ -301,10 +324,9 @@ inline ur_result_t umf2urResult(umf_result_t umfResult) { }; } -inline umf_result_t createMemoryProvider( +inline umf_result_t setCUMemoryProviderParams( umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams, - int cuDevice, void *cuContext, umf_usm_memory_type_t memType, - umf_memory_provider_handle_t *provider) { + int cuDevice, void *cuContext, umf_usm_memory_type_t memType) { umf_result_t UmfResult = umfCUDAMemoryProviderParamsSetContext(CUMemoryProviderParams, cuContext); @@ -318,13 +340,6 @@ inline umf_result_t createMemoryProvider( umfCUDAMemoryProviderParamsSetMemoryType(CUMemoryProviderParams, memType); UMF_RETURN_UMF_ERROR(UmfResult); - umf_memory_provider_handle_t umfCUDAprovider = nullptr; - UmfResult = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(), - CUMemoryProviderParams, &umfCUDAprovider); - UMF_RETURN_UMF_ERROR(UmfResult); - - *provider = umfCUDAprovider; - return UMF_RESULT_SUCCESS; }