Skip to content

Commit 8447bc8

Browse files
CUDA: mmq CLI option, fixed mmq build issues
1 parent 11f3ca0 commit 8447bc8

File tree

10 files changed

+65
-25
lines changed

10 files changed

+65
-25
lines changed

CMakeLists.txt

+10-6
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework
6868
option(LLAMA_BLAS "llama: use BLAS" OFF)
6969
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
7070
option(LLAMA_CUBLAS "llama: use CUDA" OFF)
71-
option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
71+
#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
7272
set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
7373
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
7474
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
@@ -253,9 +253,9 @@ if (LLAMA_CUBLAS)
253253
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
254254

255255
add_compile_definitions(GGML_USE_CUBLAS)
256-
if (LLAMA_CUDA_CUBLAS)
257-
add_compile_definitions(GGML_CUDA_CUBLAS)
258-
endif()
256+
# if (LLAMA_CUDA_CUBLAS)
257+
# add_compile_definitions(GGML_CUDA_CUBLAS)
258+
# endif()
259259
add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
260260
if (LLAMA_CUDA_FORCE_DMMV)
261261
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
@@ -277,10 +277,14 @@ if (LLAMA_CUBLAS)
277277
endif()
278278

279279
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
280+
# 52 == lowest CUDA 12 standard
281+
# 60 == f16 CUDA intrinsics
282+
# 61 == integer CUDA intrinsics
283+
# 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
280284
if (LLAMA_CUDA_DMMV_F16)
281-
set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics
285+
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
282286
else()
283-
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
287+
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
284288
endif()
285289
endif()
286290
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

Makefile

+3-3
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ ifdef LLAMA_CUDA_MMQ_Y
236236
else
237237
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
238238
endif # LLAMA_CUDA_MMQ_Y
239-
ifdef LLAMA_CUDA_CUBLAS
240-
NVCCFLAGS += -DGGML_CUDA_CUBLAS
241-
endif # LLAMA_CUDA_CUBLAS
239+
#ifdef LLAMA_CUDA_CUBLAS
240+
# NVCCFLAGS += -DGGML_CUDA_CUBLAS
241+
#endif # LLAMA_CUDA_CUBLAS
242242
ifdef LLAMA_CUDA_CCBIN
243243
NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
244244
endif

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,9 @@ Building the program with BLAS support may lead to some performance improvements
402402

403403
| Option | Legal values | Default | Description |
404404
|-------------------------|------------------------|---------|-------------|
405+
<!---
405406
| LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
407+
--->
406408
| LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
407409
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
408410
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |

examples/common.cpp

