Skip to content

ggml-cuda: Adding support for unified memory #8035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 1, 2024
Merged

ggml-cuda: Adding support for unified memory #8035

merged 7 commits into from
Aug 1, 2024

Conversation

matteoserva
Copy link
Contributor

@matteoserva matteoserva commented Jun 20, 2024

This adds a environment variable for launching llama.cpp with unified memory on CUDA.
This is useful when the model barely fits in VRAM and inference causes OOM errors.
In that case token generation with unified memory is much faster than partially offloading the model in CPU RAM.

Example: llama-3-70b-IQ2_XS on 24GB of VRAM.
Launch command: GGML_CUDA_ENABLE_UNIFIED_MEMORY=1 ./llama-server

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jun 20, 2024

On my machine with an Epyc 7742 and an RTX 4090 llama-bench still OOMs even with this change. I forgot to set the environment variable.

@JohannesGaessler
Copy link
Collaborator

In that case token generation with unified memory is much faster than partially offloading the model in CPU RAM.

Based on what testing methodology? On master I get this performance using llama-bench (Epyc 7742, 8x 3200 MHz RAM, RTX 4090):

model size params backend ngl fa test t/s
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 pp512 784.18 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 tg128 22.62 ± 0.00

I OOM with 77 layers. With the mode added by this PR I get:

model size params backend ngl fa test t/s
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 70 1 pp512 641.00 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 70 1 tg128 17.53 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 71 1 pp512 656.85 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 71 1 tg128 18.17 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 72 1 pp512 677.92 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 72 1 tg128 19.08 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 73 1 pp512 699.05 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 73 1 tg128 19.99 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 74 1 pp512 722.37 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 74 1 tg128 20.69 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 75 1 pp512 747.27 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 75 1 tg128 21.79 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 pp512 773.13 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 tg128 22.75 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 77 1 pp512 45.96 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 77 1 tg128 23.91 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 78 1 pp512 45.55 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 78 1 tg128 24.97 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 79 1 pp512 44.96 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 79 1 tg128 0.31 ± 0.00

The generation speed can be ~10% faster but the prompt processing speed is ~20x worse so I don't think this would ever be worth using. And if you offload too many layers the thrashing kills the generation performance too.

@matteoserva
Copy link
Contributor Author

matteoserva commented Jun 20, 2024

Thanks for testing the PR.
Here is the benchmark run on my system, with and without the PR. Follows a description of my system.
Notice that I increased the prompt to 8192 to use more memory. Both pp and tg are much faster with this PR.

With UNIFIED MEMORY (PR enabled)

GGML_CUDA_ENABLE_UNIFIED_MEMORY=1 ./llama-bench -m bartowski/Meta-Llama-3-70B-Instruct-IQ2_M.gguf -fa  1 -p 8192
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4060 Ti, compute capability 8.9, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
model size params backend ngl fa test t/s
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 pp8192 249.85 ± 14.20
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 tg128 9.27 ± 0.01

With offloading (master)

./llama-bench -m bartowski/Meta-Llama-3-70B-Instruct-IQ2_M.gguf -fa  1 -p 8192 -ngl 76 -t 6
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4060 Ti, compute capability 8.9, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
model size params backend ngl fa test t/s
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 pp8192 186.66 ± 0.18
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 tg128 5.16 ± 0.01

My system:

GPU

RTX 4060 16GB VRAM
RTX 3060 12GB VRAM
Total VRAM: 28GB
both connected to pcie3.0x8

CPU

intel 8700k with 6 phisical cores (12 virtual)

motherboard

Asus PRIME Z390-A

Additional info:

I also previously tested the same change with two rtx3060 12GB in parallel with a total of 24GB of VRAM

Possible explanation:

  • Slow CPU?
  • Pcie bottleneck

EDIT: more benchmarks with pp512

With unified memory

model size params backend ngl fa test t/s
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 pp512 223.83 ± 1.35
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 tg128 9.25 ± 0.01

master

model size params backend ngl fa test t/s
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 pp512 205.60 ± 0.38
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 tg128 5.15 ± 0.01

EDIT2:

I just noticed that you tested the PR without offloading the full model to GPU. The idea of unified memory is letting the GPU manage the transfer between RAM and VRAM. More benchmarks for comparison:

