Skip to content

Commit ee45b72

Browse files
youkaichaoSzymonOzog
authored andcommitted
[core] improve error handling when wake up from sleep mode (vllm-project#12981)
Signed-off-by: youkaichao <[email protected]> Signed-off-by: SzymonOzog <[email protected]>
1 parent 746c22b commit ee45b72

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

csrc/cumem_allocator.cpp

+51-12
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,21 @@ extern "C" {
1212
#include <cuda_runtime_api.h>
1313
#include <cuda.h>
1414

15-
#define CUDA_CHECK(condition) \
16-
do { \
17-
CUresult error = condition; \
18-
if (error != 0) { \
19-
char* error_string; \
20-
cuGetErrorString(error, (const char**)&error_string); \
21-
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
22-
<< __LINE__ << std::endl; \
23-
} \
15+
char error_msg[10240]; // 10KB buffer to store error messages
16+
CUresult no_error = CUresult(0);
17+
CUresult error_code = no_error; // store error code
18+
19+
#define CUDA_CHECK(condition) \
20+
do { \
21+
CUresult error = condition; \
22+
if (error != 0) { \
23+
error_code = error; \
24+
char* error_string; \
25+
cuGetErrorString(error, (const char**)&error_string); \
26+
snprintf(error_msg, sizeof(error_msg), "CUDA Error: %s at %s:%d", \
27+
error_string, __FILE__, __LINE__); \
28+
std::cerr << error_msg << std::endl; \
29+
} \
2430
} while (0)
2531

2632
// Global references to Python callables
@@ -54,14 +60,22 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
5460

5561
// Allocate memory using cuMemCreate
5662
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
63+
if (error_code != 0) {
64+
return;
65+
}
5766
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));
58-
67+
if (error_code != 0) {
68+
return;
69+
}
5970
CUmemAccessDesc accessDesc = {};
6071
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
6172
accessDesc.location.id = device;
6273
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
6374

6475
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
76+
if (error_code != 0) {
77+
return;
78+
}
6579
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
6680
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
6781
}
@@ -73,7 +87,13 @@ void unmap_and_release(unsigned long long device, ssize_t size,
7387
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
7488
ensure_context(device);
7589
CUDA_CHECK(cuMemUnmap(d_mem, size));
90+
if (error_code != 0) {
91+
return;
92+
}
7693
CUDA_CHECK(cuMemRelease(*p_memHandle));
94+
if (error_code != 0) {
95+
return;
96+
}
7797
}
7898

7999
PyObject* create_tuple_from_c_integers(unsigned long long a,
@@ -121,12 +141,16 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
121141
size_t granularity;
122142
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop,
123143
CU_MEM_ALLOC_GRANULARITY_MINIMUM));
124-
144+
if (error_code != 0) {
145+
return nullptr;
146+
}
125147
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
126148

127149
CUdeviceptr d_mem;
128150
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
129-
151+
if (error_code != 0) {
152+
return nullptr;
153+
}
130154
// allocate the CUmemGenericAllocationHandle
131155
CUmemGenericAllocationHandle* p_memHandle =
132156
(CUmemGenericAllocationHandle*)malloc(
@@ -208,6 +232,9 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
208232

209233
// free address and the handle
210234
CUDA_CHECK(cuMemAddressFree(d_mem, size));
235+
if (error_code != 0) {
236+
return;
237+
}
211238
free(p_memHandle);
212239
}
213240

@@ -258,6 +285,12 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
258285

259286
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
260287

288+
if (error_code != 0) {
289+
error_code = no_error;
290+
PyErr_SetString(PyExc_RuntimeError, error_msg);
291+
return nullptr;
292+
}
293+
261294
Py_RETURN_NONE;
262295
}
263296

@@ -282,6 +315,12 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
282315

283316
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
284317

318+
if (error_code != 0) {
319+
error_code = no_error;
320+
PyErr_SetString(PyExc_RuntimeError, error_msg);
321+
return nullptr;
322+
}
323+
285324
Py_RETURN_NONE;
286325
}
287326

tests/basic_correctness/test_cumem.py

+27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import pytest
34
import torch
45

56
from vllm import LLM, SamplingParams
@@ -9,6 +10,32 @@
910
from ..utils import fork_new_process_for_each_test
1011

1112

13+
@fork_new_process_for_each_test
14+
def test_python_error():
15+
"""
16+
Test if Python error occurs when there's low-level
17+
error happening from the C++ side.
18+
"""
19+
allocator = CuMemAllocator.get_instance()
20+
total_bytes = torch.cuda.mem_get_info()[1]
21+
alloc_bytes = int(total_bytes * 0.7)
22+
tensors = []
23+
with allocator.use_memory_pool():
24+
# allocate 70% of the total memory
25+
x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
26+
tensors.append(x)
27+
# release the memory
28+
allocator.sleep()
29+
30+
# allocate more memory than the total memory
31+
y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
32+
tensors.append(y)
33+
with pytest.raises(RuntimeError):
34+
# when the allocator is woken up, it should raise an error
35+
# because we don't have enough memory
36+
allocator.wake_up()
37+
38+
1239
@fork_new_process_for_each_test
1340
def test_basic_cumem():
1441
# some tensors from default memory pool

0 commit comments

Comments
 (0)