|
| 1 | +// A CUDAPluggableAllocator based on cumem* APIs. |
| 2 | +// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* |
| 3 | +// need to be unsigned long long |
| 4 | +#include <iostream> |
| 5 | + |
| 6 | +extern "C" { |
| 7 | + |
| 8 | +#define PY_SSIZE_T_CLEAN |
| 9 | +#include <Python.h> |
| 10 | + |
| 11 | +#include <sys/types.h> |
| 12 | +#include <cuda_runtime_api.h> |
| 13 | +#include <cuda.h> |
| 14 | + |
| 15 | +#define CUDA_CHECK(condition) \ |
| 16 | + do { \ |
| 17 | + CUresult error = condition; \ |
| 18 | + if (error != 0) { \ |
| 19 | + char* error_string; \ |
| 20 | + cuGetErrorString(error, (const char**)&error_string); \ |
| 21 | + std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \ |
| 22 | + << __LINE__ << std::endl; \ |
| 23 | + } \ |
| 24 | + } while (0) |
| 25 | + |
| 26 | +// Global references to Python callables |
| 27 | +// NOTE: this is borrowed reference, so we don't need to DECREF them. |
| 28 | +// This brings the limitation that the allocator needs to be singleton. |
| 29 | +static PyObject* g_python_malloc_callback = nullptr; |
| 30 | +static PyObject* g_python_free_callback = nullptr; |
| 31 | + |
| 32 | +// --------------------------------------------------------------------------- |
| 33 | +// Helper functions: |
| 34 | + |
| 35 | +void ensure_context(unsigned long long device) { |
| 36 | + CUcontext pctx; |
| 37 | + CUDA_CHECK(cuCtxGetCurrent(&pctx)); |
| 38 | + if (!pctx) { |
| 39 | + // Ensure device context. |
| 40 | + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); |
| 41 | + CUDA_CHECK(cuCtxSetCurrent(pctx)); |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, |
| 46 | + CUmemGenericAllocationHandle* p_memHandle) { |
| 47 | + ensure_context(device); |
| 48 | + // Define memory allocation properties |
| 49 | + CUmemAllocationProp prop = {}; |
| 50 | + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; |
| 51 | + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; |
| 52 | + prop.location.id = device; |
| 53 | + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; |
| 54 | + |
| 55 | + // Allocate memory using cuMemCreate |
| 56 | + CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); |
| 57 | + CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); |
| 58 | + |
| 59 | + CUmemAccessDesc accessDesc = {}; |
| 60 | + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; |
| 61 | + accessDesc.location.id = device; |
| 62 | + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; |
| 63 | + |
| 64 | + CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); |
| 65 | + // std::cout << "create_and_map: device=" << device << ", size=" << size << ", |
| 66 | + // d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; |
| 67 | +} |
| 68 | + |
| 69 | +void unmap_and_release(unsigned long long device, ssize_t size, |
| 70 | + CUdeviceptr d_mem, |
| 71 | + CUmemGenericAllocationHandle* p_memHandle) { |
| 72 | + // std::cout << "unmap_and_release: device=" << device << ", size=" << size << |
| 73 | + // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; |
| 74 | + ensure_context(device); |
| 75 | + CUDA_CHECK(cuMemUnmap(d_mem, size)); |
| 76 | + CUDA_CHECK(cuMemRelease(*p_memHandle)); |
| 77 | +} |
| 78 | + |
| 79 | +PyObject* create_tuple_from_c_integers(unsigned long long a, |
| 80 | + unsigned long long b, |
| 81 | + unsigned long long c, |
| 82 | + unsigned long long d) { |
| 83 | + // Create a new tuple of size 4 |
| 84 | + PyObject* tuple = PyTuple_New(4); |
| 85 | + if (!tuple) { |
| 86 | + return NULL; // Return NULL on failure |
| 87 | + } |
| 88 | + |
| 89 | + // Convert integers to Python objects and set them in the tuple |
| 90 | + PyTuple_SetItem( |
| 91 | + tuple, 0, |
| 92 | + PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong |
| 93 | + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); |
| 94 | + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); |
| 95 | + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); |
| 96 | + |
| 97 | + // Note: PyTuple_SetItem "steals" a reference to each object, |
| 98 | + // so we do not need to Py_DECREF the PyLong objects explicitly. |
| 99 | + |
| 100 | + return tuple; // Return the created tuple |
| 101 | +} |
| 102 | + |
| 103 | +// --------------------------------------------------------------------------- |
| 104 | +// Our exported C functions that call Python: |
| 105 | + |
| 106 | +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h |
| 107 | +void* my_malloc(ssize_t size, int device, CUstream stream) { |
| 108 | + ensure_context(device); |
| 109 | + |
| 110 | + // first allocation, align the size, and reserve an address, and also allocate |
| 111 | + // a CUmemGenericAllocationHandle |
| 112 | + |
| 113 | + // Define memory allocation properties |
| 114 | + CUmemAllocationProp prop = {}; |
| 115 | + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; |
| 116 | + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; |
| 117 | + prop.location.id = device; |
| 118 | + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; |
| 119 | + |
| 120 | + // Check if the allocation is supported |
| 121 | + size_t granularity; |
| 122 | + CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, |
| 123 | + CU_MEM_ALLOC_GRANULARITY_MINIMUM)); |
| 124 | + |
| 125 | + size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; |
| 126 | + |
| 127 | + CUdeviceptr d_mem; |
| 128 | + CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); |
| 129 | + |
| 130 | + // allocate the CUmemGenericAllocationHandle |
| 131 | + CUmemGenericAllocationHandle* p_memHandle = |
| 132 | + (CUmemGenericAllocationHandle*)malloc( |
| 133 | + sizeof(CUmemGenericAllocationHandle)); |
| 134 | + |
| 135 | + if (!g_python_malloc_callback) { |
| 136 | + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; |
| 137 | + return nullptr; |
| 138 | + } |
| 139 | + |
| 140 | + // Acquire GIL (not in stable ABI officially, but often works) |
| 141 | + PyGILState_STATE gstate = PyGILState_Ensure(); |
| 142 | + |
| 143 | + PyObject* arg_tuple = create_tuple_from_c_integers( |
| 144 | + (unsigned long long)device, (unsigned long long)alignedSize, |
| 145 | + (unsigned long long)d_mem, (unsigned long long)p_memHandle); |
| 146 | + |
| 147 | + // Call g_python_malloc_callback |
| 148 | + PyObject* py_result = |
| 149 | + PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); |
| 150 | + Py_DECREF(arg_tuple); |
| 151 | + |
| 152 | + if (!py_result) { |
| 153 | + PyErr_Print(); |
| 154 | + PyGILState_Release(gstate); |
| 155 | + return nullptr; |
| 156 | + } |
| 157 | + |
| 158 | + PyGILState_Release(gstate); |
| 159 | + |
| 160 | + // do the final mapping |
| 161 | + create_and_map(device, alignedSize, d_mem, p_memHandle); |
| 162 | + |
| 163 | + return (void*)d_mem; |
| 164 | +} |
| 165 | + |
| 166 | +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h |
| 167 | +void my_free(void* ptr, ssize_t size, int device, CUstream stream) { |
| 168 | + // get memory handle from the pointer |
| 169 | + if (!g_python_free_callback) { |
| 170 | + std::cerr << "ERROR: g_python_free_callback not set.\n"; |
| 171 | + return; |
| 172 | + } |
| 173 | + |
| 174 | + // Acquire GIL (not in stable ABI officially, but often works) |
| 175 | + PyGILState_STATE gstate = PyGILState_Ensure(); |
| 176 | + |
| 177 | + PyObject* py_ptr = |
| 178 | + PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr)); |
| 179 | + |
| 180 | + PyObject* py_result = |
| 181 | + PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); |
| 182 | + |
| 183 | + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { |
| 184 | + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); |
| 185 | + return; |
| 186 | + } |
| 187 | + |
| 188 | + unsigned long long recv_device, recv_size; |
| 189 | + unsigned long long recv_d_mem, recv_p_memHandle; |
| 190 | + // Unpack the tuple into four C integers |
| 191 | + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, |
| 192 | + &recv_d_mem, &recv_p_memHandle)) { |
| 193 | + // PyArg_ParseTuple sets an error if it fails |
| 194 | + return; |
| 195 | + } |
| 196 | + |
| 197 | + PyGILState_Release(gstate); |
| 198 | + |
| 199 | + // recv_size == size |
| 200 | + // recv_device == device |
| 201 | + |
| 202 | + // Free memory |
| 203 | + |
| 204 | + CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; |
| 205 | + CUmemGenericAllocationHandle* p_memHandle = |
| 206 | + (CUmemGenericAllocationHandle*)recv_p_memHandle; |
| 207 | + unmap_and_release(device, size, d_mem, p_memHandle); |
| 208 | + |
| 209 | + // free address and the handle |
| 210 | + CUDA_CHECK(cuMemAddressFree(d_mem, size)); |
| 211 | + free(p_memHandle); |
| 212 | +} |
| 213 | + |
| 214 | +// --------------------------------------------------------------------------- |
| 215 | +// Python extension boilerplate: |
| 216 | + |
| 217 | +// Python-exposed function: init_module(python_malloc, python_free) |
| 218 | +static PyObject* py_init_module(PyObject* self, PyObject* args) { |
| 219 | + PyObject* malloc_callback = nullptr; |
| 220 | + PyObject* free_callback = nullptr; |
| 221 | + |
| 222 | + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { |
| 223 | + return nullptr; |
| 224 | + } |
| 225 | + |
| 226 | + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { |
| 227 | + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); |
| 228 | + return nullptr; |
| 229 | + } |
| 230 | + |
| 231 | + // Save the Python callables |
| 232 | + // This module does not handle GC of these objects, so they must be kept alive |
| 233 | + // outside of this module. |
| 234 | + g_python_malloc_callback = malloc_callback; |
| 235 | + g_python_free_callback = free_callback; |
| 236 | + |
| 237 | + Py_RETURN_NONE; |
| 238 | +} |
| 239 | + |
| 240 | +static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { |
| 241 | + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { |
| 242 | + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); |
| 243 | + return nullptr; |
| 244 | + } |
| 245 | + |
| 246 | + unsigned long long recv_device, recv_size; |
| 247 | + unsigned long long recv_d_mem, recv_p_memHandle; |
| 248 | + // Unpack the tuple into four C integers |
| 249 | + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, |
| 250 | + &recv_p_memHandle)) { |
| 251 | + // PyArg_ParseTuple sets an error if it fails |
| 252 | + return nullptr; |
| 253 | + } |
| 254 | + |
| 255 | + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; |
| 256 | + CUmemGenericAllocationHandle* p_memHandle = |
| 257 | + (CUmemGenericAllocationHandle*)recv_p_memHandle; |
| 258 | + |
| 259 | + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); |
| 260 | + |
| 261 | + Py_RETURN_NONE; |
| 262 | +} |
| 263 | + |
| 264 | +static PyObject* python_create_and_map(PyObject* self, PyObject* args) { |
| 265 | + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { |
| 266 | + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); |
| 267 | + return nullptr; |
| 268 | + } |
| 269 | + |
| 270 | + unsigned long long recv_device, recv_size; |
| 271 | + unsigned long long recv_d_mem, recv_p_memHandle; |
| 272 | + // Unpack the tuple into four C integers |
| 273 | + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, |
| 274 | + &recv_p_memHandle)) { |
| 275 | + // PyArg_ParseTuple sets an error if it fails |
| 276 | + return nullptr; |
| 277 | + } |
| 278 | + |
| 279 | + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; |
| 280 | + CUmemGenericAllocationHandle* p_memHandle = |
| 281 | + (CUmemGenericAllocationHandle*)recv_p_memHandle; |
| 282 | + |
| 283 | + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); |
| 284 | + |
| 285 | + Py_RETURN_NONE; |
| 286 | +} |
| 287 | + |
| 288 | +static PyMethodDef module_methods[] = { |
| 289 | + {"init_module", (PyCFunction)py_init_module, METH_VARARGS, |
| 290 | + "Initialize module with python_malloc and python_free callables."}, |
| 291 | + {"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, |
| 292 | + "Create and map memory on the device."}, |
| 293 | + {"python_unmap_and_release", (PyCFunction)python_unmap_and_release, |
| 294 | + METH_VARARGS, "Unmap and release memory on the device."}, |
| 295 | + {NULL, NULL, 0, NULL} // sentinel |
| 296 | +}; |
| 297 | + |
| 298 | +static struct PyModuleDef cumem_allocator_module = { |
| 299 | + PyModuleDef_HEAD_INIT, "cumem_allocator", |
| 300 | + "cumem-based allocator for CUDAPluggableAllocator", -1, module_methods}; |
| 301 | + |
| 302 | +PyMODINIT_FUNC PyInit_cumem_allocator(void) { |
| 303 | + // Initialize the module |
| 304 | + PyObject* module = PyModule_Create(&cumem_allocator_module); |
| 305 | + if (!module) { |
| 306 | + return NULL; |
| 307 | + } |
| 308 | + return module; |
| 309 | +} |
| 310 | +} // extern "C" |
0 commit comments