Unified memory enabled

model size params backend ngl fa test t/s
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 1 pp512 215.34 ± 0.56
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 1 tg128 6.91 ± 0.03
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 pp512 223.87 ± 0.38
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 tg128 9.26 ± 0.02

@Nexesenex
Copy link
Contributor

Nexesenex commented Jun 21, 2024

@matteoserva

Testing your PR on Windows 11, LCPP b3197 + your PR merged and the env variable GGML_CUDA_ENABLE_UNIFIED_MEMORY set on 1, I get this :

Q:\Lla\LLAMA_CUDA_121>llama-perplexity -m X:\text-generation-webui\Models\L3-70B-Euryale-v2.1-Imat-IQ4_SR.gguf -bf arc-challenge-validation.bin --multiple-choice -ngl 81 -b 128 -ts 54,27 --no-mmap -fa -c 512

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llm_load_tensors: ggml ctx size =    1.10 MiB
CUDA error: CUDA-capable device(s) is/are busy or unavailable
  current device: 1, in function ggml_cuda_set_device at Q:\GitHub\llama.cpp\ggml-cuda.cu:113
  cudaSetDevice(device)
GGML_ASSERT: Q:\GitHub\llama.cpp\ggml-cuda.cu:100: !"CUDA error"

It loads with GPU0, GPU1, but not with both.
As soon as I disable the env variable, things run as usual.

@matteoserva
Copy link
Contributor Author

Testing your PR on Windows 11, LCPP b3197 + your PR merged and the env variable GGML_CUDA_ENABLE_UNIFIED_MEMORY set on 1, I get this :
[...]

It loads with GPU0, GPU1, but not with both. As soon as I disable the env variable, things run as usual.

I think that this setting is managed directly by the nvidia CUDA driver in windows. Could you try changing the settings referenced in this page? https://nvidia.custhelp.com/app/answers/detail/a_id/5490/~/system-memory-fallback-for-stable-diffusion

Anyway I am going to install windows on my computer and test this myself in the next days.

@Nexesenex
Copy link
Contributor

I think that this setting is managed directly by the Nvidia CUDA driver in windows. Could you try changing the settings referenced in this page? https://nvidia.custhelp.com/app/answers/detail/a_id/5490/~/system-memory-fallback-for-stable-diffusion

I did that, and when system fallback is disabled, llama.cpp runs properly in bi-GPU. When renabled, problem occurs again. It basically acts indeed like the env variable.

Anyway I am going to install windows on my computer and test this myself in the next days.

crossed fingers, because I'd love to access the shared memory via Cuda on dual GPU

@mofosyne mofosyne added the Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix label Jun 21, 2024
@JohannesGaessler
Copy link
Collaborator

RTX 4090, Epyc 7742, 8x 3200 MHz RAM:

model size params backend ngl fa test t/s
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 81 1 pp512 44.61 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 81 1 tg128 0.15 ± 0.00

P40, Xeon E5-2683 v4, 4x 2133 MHz RAM:

model size params backend ngl fa test t/s
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 70 1 pp512 77.93 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 70 1 tg128 4.75 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 71 1 pp512 78.45 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 71 1 tg128 4.82 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 72 1 pp512 79.34 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 72 1 tg128 4.88 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 73 1 pp512 79.78 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 73 1 tg128 5.00 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 74 1 pp512 80.06 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 74 1 tg128 5.10 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 75 1 pp512 80.43 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 75 1 tg128 5.21 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 pp512 80.79 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 tg128 5.27 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 77 1 pp512 80.81 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 77 1 tg128 5.30 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 78 1 pp512 79.74 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 78 1 tg128 5.46 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 79 1 pp512 79.44 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 79 1 tg128 5.56 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 80 1 pp512 78.89 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 80 1 tg128 5.80 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 81 1 pp512 32.84 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 81 1 tg128 0.76 ± 0.00

For both hardware configurations I am not able to get better performance than with master. However, if there are people for which this does provide better performance I would still merge this PR since it's a very simple and self-contained changed that is not the default. But please also add some brief documentation to the README that explains how to enable this option.

@JohannesGaessler
Copy link
Collaborator

RTX 3090, Razen 5950X, 2x 3200 MHz RAM:

