@@ -12,15 +12,21 @@ extern "C" {
12
12
#include < cuda_runtime_api.h>
13
13
#include < cuda.h>
14
14
15
- #define CUDA_CHECK (condition ) \
16
- do { \
17
- CUresult error = condition; \
18
- if (error != 0 ) { \
19
- char * error_string; \
20
- cuGetErrorString (error, (const char **)&error_string); \
21
- std::cerr << " CUDA Error: " << error_string << " at " << __FILE__ << " :" \
22
- << __LINE__ << std::endl; \
23
- } \
15
+ char error_msg[10240 ]; // 10KB buffer to store error messages
16
+ CUresult no_error = CUresult(0 );
17
+ CUresult error_code = no_error; // store error code
18
+
19
+ #define CUDA_CHECK (condition ) \
20
+ do { \
21
+ CUresult error = condition; \
22
+ if (error != 0 ) { \
23
+ error_code = error; \
24
+ char * error_string; \
25
+ cuGetErrorString (error, (const char **)&error_string); \
26
+ snprintf (error_msg, sizeof (error_msg), " CUDA Error: %s at %s:%d" , \
27
+ error_string, __FILE__, __LINE__); \
28
+ std::cerr << error_msg << std::endl; \
29
+ } \
24
30
} while (0 )
25
31
26
32
// Global references to Python callables
@@ -54,14 +60,22 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
54
60
55
61
// Allocate memory using cuMemCreate
56
62
CUDA_CHECK (cuMemCreate (p_memHandle, size, &prop, 0 ));
63
+ if (error_code != 0 ) {
64
+ return ;
65
+ }
57
66
CUDA_CHECK (cuMemMap (d_mem, size, 0 , *p_memHandle, 0 ));
58
-
67
+ if (error_code != 0 ) {
68
+ return ;
69
+ }
59
70
CUmemAccessDesc accessDesc = {};
60
71
accessDesc.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
61
72
accessDesc.location .id = device;
62
73
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
63
74
64
75
CUDA_CHECK (cuMemSetAccess (d_mem, size, &accessDesc, 1 ));
76
+ if (error_code != 0 ) {
77
+ return ;
78
+ }
65
79
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
66
80
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
67
81
}
@@ -73,7 +87,13 @@ void unmap_and_release(unsigned long long device, ssize_t size,
73
87
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
74
88
ensure_context (device);
75
89
CUDA_CHECK (cuMemUnmap (d_mem, size));
90
+ if (error_code != 0 ) {
91
+ return ;
92
+ }
76
93
CUDA_CHECK (cuMemRelease (*p_memHandle));
94
+ if (error_code != 0 ) {
95
+ return ;
96
+ }
77
97
}
78
98
79
99
PyObject* create_tuple_from_c_integers (unsigned long long a,
@@ -121,12 +141,16 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
121
141
size_t granularity;
122
142
CUDA_CHECK (cuMemGetAllocationGranularity (&granularity, &prop,
123
143
CU_MEM_ALLOC_GRANULARITY_MINIMUM));
124
-
144
+ if (error_code != 0 ) {
145
+ return nullptr ;
146
+ }
125
147
size_t alignedSize = ((size + granularity - 1 ) / granularity) * granularity;
126
148
127
149
CUdeviceptr d_mem;
128
150
CUDA_CHECK (cuMemAddressReserve (&d_mem, alignedSize, 0 , 0 , 0 ));
129
-
151
+ if (error_code != 0 ) {
152
+ return nullptr ;
153
+ }
130
154
// allocate the CUmemGenericAllocationHandle
131
155
CUmemGenericAllocationHandle* p_memHandle =
132
156
(CUmemGenericAllocationHandle*)malloc (
@@ -208,6 +232,9 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
208
232
209
233
// free address and the handle
210
234
CUDA_CHECK (cuMemAddressFree (d_mem, size));
235
+ if (error_code != 0 ) {
236
+ return ;
237
+ }
211
238
free (p_memHandle);
212
239
}
213
240
@@ -258,6 +285,12 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
258
285
259
286
unmap_and_release (recv_device, recv_size, d_mem_ptr, p_memHandle);
260
287
288
+ if (error_code != 0 ) {
289
+ error_code = no_error;
290
+ PyErr_SetString (PyExc_RuntimeError, error_msg);
291
+ return nullptr ;
292
+ }
293
+
261
294
Py_RETURN_NONE;
262
295
}
263
296
@@ -282,6 +315,12 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
282
315
283
316
create_and_map (recv_device, recv_size, d_mem_ptr, p_memHandle);
284
317
318
+ if (error_code != 0 ) {
319
+ error_code = no_error;
320
+ PyErr_SetString (PyExc_RuntimeError, error_msg);
321
+ return nullptr ;
322
+ }
323
+
285
324
Py_RETURN_NONE;
286
325
}
287
326
0 commit comments