Skip to content

Commit 11f3ca0

Browse files
CUDA: Quantized matrix matrix multiplication (#2160)
* mmq implementation for non k-quants * q6_K * q2_K * q3_k * q4_K * vdr * q5_K * faster q8_1 loading * loop unrolling * add __restrict__ * q2_K sc_high * GGML_CUDA_MMQ_Y * Updated Makefile * Update Makefile * DMMV_F16 -> F16 * Updated README, CMakeLists * Fix CMakeLists.txt * Fix CMakeLists.txt * Fix multi GPU out-of-bounds
1 parent 9baf9ef commit 11f3ca0

File tree

4 files changed

+1293
-322
lines changed

4 files changed

+1293
-322
lines changed

CMakeLists.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ endif()
6767
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
6868
option(LLAMA_BLAS "llama: use BLAS" OFF)
6969
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
70-
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
70+
option(LLAMA_CUBLAS "llama: use CUDA" OFF)
71+
option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
72+
set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
7173
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
7274
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
7375
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
@@ -251,6 +253,10 @@ if (LLAMA_CUBLAS)
251253
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
252254

253255
add_compile_definitions(GGML_USE_CUBLAS)
256+
if (LLAMA_CUDA_CUBLAS)
257+
add_compile_definitions(GGML_CUDA_CUBLAS)
258+
endif()
259+
add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
254260
if (LLAMA_CUDA_FORCE_DMMV)
255261
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
256262
endif()

Makefile

+13-2
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ ifdef LLAMA_CUBLAS
194194
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
195195
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
196196
OBJS += ggml-cuda.o
197-
NVCCFLAGS = --forward-unknown-to-host-compiler
197+
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
198198
ifdef LLAMA_CUDA_NVCC
199199
NVCC = $(LLAMA_CUDA_NVCC)
200200
else
@@ -220,14 +220,25 @@ else ifdef LLAMA_CUDA_DMMV_Y
220220
else
221221
NVCCFLAGS += -DGGML_CUDA_MMV_Y=1
222222
endif # LLAMA_CUDA_MMV_Y
223+
ifdef LLAMA_CUDA_F16
224+
NVCCFLAGS += -DGGML_CUDA_F16
225+
endif # LLAMA_CUDA_F16
223226
ifdef LLAMA_CUDA_DMMV_F16
224-
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
227+
NVCCFLAGS += -DGGML_CUDA_F16
225228
endif # LLAMA_CUDA_DMMV_F16
226229
ifdef LLAMA_CUDA_KQUANTS_ITER
227230
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
228231
else
229232
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
230233
endif
234+
ifdef LLAMA_CUDA_MMQ_Y
235+
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y)
236+
else
237+
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
238+
endif # LLAMA_CUDA_MMQ_Y
239+
ifdef LLAMA_CUDA_CUBLAS
240+
NVCCFLAGS += -DGGML_CUDA_CUBLAS
241+
endif # LLAMA_CUDA_CUBLAS
231242
ifdef LLAMA_CUDA_CCBIN
232243
NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
233244
endif

README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,12 @@ Building the program with BLAS support may lead to some performance improvements
402402

403403
| Option | Legal values | Default | Description |
404404
|-------------------------|------------------------|---------|-------------|
405+
| LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
406+
| LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
405407
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
406408
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
407-
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
408-
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
409+
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
410+
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
409411
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
410412

411413
- #### CLBlast

0 commit comments

Comments
 (0)