model size params backend ngl fa test t/s
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 70 1 pp512 387.01 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 70 1 tg128 8.32 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 71 1 pp512 394.24 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 71 1 tg128 8.71 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 72 1 pp512 401.56 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 72 1 tg128 9.21 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 73 1 pp512 39.39 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 73 1 tg128 9.74 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 74 1 pp512 60.15 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 74 1 tg128 10.37 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 75 1 pp512 59.81 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 75 1 tg128 11.10 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 pp512 59.51 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 76 1 tg128 0.22 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 81 1 pp512 56.52 ± 0.00
llama 70B Q2_K - Medium 23.71 GiB 68.98 B CUDA 81 1 tg128 0.15 ± 0.00

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jun 23, 2024

I just noticed, are you testing with LLAMA_CUDA_FORCE_MMQ and --flash-attention? With those the amount of temporary buffers needed for inference should be greatly reduced.

@matteoserva
Copy link
Contributor Author

matteoserva commented Jun 23, 2024

Thanks for all your testing. I rerun the benchmark with FORCE_MMQ and I'm getting the same results. My guess is that in my specific hardware configuration the actual computation in the CPU is so slow that streaming the data to the GPU is actually faster.

Another difference I can see in our setups is that I'm using IQ quants and you are using Q2_K. Could this be the culprit?

I added the README to the PR.

Here are the benchmark with FORCE_MMQ

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   yes
ggml_cuda_init: CUDA_USE_TENSOR_CORES: no
model size params backend ngl fa test t/s PR
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 pp512 220.35 ± 2.33 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 tg128 9.17 ± 0.06 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 pp8192 255.76 ± 5.93 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 1 tg512 9.32 ± 0.00 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 1 pp512 214.77 ± 1.74 no
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 1 tg128 6.78 ± 0.05 no
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 pp8192 188.16 ± 0.21 no
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 tg512 5.17 ± 0.00 no
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 0 pp512 210.88 ± 0.79 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 0 tg128 9.12 ± 0.02 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 0 pp8192 89.18 ± 4.93 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 99 0 tg512 9.19 ± 0.00 yes
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 0 pp512 205.75 ± 1.02 no
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 0 tg128 6.79 ± 0.03 no
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 74 0 pp8192 136.54 ± 0.03 no
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 74 0 tg512 4.40 ± 0.00 no

@JohannesGaessler
Copy link
Collaborator

I procured a qi2_m model. The best performance I could get on master with RTX 4090, Epyc 7742, and 8x 3200 MHz RAM is:

model size params backend ngl fa test t/s
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 72 1 pp8192 572.80 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 72 1 tg128 17.34 ± 0.00

With this PR I get this performance:

model size params backend ngl fa test t/s
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 70 0 pp8192 432.39 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 70 0 tg128 15.53 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 70 1 pp8192 538.03 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 70 1 tg128 15.99 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 71 0 pp8192 439.14 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 71 0 tg128 16.07 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 71 1 pp8192 552.28 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 71 1 tg128 16.62 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 72 0 pp8192 448.57 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 72 0 tg128 16.87 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 72 1 pp8192 567.12 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 72 1 tg128 17.21 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 73 0 pp8192 388.96 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 73 0 tg128 17.56 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 73 1 pp8192 583.31 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 73 1 tg128 18.07 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 74 0 pp8192 282.11 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 74 0 tg128 16.94 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 74 1 pp8192 597.74 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 74 1 tg128 17.49 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 75 0 pp8192 209.05 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 75 0 tg128 18.01 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 75 1 pp8192 613.11 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 75 1 tg128 18.79 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 0 pp8192 165.12 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 0 tg128 19.39 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 pp8192 321.75 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 76 1 tg128 20.01 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 77 0 pp8192 136.76 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 77 0 tg128 20.20 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 77 1 pp8192 339.15 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 77 1 tg128 21.19 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 78 0 pp8192 91.00 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 78 0 tg128 20.96 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 78 1 pp8192 184.71 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 78 1 tg128 22.47 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 0 pp8192 58.13 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 0 tg128 22.83 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 1 pp8192 143.29 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 79 1 tg128 23.62 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 80 0 pp8192 48.13 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 80 0 tg128 25.12 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 80 1 pp8192 153.06 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 80 1 tg128 26.06 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 81 0 pp8192 47.48 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 81 0 tg128 29.80 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 81 1 pp8192 102.01 ± 0.00
llama 70B IQ2_M - 2.7 bpw 22.46 GiB 70.55 B CUDA 81 1 tg128 31.09 ± 0.00

