Skip to content

Commit f5bf761

Browse files
fraxy-vslaren
andauthored
Capture CUDA logging output (#7298)
* logging: output capture in cuda module * fix compile error * fix: vsnprintf terminates with 0, string use not correct * post review * Update llama.cpp Co-authored-by: slaren <[email protected]> * Update llama.cpp Co-authored-by: slaren <[email protected]> --------- Co-authored-by: slaren <[email protected]>
1 parent 059031b commit f5bf761

File tree

3 files changed

+75
-30
lines changed

3 files changed

+75
-30
lines changed

ggml-cuda.cu

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,59 @@
4343
#include <mutex>
4444
#include <stdint.h>
4545
#include <stdio.h>
46+
#include <stdarg.h>
47+
#include <stdlib.h>
4648
#include <string>
4749
#include <vector>
4850

4951
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
5052

53+
static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
54+
GGML_UNUSED(level);
55+
GGML_UNUSED(user_data);
56+
fprintf(stderr, "%s", msg);
57+
}
58+
59+
ggml_log_callback ggml_cuda_log_callback = ggml_cuda_default_log_callback;
60+
void * ggml_cuda_log_user_data = NULL;
61+
62+
GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data) {
63+
ggml_cuda_log_callback = log_callback;
64+
ggml_cuda_log_user_data = user_data;
65+
}
66+
67+
#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
68+
#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
69+
#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
70+
71+
GGML_ATTRIBUTE_FORMAT(2, 3)
72+
static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
73+
if (ggml_cuda_log_callback != NULL) {
74+
va_list args;
75+
va_start(args, format);
76+
char buffer[128];
77+
int len = vsnprintf(buffer, 128, format, args);
78+
if (len < 128) {
79+
ggml_cuda_log_callback(level, buffer, ggml_cuda_log_user_data);
80+
} else {
81+
std::vector<char> buffer2(len + 1); // vsnprintf adds a null terminator
82+
va_end(args);
83+
va_start(args, format);
84+
vsnprintf(&buffer2[0], buffer2.size(), format, args);
85+
ggml_cuda_log_callback(level, buffer2.data(), ggml_cuda_log_user_data);
86+
}
87+
va_end(args);
88+
}
89+
}
90+
5191
[[noreturn]]
5292
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
5393
int id = -1; // in case cudaGetDevice fails
5494
cudaGetDevice(&id);
5595

56-
fprintf(stderr, "CUDA error: %s\n", msg);
57-
fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line);
58-
fprintf(stderr, " %s\n", stmt);
96+
GGML_CUDA_LOG_ERROR("CUDA error: %s\n", msg);
97+
GGML_CUDA_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
98+
GGML_CUDA_LOG_ERROR(" %s\n", stmt);
5999
// abort with GGML_ASSERT to get a stack trace
60100
GGML_ASSERT(!"CUDA error");
61101
}
@@ -91,24 +131,24 @@ static ggml_cuda_device_info ggml_cuda_init() {
91131

92132
cudaError_t err = cudaGetDeviceCount(&info.device_count);
93133
if (err != cudaSuccess) {
94-
fprintf(stderr, "%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
134+
GGML_CUDA_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
95135
return info;
96136
}
97137

98138
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
99139

100140
int64_t total_vram = 0;
101141
#if defined(GGML_CUDA_FORCE_MMQ)
102-
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
142+
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
103143
#else
104-
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
144+
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
105145
#endif
106146
#if defined(CUDA_USE_TENSOR_CORES)
107-
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
147+
GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
108148
#else
109-
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
149+
GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
110150
#endif
111-
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
151+
GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
112152
for (int id = 0; id < info.device_count; ++id) {
113153
int device_vmm = 0;
114154

@@ -129,7 +169,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
129169

130170
cudaDeviceProp prop;
131171
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
132-
fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
172+
GGML_CUDA_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
133173

134174
info.default_tensor_split[id] = total_vram;
135175
total_vram += prop.totalGlobalMem;
@@ -235,8 +275,8 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
235275
*actual_size = look_ahead_size;
236276
pool_size += look_ahead_size;
237277
#ifdef DEBUG_CUDA_MALLOC
238-
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
239-
(uint32_t)(max_size/1024/1024), (uint32_t)(pool_size/1024/1024), (uint32_t)(size/1024/1024));
278+
GGML_CUDA_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
279+
(uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
240280
#endif
241281
return ptr;
242282
}
@@ -250,7 +290,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
250290
return;
251291
}
252292
}
253-
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
293+
GGML_CUDA_LOG_WARN("Cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
254294
ggml_cuda_set_device(device);
255295
CUDA_CHECK(cudaFree(ptr));
256296
pool_size -= size;
@@ -499,7 +539,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
499539
void * dev_ptr;
500540
cudaError_t err = cudaMalloc(&dev_ptr, size);
501541
if (err != cudaSuccess) {
502-
fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size/1024.0/1024.0, buft_ctx->device, cudaGetErrorString(err));
542+
GGML_CUDA_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
503543
return nullptr;
504544
}
505545

