Skip to content

Commit f7e21bb

Browse files
committed
Use UMF CUDA provider in UR
Signed-off-by: Lukasz Dorau <[email protected]>
1 parent e2df8ac commit f7e21bb

File tree

7 files changed

+183
-21
lines changed

7 files changed

+183
-21
lines changed

source/adapters/cuda/context.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "common.hpp"
2121
#include "device.hpp"
22+
#include "umf_helpers.hpp"
2223

2324
#include <umf/memory_pool.h>
2425

@@ -74,6 +75,31 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
7475
/// if necessary.
7576
///
7677
///
78+
79+
static ur_result_t
80+
CreateHostMemoryProvider(ur_device_handle_t_ *DeviceHandle,
81+
umf_memory_provider_handle_t *MemoryProviderHost) {
82+
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr;
83+
84+
*MemoryProviderHost = nullptr;
85+
CUcontext context = DeviceHandle->getNativeContext();
86+
87+
umf_result_t UmfResult =
88+
umfCUDAMemoryProviderParamsCreate(&CUMemoryProviderParams);
89+
UMF_RETURN_UR_ERROR(UmfResult);
90+
91+
umf::cuda_params_unique_handle_t CUMemoryProviderParamsUnique(
92+
CUMemoryProviderParams, umfCUDAMemoryProviderParamsDestroy);
93+
94+
// create UMF CUDA memory provider for the host memory (UMF_MEMORY_TYPE_HOST)
95+
UmfResult = umf::createMemoryProvider(
96+
CUMemoryProviderParamsUnique.get(), 0 /* cuDevice */, context,
97+
UMF_MEMORY_TYPE_HOST, MemoryProviderHost);
98+
UMF_RETURN_UR_ERROR(UmfResult);
99+
100+
return UR_RESULT_SUCCESS;
101+
}
102+
77103
struct ur_context_handle_t_ {
78104

79105
struct deleter_data {
@@ -86,14 +112,25 @@ struct ur_context_handle_t_ {
86112
std::vector<ur_device_handle_t> Devices;
87113
std::atomic_uint32_t RefCount;
88114

115+
// UMF CUDA memory provider for the host memory (UMF_MEMORY_TYPE_HOST)
116+
umf_memory_provider_handle_t MemoryProviderHost = nullptr;
117+
89118
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
90119
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
91120
for (auto &Dev : Devices) {
92121
urDeviceRetain(Dev);
93122
}
123+
124+
// Create UMF CUDA memory provider for the host memory
125+
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
126+
// it is guaranteed to exist).
127+
UR_CHECK_ERROR(CreateHostMemoryProvider(Devices[0], &MemoryProviderHost));
94128
};
95129

96130
~ur_context_handle_t_() {
131+
if (MemoryProviderHost) {
132+
umfMemoryProviderDestroy(MemoryProviderHost);
133+
}
97134
for (auto &Dev : Devices) {
98135
urDeviceRelease(Dev);
99136
}

source/adapters/cuda/device.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#include <ur/ur.hpp>
1313

14+
#include <umf/memory_provider.h>
15+
1416
#include "common.hpp"
1517

1618
struct ur_device_handle_t_ {
@@ -79,9 +81,20 @@ struct ur_device_handle_t_ {
7981
// CUDA doesn't really have this concept, and could allow almost 100% of
8082
// global memory in one allocation, but is dependent on device usage.
8183
UR_CHECK_ERROR(cuDeviceTotalMem(&MaxAllocSize, cuDevice));
84+
85+
MemoryProviderDevice = nullptr;
86+
MemoryProviderShared = nullptr;
8287
}
8388

84-
~ur_device_handle_t_() { cuDevicePrimaryCtxRelease(CuDevice); }
89+
~ur_device_handle_t_() {
90+
if (MemoryProviderDevice) {
91+
umfMemoryProviderDestroy(MemoryProviderDevice);
92+
}
93+
if (MemoryProviderShared) {
94+
umfMemoryProviderDestroy(MemoryProviderShared);
95+
}
96+
cuDevicePrimaryCtxRelease(CuDevice);
97+
}
8598

8699
native_type get() const noexcept { return CuDevice; };
87100

@@ -117,6 +130,12 @@ struct ur_device_handle_t_ {
117130

118131
// bookkeeping for mipmappedArray leaks in Mapping external Memory
119132
std::map<CUarray, CUmipmappedArray> ChildCuarrayFromMipmapMap;
133+
134+
// UMF CUDA memory provider for the device memory (UMF_MEMORY_TYPE_DEVICE)
135+
umf_memory_provider_handle_t MemoryProviderDevice;
136+
137+
// UMF CUDA memory provider for the shared memory (UMF_MEMORY_TYPE_SHARED)
138+
umf_memory_provider_handle_t MemoryProviderShared;
120139
};
121140

122141
int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);

source/adapters/cuda/memory.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "context.hpp"
1515
#include "enqueue.hpp"
1616
#include "memory.hpp"
17+
#include "umf_helpers.hpp"
1718

1819
/// Creates a UR Memory object using a CUDA memory allocation.
1920
/// Can trigger a manual copy depending on the mode.
@@ -49,7 +50,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
4950
cuMemHostRegister(HostPtr, size, CU_MEMHOSTREGISTER_DEVICEMAP));
5051
AllocMode = BufferMem::AllocMode::UseHostPtr;
5152
} else if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) {
52-
UR_CHECK_ERROR(cuMemAllocHost(&HostPtr, size));
53+
UMF_CHECK_ERROR(umfMemoryProviderAlloc(hContext->MemoryProviderHost, size,
54+
0, &HostPtr));
5355
AllocMode = BufferMem::AllocMode::AllocHostPtr;
5456
} else if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) {
5557
AllocMode = BufferMem::AllocMode::CopyIn;
@@ -440,7 +442,8 @@ ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t Mem,
440442
CU_MEMHOSTALLOC_DEVICEMAP));
441443
UR_CHECK_ERROR(cuMemHostGetDevicePointer(&DevPtr, Buffer.HostPtr, 0));
442444
} else {
443-
UR_CHECK_ERROR(cuMemAlloc(&DevPtr, Buffer.Size));
445+
UMF_CHECK_ERROR(umfMemoryProviderAlloc(hDevice->MemoryProviderDevice,
446+
Buffer.Size, 0, (void **)&DevPtr));
444447
}
445448
} else {
446449
CUarray ImageArray{};

source/adapters/cuda/platform.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,46 @@
1313
#include "common.hpp"
1414
#include "context.hpp"
1515
#include "device.hpp"
16+
#include "umf_helpers.hpp"
1617

1718
#include <cassert>
1819
#include <cuda.h>
1920
#include <sstream>
2021

22+
static ur_result_t
23+
CreateDeviceMemoryProviders(ur_platform_handle_t_ *Platform) {
24+
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr;
25+
26+
umf_result_t UmfResult =
27+
umfCUDAMemoryProviderParamsCreate(&CUMemoryProviderParams);
28+
UMF_RETURN_UR_ERROR(UmfResult);
29+
30+
umf::cuda_params_unique_handle_t CUMemoryProviderParamsUnique(
31+
CUMemoryProviderParams, umfCUDAMemoryProviderParamsDestroy);
32+
33+
for (auto &Device : Platform->Devices) {
34+
ur_device_handle_t_ *device_handle = Device.get();
35+
CUdevice device = device_handle->get();
36+
CUcontext context = device_handle->getNativeContext();
37+
38+
// create UMF CUDA memory provider for the device memory
39+
// (UMF_MEMORY_TYPE_DEVICE)
40+
UmfResult = umf::createMemoryProvider(
41+
CUMemoryProviderParamsUnique.get(), device, context,
42+
UMF_MEMORY_TYPE_DEVICE, &device_handle->MemoryProviderDevice);
43+
UMF_RETURN_UR_ERROR(UmfResult);
44+
45+
// create UMF CUDA memory provider for the shared memory
46+
// (UMF_MEMORY_TYPE_SHARED)
47+
UmfResult = umf::createMemoryProvider(
48+
CUMemoryProviderParamsUnique.get(), device, context,
49+
UMF_MEMORY_TYPE_SHARED, &device_handle->MemoryProviderShared);
50+
UMF_RETURN_UR_ERROR(UmfResult);
51+
}
52+
53+
return UR_RESULT_SUCCESS;
54+
}
55+
2156
UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
2257
ur_platform_handle_t hPlatform, ur_platform_info_t PlatformInfoType,
2358
size_t Size, void *pPlatformInfo, size_t *pSizeRet) {
@@ -98,6 +133,8 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
98133
new ur_device_handle_t_{Device, Context, EvBase, &Platform,
99134
static_cast<uint32_t>(i)});
100135
}
136+
137+
UR_CHECK_ERROR(CreateDeviceMemoryProviders(&Platform));
101138
} catch (const std::bad_alloc &) {
102139
// Signal out-of-memory situation
103140
for (int i = 0; i < NumDevices; ++i) {

source/adapters/cuda/usm.cpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,26 +102,41 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
102102
return UR_RESULT_SUCCESS;
103103
}
104104

105-
ur_result_t USMFreeImpl(ur_context_handle_t, void *Pointer) {
105+
ur_result_t USMFreeImpl(ur_context_handle_t hContext, void *Pointer) {
106106
ur_result_t Result = UR_RESULT_SUCCESS;
107107
try {
108108
unsigned int IsManaged;
109109
unsigned int Type;
110-
void *AttributeValues[2] = {&IsManaged, &Type};
111-
CUpointer_attribute Attributes[2] = {CU_POINTER_ATTRIBUTE_IS_MANAGED,
112-
CU_POINTER_ATTRIBUTE_MEMORY_TYPE};
113-
UR_CHECK_ERROR(cuPointerGetAttributes(2, Attributes, AttributeValues,
114-
(CUdeviceptr)Pointer));
110+
unsigned int DeviceOrdinal;
111+
const int NumAttributes = 3;
112+
void *AttributeValues[NumAttributes] = {&IsManaged, &Type, &DeviceOrdinal};
113+
114+
CUpointer_attribute Attributes[NumAttributes] = {
115+
CU_POINTER_ATTRIBUTE_IS_MANAGED, CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
116+
CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL};
117+
UR_CHECK_ERROR(cuPointerGetAttributes(
118+
NumAttributes, Attributes, AttributeValues, (CUdeviceptr)Pointer));
115119
UR_ASSERT(Type == CU_MEMORYTYPE_DEVICE || Type == CU_MEMORYTYPE_HOST,
116120
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
117-
if (IsManaged || Type == CU_MEMORYTYPE_DEVICE) {
118-
// Memory allocated with cuMemAlloc and cuMemAllocManaged must be freed
119-
// with cuMemFree
120-
UR_CHECK_ERROR(cuMemFree((CUdeviceptr)Pointer));
121+
122+
std::vector<ur_device_handle_t> ContextDevices = hContext->getDevices();
123+
ur_platform_handle_t Platform = ContextDevices[0]->getPlatform();
124+
unsigned int NumDevices = Platform->Devices.size();
125+
UR_ASSERT(DeviceOrdinal < NumDevices, UR_RESULT_ERROR_INVALID_DEVICE);
126+
127+
ur_device_handle_t Device = Platform->Devices[DeviceOrdinal].get();
128+
umf_memory_provider_handle_t MemoryProvider;
129+
130+
if (IsManaged) {
131+
MemoryProvider = Device->MemoryProviderShared;
132+
} else if (Type == CU_MEMORYTYPE_DEVICE) {
133+
MemoryProvider = Device->MemoryProviderDevice;
121134
} else {
122-
// Memory allocated with cuMemAllocHost must be freed with cuMemFreeHost
123-
UR_CHECK_ERROR(cuMemFreeHost(Pointer));
135+
MemoryProvider = hContext->MemoryProviderHost;
124136
}
137+
138+
UMF_CHECK_ERROR(umfMemoryProviderFree(MemoryProvider, Pointer,
139+
0 /* size is unknown */));
125140
} catch (ur_result_t Err) {
126141
Result = Err;
127142
}
@@ -143,7 +158,8 @@ ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t,
143158
uint32_t Alignment) {
144159
try {
145160
ScopedContext Active(Device);
146-
UR_CHECK_ERROR(cuMemAlloc((CUdeviceptr *)ResultPtr, Size));
161+
UMF_CHECK_ERROR(umfMemoryProviderAlloc(Device->MemoryProviderDevice, Size,
162+
Alignment, ResultPtr));
147163
} catch (ur_result_t Err) {
148164
return Err;
149165
}
@@ -164,8 +180,8 @@ ur_result_t USMSharedAllocImpl(void **ResultPtr, ur_context_handle_t,
164180
uint32_t Alignment) {
165181
try {
166182
ScopedContext Active(Device);
167-
UR_CHECK_ERROR(cuMemAllocManaged((CUdeviceptr *)ResultPtr, Size,
168-
CU_MEM_ATTACH_GLOBAL));
183+
UMF_CHECK_ERROR(umfMemoryProviderAlloc(Device->MemoryProviderShared, Size,
184+
Alignment, ResultPtr));
169185
} catch (ur_result_t Err) {
170186
return Err;
171187
}
@@ -179,11 +195,12 @@ ur_result_t USMSharedAllocImpl(void **ResultPtr, ur_context_handle_t,
179195
return UR_RESULT_SUCCESS;
180196
}
181197

182-
ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t,
198+
ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t hContext,
183199
ur_usm_host_mem_flags_t, size_t Size,
184200
uint32_t Alignment) {
185201
try {
186-
UR_CHECK_ERROR(cuMemAllocHost(ResultPtr, Size));
202+
UMF_CHECK_ERROR(umfMemoryProviderAlloc(hContext->MemoryProviderHost, Size,
203+
Alignment, ResultPtr));
187204
} catch (ur_result_t Err) {
188205
return Err;
189206
}

source/common/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ else()
6464
set(UMF_BUILD_EXAMPLES OFF CACHE INTERNAL "Build UMF examples")
6565
set(UMF_BUILD_SHARED_LIBRARY ${UMF_BUILD_SHARED_LIBRARY} CACHE INTERNAL "Build UMF shared library")
6666
set(UMF_BUILD_LIBUMF_POOL_DISJOINT ON CACHE INTERNAL "Build Disjoint Pool")
67-
set(UMF_BUILD_CUDA_PROVIDER OFF CACHE INTERNAL "Build UMF CUDA provider")
67+
set(UMF_BUILD_CUDA_PROVIDER ON CACHE INTERNAL "Build UMF CUDA provider")
6868

6969
FetchContent_MakeAvailable(unified-memory-framework)
7070
FetchContent_GetProperties(unified-memory-framework)

source/common/umf_helpers.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <umf/memory_pool_ops.h>
1717
#include <umf/memory_provider.h>
1818
#include <umf/memory_provider_ops.h>
19+
#include <umf/providers/provider_cuda.h>
1920
#include <ur_api.h>
2021

2122
#include "logger/ur_logger.hpp"
@@ -27,6 +28,24 @@
2728
#include <tuple>
2829
#include <utility>
2930

31+
#define UMF_CHECK_ERROR(UmfResult) UR_CHECK_ERROR(umf::umf2urResult(UmfResult));
32+
33+
#define UMF_RETURN_UMF_ERROR(UmfResult) \
34+
do { \
35+
umf_result_t UmfResult_ = (UmfResult); \
36+
if (UmfResult_ != UMF_RESULT_SUCCESS) { \
37+
return UmfResult_; \
38+
} \
39+
} while (0)
40+
41+
#define UMF_RETURN_UR_ERROR(UmfResult) \
42+
do { \
43+
umf_result_t UmfResult_ = (UmfResult); \
44+
if (UmfResult_ != UMF_RESULT_SUCCESS) { \
45+
return umf::umf2urResult(UmfResult_); \
46+
} \
47+
} while (0)
48+
3049
namespace umf {
3150

3251
using pool_unique_handle_t =
@@ -35,6 +54,9 @@ using pool_unique_handle_t =
3554
using provider_unique_handle_t =
3655
std::unique_ptr<umf_memory_provider_t,
3756
std::function<void(umf_memory_provider_handle_t)>>;
57+
using cuda_params_unique_handle_t = std::unique_ptr<
58+
umf_cuda_memory_provider_params_t,
59+
std::function<umf_result_t(umf_cuda_memory_provider_params_handle_t)>>;
3860

3961
#define DEFINE_CHECK_OP(op) \
4062
template <typename T> class HAS_OP_##op { \
@@ -279,6 +301,33 @@ inline ur_result_t umf2urResult(umf_result_t umfResult) {
279301
};
280302
}
281303

304+
inline umf_result_t createMemoryProvider(
305+
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams,
306+
int cuDevice, void *cuContext, umf_usm_memory_type_t memType,
307+
umf_memory_provider_handle_t *provider) {
308+
309+
umf_result_t UmfResult =
310+
umfCUDAMemoryProviderParamsSetContext(CUMemoryProviderParams, cuContext);
311+
UMF_RETURN_UMF_ERROR(UmfResult);
312+
313+
UmfResult =
314+
umfCUDAMemoryProviderParamsSetDevice(CUMemoryProviderParams, cuDevice);
315+
UMF_RETURN_UMF_ERROR(UmfResult);
316+
317+
UmfResult =
318+
umfCUDAMemoryProviderParamsSetMemoryType(CUMemoryProviderParams, memType);
319+
UMF_RETURN_UMF_ERROR(UmfResult);
320+
321+
umf_memory_provider_handle_t umfCUDAprovider = nullptr;
322+
UmfResult = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
323+
CUMemoryProviderParams, &umfCUDAprovider);
324+
UMF_RETURN_UMF_ERROR(UmfResult);
325+
326+
*provider = umfCUDAprovider;
327+
328+
return UMF_RESULT_SUCCESS;
329+
}
330+
282331
} // namespace umf
283332

284333
#endif /* UMF_HELPERS_H */

0 commit comments

Comments
 (0)