Peak prompt processing performance is higher. Token generation continues to improve but that doesn't OOM anyways and is the same speed as on master.

You said you wanted to test Windows. Do you still plan to do that?

@matteoserva
Copy link
Contributor Author

First, thanks again for benchmarking so extensively this PR.

From my understanding of your results, you hit a sweet spot at ngl 75 where both PP and TG were higher than in master. Then another sweet spot at ngl 81 where pp is much lower but tg went from the 17.34 of master to 31.09, which is a speedup of 1.79. The PR probably benefits IQ quants more than others.

My understanding of the NVIDIA docs is that the windows CUDA driver implements the same functionality of this PR, so the PR wouldn't make any difference in windows. I think a windows test is not warranted. Of course I'm open to other opinions. I could possibly test this on windows this weekend.

@JohannesGaessler
Copy link
Collaborator

My understanding of the NVIDIA docs is that the windows CUDA driver implements the same functionality of this PR, so the PR wouldn't make any difference in windows. I think a windows test is not warranted. Of course I'm open to other opinions. I could possibly test this on windows this weekend.

I'm only asking because you said:

Anyway I am going to install windows on my computer and test this myself in the next days.

If you are not going to do that then I think we should move to finalizing the PR.

@matteoserva
Copy link
Contributor Author

Ok. So I think that the PR is ready.

@slaren
Copy link
Member

slaren commented Jun 26, 2024

Is the total memory used between model + context + compute buffer actually larger than the capacity of the GPU in the cases where it seems to improve performance? This models is 22.46 GiB, so it should fit on a 24GB GPU with very low context.

@matteoserva
Copy link
Contributor Author

Is the total memory used between model + context + compute buffer actually larger than the capacity of the GPU in the cases where it seems to improve performance? This models is 22.46 GiB, so it should fit on a 24GB GPU with very low context.

Yes. The total memory must be slightly larger than the total memory in the GPU.

If it's smaller, then there is no improvement.
If it's much larger, then a better solution is to partial offload the model by setting -ngl

@slaren
Copy link
Member

slaren commented Jun 26, 2024

I tried Meta-Llama-3-70B-Instruct-IQ2_XS.gguf, without additional parameters (ie. max context, no fa), since that's what mentioned. I get this memory usage:

llm_load_tensors:      CUDA0 buffer size = 19826.41 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =  2560.00 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1104.00 MiB

The total is 23490.41 which is definitely lower than 24GB.

@slaren
Copy link
Member

slaren commented Jun 26, 2024

The reason I ask is because the only way I can see this would work at all is if there is memory fragmentation that prevents the model buffer from being allocated in the first place.

Can you test 1e6e363? This change will allocate one buffer per tensor and should avoid any fragmentation issues.

@matteoserva
Copy link
Contributor Author

The reason I ask is because the only way I can see this would work at all is if there is memory fragmentation that prevents the model buffer from being allocated in the first place.

Can you test 1e6e363? This change will allocate one buffer per tensor and should avoid any fragmentation issues.

I don't have the 24GB anymore since I replaced the graphic cards in my PC. Now I have a RTX4060 16GB + RTX3060 12GB.
I tried to replicate the exact same issue I was facing by using the IQ2_M model.
Both branches crash with the same error.

llama-cli

Here is the launch command:
CUDA_VISIBLE_DEVICES=0,1 ./llama-cli -ngl 79 -t 6 -c 8192 -m ~/tmp/models_cache/Meta-Llama-3-70B-Instruct-IQ2_M.gguf -p "Hello"