@@ -1002,8 +1042,8 @@ static void * ggml_cuda_host_malloc(size_t size) {
10021042
if (err != cudaSuccess) {
10031043
// clear the error
10041044
cudaGetLastError();
1005-
fprintf(stderr, "%s: warning: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1006-
size/1024.0/1024.0, cudaGetErrorString(err));
1045+
GGML_CUDA_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1046+
size / 1024.0 / 1024.0, cudaGetErrorString(err));
10071047
return nullptr;
10081048
}
10091049

@@ -2246,7 +2286,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22462286
break;
22472287
case GGML_OP_MUL_MAT:
22482288
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
2249-
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
2289+
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
22502290
return false;
22512291
} else {
22522292
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
@@ -2300,7 +2340,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23002340

23012341
cudaError_t err = cudaGetLastError();
23022342
if (err != cudaSuccess) {
2303-
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
2343+
GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
23042344
CUDA_CHECK(err);
23052345
}
23062346

@@ -2476,7 +2516,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24762516
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
24772517
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
24782518
#ifndef NDEBUG
2479-
fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__);
2519+
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
24802520
#endif
24812521
}
24822522
}
@@ -2523,14 +2563,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25232563
if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
25242564
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
25252565
#ifndef NDEBUG
2526-
fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__);
2566+
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__);
25272567
#endif
25282568
}
25292569

25302570
if (node->op == GGML_OP_MUL_MAT_ID) {
25312571
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
25322572
#ifndef NDEBUG
2533-
fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
2573+
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
25342574
#endif
25352575
}
25362576

@@ -2539,7 +2579,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25392579
// Changes in batch size or context size can cause changes to the grid size of some kernels.
25402580
use_cuda_graph = false;
25412581
#ifndef NDEBUG
2542-
fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2582+
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
25432583
#endif
25442584
}
25452585

@@ -2567,7 +2607,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25672607
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
25682608
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
25692609
#ifndef NDEBUG
2570-
fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
2610+
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
25712611
#endif
25722612
}
25732613
}
@@ -2605,7 +2645,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26052645

26062646
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
26072647
if (!ok) {
2608-
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2648+
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
26092649
}
26102650
GGML_ASSERT(ok);
26112651
}
@@ -2624,7 +2664,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26242664
use_cuda_graph = false;
26252665
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
26262666
#ifndef NDEBUG
2627-
fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2667+
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
26282668
#endif
26292669
} else {
26302670
graph_evaluated_or_captured = true; // CUDA graph has been captured
@@ -2691,7 +2731,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26912731
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
26922732
if (stat == cudaErrorGraphExecUpdateFailure) {
26932733
#ifndef NDEBUG
2694-
fprintf(stderr, "%s: CUDA graph update failed\n", __func__);
2734+
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
26952735
#endif
26962736
// The pre-existing graph exec cannot be updated due to violated constraints
26972737
// so instead clear error and re-instantiate
@@ -2948,13 +2988,13 @@ static ggml_guid_t ggml_backend_cuda_guid() {
29482988

29492989
GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
29502990
if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
2951-
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
2991+
GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device);
29522992
return nullptr;
29532993
}
29542994

29552995
ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
29562996
if (ctx == nullptr) {
2957-
fprintf(stderr, "%s: error: failed to allocate context\n", __func__);
2997+
GGML_CUDA_LOG_ERROR("%s: failed to allocate context\n", __func__);
29582998
return nullptr;
29592999
}
29603000

@@ -2998,8 +3038,8 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
29983038
// clear the error
29993039
cudaGetLastError();
30003040

3001-
fprintf(stderr, "%s: warning: failed to register %.2f MiB of pinned memory: %s\n", __func__,
3002-
size/1024.0/1024.0, cudaGetErrorString(err));
3041+
GGML_CUDA_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
3042+
size / 1024.0 / 1024.0, cudaGetErrorString(err));
30033043
return false;
30043044
}
30053045
return true;

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t *
3838
GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
3939
GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
4040

41+
GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
4142
#ifdef __cplusplus
4243
}
4344
#endif

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,6 +1697,8 @@ struct llama_state {
16971697
llama_state() {
16981698
#ifdef GGML_USE_METAL
16991699
ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data);
1700+
#elif defined(GGML_USE_CUDA)
1701+
ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data);
17001702
#endif
17011703
}
17021704

@@ -18174,6 +18176,8 @@ void llama_log_set(ggml_log_callback log_callback, void * user_data) {
1817418176
g_state.log_callback_user_data = user_data;
1817518177
#ifdef GGML_USE_METAL
1817618178
ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
18179+
#elif defined(GGML_USE_CUDA)
18180+
ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
1817718181
#endif
1817818182
}
1817918183

0 commit comments

Comments
 (0)