+13-3
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
352352
#ifdef GGML_USE_CUBLAS
353353
params.main_gpu = std::stoi(argv[i]);
354354
#else
355-
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
355+
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
356356
#endif
357357
} else if (arg == "--tensor-split" || arg == "-ts") {
358358
if (++i >= argc) {
@@ -376,13 +376,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
376376
}
377377
}
378378
#else
379-
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
379+
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
380+
#endif // GGML_USE_CUBLAS
381+
} else if (arg == "--mul-mat-q" || arg == "-mmq") {
382+
#ifdef GGML_USE_CUBLAS
383+
params.mul_mat_q = true;
384+
#else
385+
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n");
380386
#endif // GGML_USE_CUBLAS
381387
} else if (arg == "--low-vram" || arg == "-lv") {
382388
#ifdef GGML_USE_CUBLAS
383389
params.low_vram = true;
384390
#else
385-
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
391+
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
386392
#endif // GGML_USE_CUBLAS
387393
} else if (arg == "--no-mmap") {
388394
params.use_mmap = false;
@@ -585,6 +591,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
585591
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
586592
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
587593
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n" );
594+
fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
595+
fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
596+
fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
588597
#endif
589598
fprintf(stdout, " --mtest compute maximum memory usage\n");
590599
fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");
@@ -637,6 +646,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
637646
lparams.main_gpu = params.main_gpu;
638647
lparams.tensor_split = params.tensor_split;
639648
lparams.low_vram = params.low_vram;
649+
lparams.mul_mat_q = params.mul_mat_q;
640650
lparams.seed = params.seed;
641651
lparams.f16_kv = params.memory_f16;
642652
lparams.use_mmap = params.use_mmap;

examples/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct gpt_params {
7474
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
7575

7676
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
77+
bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels
7778
bool memory_f16 = true; // use f16 instead of f32 for memory kv
7879
bool random_prompt = false; // do not randomize prompt if none provided
7980
bool use_color = false; // use color to distinguish generations and inputs

examples/server/server.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,9 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
631631
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
632632
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
633633
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
634+
fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
635+
fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
636+
fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
634637
#endif
635638
fprintf(stdout, " -m FNAME, --model FNAME\n");
636639
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
@@ -835,7 +838,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
835838
#ifdef GGML_USE_CUBLAS
836839
params.low_vram = true;
837840
#else
838-
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
841+
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
842+
#endif // GGML_USE_CUBLAS
843+
}
844+
else if (arg == "--mul-mat-q" || arg == "-mmq")
845+
{
846+
#ifdef GGML_USE_CUBLAS
847+
params.mul_mat_q = true;
848+
#else
849+
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
839850
#endif // GGML_USE_CUBLAS
840851
}
841852
else if (arg == "--main-gpu" || arg == "-mg")

ggml-cuda.cu

+14-10
Original file line numberDiff line numberDiff line change
@@ -3536,10 +3536,9 @@ static size_t g_scratch_offset = 0;
35363536

35373537
static int g_device_count = -1;
35383538
static int g_main_device = 0;
3539-
#ifndef GGML_CUDA_FORCE_DMMV
35403539
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
3541-
#endif
35423540
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
3541+
static bool g_mul_mat_q = false;
35433542

35443543
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
35453544

@@ -3561,9 +3560,7 @@ void ggml_init_cublas() {
35613560
g_tensor_split[id] = total_vram;
35623561
total_vram += prop.totalGlobalMem;
35633562

3564-
#ifndef GGML_CUDA_FORCE_DMMV
35653563
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
3566-
#endif
35673564
}
35683565
for (int id = 0; id < g_device_count; ++id) {
35693566
g_tensor_split[id] /= total_vram;
@@ -3916,6 +3913,7 @@ inline void ggml_cuda_op_mul_mat_vec(
39163913

39173914
#ifdef GGML_CUDA_FORCE_DMMV
39183915
const bool use_mul_mat_vec_q = false;
3916+
(void) g_compute_capabilities[0];
39193917
#else
39203918
int id;
39213919
CUDA_CHECK(cudaGetDevice(&id));
@@ -4657,12 +4655,14 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
46574655
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
46584656
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
46594657
} else {
4660-
#ifdef GGML_CUDA_CUBLAS
4661-
const bool use_mul_mat_q = false;
4662-
#else
4663-
const bool use_mul_mat_q = ggml_is_quantized(src0->type);
4664-
#endif // GGML_CUDA_CUBLAS
4665-
if (use_mul_mat_q) {
4658+
int min_compute_capability = INT_MAX;
4659+
for (int id = 0; id < g_device_count; ++id) {
4660+
if (min_compute_capability > g_compute_capabilities[id]) {
4661+
min_compute_capability = g_compute_capabilities[id];
4662+
}
4663+
}
4664+
4665+
if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
46664666
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
46674667
} else {
46684668
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
@@ -4953,6 +4953,10 @@ void ggml_cuda_set_main_device(int main_device) {
49534953
}
49544954
}
49554955

4956+
void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
4957+
g_mul_mat_q = mul_mat_q;
4958+
}
4959+
49564960
void ggml_cuda_set_scratch_size(size_t scratch_size) {
49574961
g_scratch_size = scratch_size;
49584962
}

ggml-cuda.h

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
2727
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
2828
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
2929
void ggml_cuda_set_main_device(int main_device);
30+
void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
3031
void ggml_cuda_set_scratch_size(size_t scratch_size);
3132
void ggml_cuda_free_scratch(void);
3233
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);

llama.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ struct llama_context_params llama_context_default_params() {
879879
/*.progress_callback =*/ nullptr,
880880
/*.progress_callback_user_data =*/ nullptr,
881881
/*.low_vram =*/ false,
882+
/*.mul_mat_q =*/ false,
882883
/*.f16_kv =*/ true,
883884
/*.logits_all =*/ false,
884885
/*.vocab_only =*/ false,
@@ -1006,6 +1007,7 @@ static void llama_model_load_internal(
10061007
int n_gpu_layers,
10071008
int main_gpu,
10081009
const float * tensor_split,
1010+
const bool mul_mat_q,
10091011
float rope_freq_base,
10101012
float rope_freq_scale,
10111013
bool low_vram,
@@ -1134,9 +1136,11 @@ static void llama_model_load_internal(
11341136
}
11351137

11361138
(void) main_gpu;
1139+
(void) mul_mat_q;
11371140
#if defined(GGML_USE_CUBLAS)
11381141
fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
11391142
ggml_cuda_set_main_device(main_gpu);
1143+
ggml_cuda_set_mul_mat_q(mul_mat_q);
11401144
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
11411145
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
11421146
#elif defined(GGML_USE_CLBLAST)
@@ -1341,6 +1345,7 @@ static bool llama_model_load(
13411345
int n_gpu_layers,
13421346
int main_gpu,
13431347
const float * tensor_split,
1348+
const bool mul_mat_q,
13441349
float rope_freq_base,
13451350
float rope_freq_scale,
13461351
bool low_vram,
@@ -1351,7 +1356,8 @@ static bool llama_model_load(
13511356
llama_progress_callback progress_callback,
13521357
void *progress_callback_user_data) {
13531358
try {
1354-
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
1359+
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers,
1360+
main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type,
13551361
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
13561362
return true;
13571363
} catch (const std::exception & err) {
@@ -3103,7 +3109,7 @@ struct llama_model * llama_load_model_from_file(
31033109
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
31043110

31053111
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
3106-
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
3112+
params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
31073113
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
31083114
params.progress_callback_user_data)) {
31093115
delete model;

llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ extern "C" {
108108

109109
// Keep the booleans together to avoid misalignment during copy-by-value.
110110
bool low_vram; // if true, reduce VRAM usage at the cost of performance
111+
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
111112
bool f16_kv; // use fp16 for KV cache
112113
bool logits_all; // the llama_eval() call computes all logits, not just the last one
113114
bool vocab_only; // only load the vocabulary, no weights

0 commit comments

Comments
 (0)