Here is the relevant portion of the log by llama-cli:

ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4060 Ti, compute capability 8.9, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llm_load_tensors: ggml ctx size =    1,01 MiB
llm_load_tensors: offloading 79 repeating layers to GPU
llm_load_tensors: offloaded 79/81 layers to GPU
llm_load_tensors:        CPU buffer size = 22994,45 MiB
llm_load_tensors:      CUDA0 buffer size = 12657,75 MiB
llm_load_tensors:      CUDA1 buffer size =  8922,38 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 8192
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500000,0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    32,00 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =  1472,00 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =  1056,00 MiB
llama_new_context_with_model: KV self size  = 2560,00 MiB, K (f16): 1280,00 MiB, V (f16): 1280,00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0,49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1108,50 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =  1104,00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    32,01 MiB
llama_new_context_with_model: graph nodes  = 2566
llama_new_context_with_model: graph splits = 16
CUDA error: CUBLAS_STATUS_NOT_INITIALIZED
  current device: 1, in function cublas_handle at ggml-cuda/common.cuh:796
  cublasCreate_v2(&cublas_handles[device])
GGML_ASSERT: ggml-cuda.cu:100: !"CUDA error"
[New LWP 15866]
[... repeated]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007f8d52af2b57 in __GI___wait4 (pid=15909, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      ../sysdeps/unix/sysv/linux/wait4.c: File o directory non esistente.
#0  0x00007f8d52af2b57 in __GI___wait4 (pid=15909, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x00005605fc28eb9b in ggml_print_backtrace ()
#2  0x00005605fc3e8557 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) ()
#3  0x00005605fc3eacf2 in ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*) ()
#4  0x00005605fc3eeda0 in ggml_cuda_mul_mat(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*) ()
#5  0x00005605fc3f0c8d in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) ()
#6  0x00005605fc51e8cd in ggml_backend_sched_graph_compute_async ()
#7  0x00005605fc2ef37a in llama_decode ()
#8  0x00005605fc37d697 in llama_init_from_gpt_params(gpt_params&) ()
#9  0x00005605fc2868df in main ()

llama-server

  • The llama-server crashes with the same error if I set -ngl to 79.
  • The server starts correctly with -ngl 78 but then it crashes during inference (When I send a prompt)

Server launch command:
CUDA_VISIBLE_DEVICES=0,1 ./llama-server -ngl 175 -t 6 -c 8192 --host 0.0.0.0 -m ~/tmp/models_cache/Meta-Llama-3-70B-Instruct-IQ2_M.gguf -ngl 78

Server error log:

ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4060 Ti, compute capability 8.9, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llm_load_tensors: ggml ctx size =    1.01 MiB
llm_load_tensors: offloading 78 repeating layers to GPU
llm_load_tensors: offloaded 78/81 layers to GPU
llm_load_tensors:        CPU buffer size = 22994.45 MiB
llm_load_tensors:      CUDA0 buffer size = 12633.25 MiB
llm_load_tensors:      CUDA1 buffer size =  8652.00 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 8192
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    64.00 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =  1472.00 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =  1024.00 MiB
llama_new_context_with_model: KV self size  = 2560.00 MiB, K (f16): 1280.00 MiB, V (f16): 1280.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.98 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1108.50 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =  1104.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    32.01 MiB
llama_new_context_with_model: graph nodes  = 2566
llama_new_context_with_model: graph splits = 27
INFO [                    init] initializing slots | tid="140479682437120" timestamp=1719422754 n_slots=1
INFO [                    init] new slot | tid="140479682437120" timestamp=1719422754 id_slot=0 n_ctx_slot=8192
INFO [                    main] model loaded | tid="140479682437120" timestamp=1719422754
INFO [                    main] chat template | tid="140479682437120" timestamp=1719422754 chat_example="<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHow are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" built_in=true
INFO [                    main] HTTP server listening | tid="140479682437120" timestamp=1719422754 n_threads_http="11" port="8080" hostname="0.0.0.0"
INFO [            update_slots] all slots are idle | tid="140479682437120" timestamp=1719422754
INFO [      log_server_request] request | tid="140423569248256" timestamp=1719422758 remote_addr="127.0.0.1" remote_port=33910 status=200 method="GET" path="/props" params={}
INFO [      log_server_request] request | tid="140423588540416" timestamp=1719422758 remote_addr="127.0.0.1" remote_port=33920 status=200 method="POST" path="/tokenize" params={}
INFO [   launch_slot_with_task] slot is processing task | tid="140479682437120" timestamp=1719422758 id_slot=0 id_task=0
INFO [            update_slots] kv cache rm [p0, end) | tid="140479682437120" timestamp=1719422758 id_slot=0 id_task=0 p0=0
CUDA error: out of memory
  current device: 0, in function alloc at ggml-cuda.cu:359
  cuMemCreate(&handle, reserve_size, &prop, 0)
GGML_ASSERT: ggml-cuda.cu:100: !"CUDA error"
[New LWP 17044]
[... repeated]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007fc3f36f2b57 in __GI___wait4 (pid=17135, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      ../sysdeps/unix/sysv/linux/wait4.c: File o directory non esistente.
#0  0x00007fc3f36f2b57 in __GI___wait4 (pid=17135, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x0000562292f1b19b in ggml_print_backtrace ()
#2  0x00005622930780b7 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) ()
#3  0x00005622930821b7 in ggml_cuda_pool_vmm::alloc(unsigned long, unsigned long*) ()
#4  0x000056229307ad69 in ggml_cuda_op_mul_mat_cublas(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, char const*, float const*, char const*, float*, long, long, long, long, CUstream_st*) ()
#5  0x000056229307d640 in ggml_cuda_op_mul_mat(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, void (*)(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, char const*, float const*, char const*, float*, long, long, long, long, CUstream_st*), void (*)(float const*, void*, long, long, long, long, ggml_type, CUstream_st*)) ()
#6  0x000056229307e6f5 in ggml_cuda_mul_mat(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*) ()
#7  0x00005622930807ed in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) ()
#8  0x00005622931ae3cd in ggml_backend_sched_graph_compute_async ()
#9  0x0000562292f7bd6a in llama_decode ()
#10 0x0000562293260413 in server_context::update_slots() ()
#11 0x0000562293244746 in server_queue::start_loop() ()
#12 0x0000562292f1854f in main ()

Here is the output of nvidia-smi at idle (I have a X11 server running on GPU1, the 12GB one)

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
|  0%   39C    P8    10W / 165W |      6MiB / 16380MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:02:00.0  On |                  N/A |
|  0%   46C    P8    19W / 170W |    492MiB / 12288MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

I didn't include the logs of both branches because I see no differences.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jul 6, 2024
@Gingeropolous
Copy link

im trying to compile the PR because i keep getting oomed (running 2x vega 64 8GB... yeah, i know....)

ggml/src/ggml-cuda.cu:563:15: error: use of undeclared identifier 'cudaMallocManaged'; did you mean 'hipMallocManaged'?
563 | err = cudaMallocManaged(&dev_ptr, size);
| ^~~~~~~~~~~~~~~~~
| hipMallocManaged
/opt/rocm-6.1.2/include/hip/hip_runtime_api.h:3120:12: note: 'hipMallocManaged' declared here
3120 | hipError_t hipMallocManaged(void** dev_ptr,
| ^

@matteoserva
Copy link
Contributor Author

hipMallocManaged

The latest llama.cpp master should already support unified memory on AMD device.
The relevant compile flag is GGML_HIP_UMA.
There is an explanation in the documentation: https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jul 30, 2024
At the correct location, per the 4th commit of the PR.
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the long radio silence. Since the changes in this PR are comparatively simple I think it's fine to merge as long as the feature is documented and as long as said documentation is not misleading. If the documentation simply says that the option exists and that it enables swapping VRAM to RAM on Linux then I think that would be acceptable. I would only be comfortable with positive performance claims if we can reach consensus regarding the conditions under which the option helps with performance.

@matteoserva
Copy link
Contributor Author

I removed the line in the documentation about the performance improvement in that specific case. I prefer not making any performance claim since I don't have enough data.

The PR should be ready for merging.

@JohannesGaessler JohannesGaessler merged commit afbb4c1 into ggml-org:master Aug 1, 2024
53 checks passed
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Aug 2, 2024
* Adding support for unified memory

* adding again the documentation about unified memory

* refactoring: Moved the unified memory code in the correct location.

* Fixed compilation error when using hipblas

* cleaning up the documentation

* Updating the documentation

Co-authored-by: Johannes Gäßler <[email protected]>

* adding one more case where the PR should not be enabled

---------

Co-authored-by: matteo serva <[email protected]>
Co-authored-by: Johannes Gäßler <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants