Skip to content

Vulkan Improvements #5835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 5, 2024
Merged

Vulkan Improvements #5835

merged 15 commits into from
Mar 5, 2024

Conversation

0cc4m
Copy link
Collaborator

@0cc4m 0cc4m commented Mar 2, 2024

Here's a batch of Vulkan improvements:

  • Shader optimizations, especially for legacy quants to get them up to speed
    • I also optimized them for AMD GCN GPUs
  • Allow simple ops (add, mul, etc) to work on non-continuous tensors
  • Simplify and standardize matrix matrix multiplication shader selection
  • Add dequant matrix matrix multiplication shaders for legacy quants
    • k-quants will come later, they are not as simple to implement
  • Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size as workaround for Vulkan Device memory allocation failed (ErrorOutOfDeviceMemory ) #5441
  • Update soft_max to support ALIBI

I'll add some benchmarks of my GPUs later. Let me know what you think.

@Artefact2
Copy link
Collaborator

Q4_0 benchmarks on gfx1030/radv
model size params backend ngl test t/s
llama 13B Q4_0 6.88 GiB 13.02 B Vulkan pr 99 pp 512 203.97 ± 1.45
llama 13B Q4_0 6.88 GiB 13.02 B Vulkan pr 99 tg 128 23.97 ± 0.09
llama 13B Q4_0 6.88 GiB 13.02 B Vulkan master 99 pp 512 160.00 ± 1.69
llama 13B Q4_0 6.88 GiB 13.02 B Vulkan master 99 tg 128 13.62 ± 0.03
llama 13B Q4_0 6.88 GiB 13.02 B ROCm 99 pp 512 499.00 ± 0.15
llama 13B Q4_0 6.88 GiB 13.02 B ROCm 99 tg 128 40.98 ± 0.01

@sorasoras
Copy link

@0cc4m what about IQ quants?
I think IQ4XS is a upgrade of Q4KM in most case.
Thanks.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 2, 2024

@0cc4m what about IQ quants? I think IQ4XS is a upgrade of Q4KM in most case. Thanks.

Sure, but they're quite a bit of work to implement. I'll get to them eventually, but especially MoE takes precedence.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 2, 2024

Here are some benchmarks:

Vulkan0: AMD Radeon Pro VII (RADV VEGA20) | uma: 0 | fp16: 1 | warp size: 64

model size params backend ngl test t/s
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 pp 512 168.31 ± 0.78
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 tg 128 10.03 ± 0.35
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 pp 512 181.23 ± 0.82
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 tg 128 19.24 ± 0.81
llama 13B Q4_0 6.86 GiB 13.02 B ROCm 99 pp 512 271.75 ± 0.78
llama 13B Q4_0 6.86 GiB 13.02 B ROCm 99 tg 128 45.41 ± 0.05

Vulkan0: AMD Radeon RX 6800 XT (RADV NAVI21) | uma: 0 | fp16: 1 | warp size: 64

model size params backend ngl test t/s
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 pp 512 276.86 ± 5.49
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 tg 128 20.15 ± 0.33
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 pp 512 374.73 ± 8.15
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 tg 128 34.21 ± 1.14

(ROCm failed to run on that PC for whatever reason)

Vulkan0: Intel(R) Arc(tm) A770 Graphics (DG2) | uma: 0 | fp16: 1 | warp size: 32

model size params backend ngl test t/s
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 pp 512 109.11 ± 2.30
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 tg 128 2.82 ± 0.03
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 pp 512 100.91 ± 2.09
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 tg 128 16.18 ± 1.17
llama 13B Q4_0 6.86 GiB 13.02 B SYCL 99 pp 512 453.10 ± 46.87
llama 13B Q4_0 6.86 GiB 13.02 B SYCL 99 tg 128 18.64 ± 0.01

Vulkan0: NVIDIA GeForce RTX 3090 | uma: 0 | fp16: 1 | warp size: 32

model size params backend ngl test t/s
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 pp 512 317.80 ± 14.31
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (Master) 99 tg 128 27.32 ± 1.17
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 pp 512 492.58 ± 32.60
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan (PR) 99 tg 128 26.20 ± 0.18
llama 13B Q4_0 6.86 GiB 13.02 B CUDA 99 pp 512 2452.13 ± 38.87
llama 13B Q4_0 6.86 GiB 13.02 B CUDA 99 tg 128 82.22 ± 0.02

@Nindaleth
Copy link
Contributor

Nindaleth commented Mar 3, 2024

Here are my numbers:
Radeon RX 6700 XT, Ryzen 5700X ECO, model mistral-7b-instruct-v0.2 fully oflloaded to the GPU. ROCm 6.0.

I've only tested the known-to-be-much-slower quants + the Q4_K_S known to be already faster than ROCm. Very nice speed improvements all across the board.

Vulkan0: AMD Radeon RX 6700 XT (RADV NAVI22) | uma: 0 | fp16: 1 | warp size: 64

model size params backend ngl test t/s
llama 7B Q4_0 3.83 GiB 7.24 B Vulkan (master) 99 pp 512 309.00 ± 7.35
llama 7B Q4_0 3.83 GiB 7.24 B Vulkan (master) 99 tg 128 23.68 ± 0.34
llama 7B Q4_0 3.83 GiB 7.24 B Vulkan (PR) 99 pp 512 425.07 ± 12.91
llama 7B Q4_0 3.83 GiB 7.24 B Vulkan (PR) 99 tg 128 41.81 ± 1.27
llama 7B Q4_0 3.83 GiB 7.24 B ROCm 99 pp 512 1017.28 ± 5.60
llama 7B Q4_0 3.83 GiB 7.24 B ROCm 99 tg 128 58.97 ± 0.21
llama 7B Q4_1 4.24 GiB 7.24 B Vulkan (master) 99 pp 512 306.82 ± 7.91
llama 7B Q4_1 4.24 GiB 7.24 B Vulkan (master) 99 tg 128 23.39 ± 0.39
llama 7B Q4_1 4.24 GiB 7.24 B Vulkan (PR) 99 pp 512 358.06 ± 9.70
llama 7B Q4_1 4.24 GiB 7.24 B Vulkan (PR) 99 tg 128 41.32 ± 1.40
llama 7B Q4_1 4.24 GiB 7.24 B ROCm 99 pp 512 974.39 ± 0.82
llama 7B Q4_1 4.24 GiB 7.24 B ROCm 99 tg 128 54.88 ± 0.57
llama 7B Q5_0 4.65 GiB 7.24 B Vulkan (master) 99 pp 512 309.18 ± 8.64
llama 7B Q5_0 4.65 GiB 7.24 B Vulkan (master) 99 tg 128 17.65 ± 0.25
llama 7B Q5_0 4.65 GiB 7.24 B Vulkan (PR) 99 pp 512 357.03 ± 8.72
llama 7B Q5_0 4.65 GiB 7.24 B Vulkan (PR) 99 tg 128 29.65 ± 0.38
llama 7B Q5_0 4.65 GiB 7.24 B ROCm 99 pp 512 875.97 ± 0.45
llama 7B Q5_0 4.65 GiB 7.24 B ROCm 99 tg 128 52.33 ± 0.04
llama 7B Q5_1 5.07 GiB 7.24 B Vulkan (master) 99 pp 512 301.61 ± 7.63
llama 7B Q5_1 5.07 GiB 7.24 B Vulkan (master) 99 tg 128 17.85 ± 0.38
llama 7B Q5_1 5.07 GiB 7.24 B Vulkan (PR) 99 pp 512 356.00 ± 9.50
llama 7B Q5_1 5.07 GiB 7.24 B Vulkan (PR) 99 tg 128 28.87 ± 0.47
llama 7B Q5_1 5.07 GiB 7.24 B ROCm 99 pp 512 864.23 ± 16.58
llama 7B Q5_1 5.07 GiB 7.24 B ROCm 99 tg 128 48.27 ± 0.19
llama 7B Q4_K - Small 3.86 GiB 7.24 B Vulkan (master) 99 pp 512 356.89 ± 10.47
llama 7B Q4_K - Small 3.86 GiB 7.24 B Vulkan (master) 99 tg 128 50.49 ± 1.41
llama 7B Q4_K - Small 3.86 GiB 7.24 B Vulkan (PR) 99 pp 512 352.91 ± 9.94
llama 7B Q4_K - Small 3.86 GiB 7.24 B Vulkan (PR) 99 tg 128 54.58 ± 2.35
llama 7B Q4_K - Small 3.86 GiB 7.24 B ROCm 99 pp 512 838.06 ± 0.35
llama 7B Q4_K - Small 3.86 GiB 7.24 B ROCm 99 tg 128 49.70 ± 0.06
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (master) 99 pp 512 356.70 ± 10.60
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (master) 99 tg 128 49.21 ± 1.43
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (PR) 99 pp 512 353.23 ± 9.58
llama 7B Q4_K - Medium 4.07 GiB 7.24 B Vulkan (PR) 99 tg 128 54.41 ± 2.23
llama 7B Q4_K - Medium 4.07 GiB 7.24 B ROCm 99 pp 512 827.09 ± 1.75
llama 7B Q4_K - Medium 4.07 GiB 7.24 B ROCm 99 tg 128 49.19 ± 0.42
llama 7B Q5_K - Medium 4.78 GiB 7.24 B Vulkan (master) 99 pp 512 355.22 ± 11.17
llama 7B Q5_K - Medium 4.78 GiB 7.24 B Vulkan (master) 99 tg 128 41.38 ± 1.18
llama 7B Q5_K - Medium 4.78 GiB 7.24 B Vulkan (PR) 99 pp 512 353.14 ± 9.81
llama 7B Q5_K - Medium 4.78 GiB 7.24 B Vulkan (PR) 99 tg 128 45.67 ± 1.64
llama 7B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 pp 512 818.90 ± 0.33
llama 7B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 tg 128 47.00 ± 0.03
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (master) 99 pp 512 336.32 ± 9.85
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (master) 99 tg 128 41.38 ± 1.12
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (PR) 99 pp 512 333.10 ± 7.90
llama 7B Q6_K 5.53 GiB 7.24 B Vulkan (PR) 99 tg 128 44.35 ± 1.45
llama 7B Q6_K 5.53 GiB 7.24 B ROCm 99 pp 512 772.27 ± 5.53
llama 7B Q6_K 5.53 GiB 7.24 B ROCm 99 tg 128 45.06 ± 0.16
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (master) 99 pp 512 306.12 ± 10.09
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (master) 99 tg 128 22.52 ± 0.27
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (PR) 99 pp 512 352.81 ± 9.02
llama 7B Q8_0 7.17 GiB 7.24 B Vulkan (PR) 99 tg 128 33.43 ± 1.15
llama 7B Q8_0 7.17 GiB 7.24 B ROCm 99 pp 512 1030.44 ± 4.23
llama 7B Q8_0 7.17 GiB 7.24 B ROCm 99 tg 128 37.21 ± 0.22

@Nindaleth
Copy link
Contributor

Couldn't resist testing the commonly recommended Q4_K_M and Q5_K_M, edited my table. Q4_K_* tg is epic, Q5_K_M tg is excellent.

@daniandtheweb
Copy link
Contributor

daniandtheweb commented Mar 3, 2024

Here are some benchmarks on an AMD Radeon RX 5700 XT, the results are quite impressive since the prompt processing is now faster than ROCm on q4_0 (ROCm is not particularly optimized for this card so that may be the reason for the not so impressive results).

Model: llama 2

Vulkan0: AMD Radeon RX 5700 XT (RADV NAVI10) | uma: 0 | fp16: 1 | warp size: 64

model size params backend ngl test t/s speedup
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (master) 99 pp 512 177.59 ± 0.75
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (master) 99 tg 128 22.76 ± 0.09
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (PR) 99 pp 512 323.04 ± 3.48 1.81
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan (PR) 99 tg 128 37.80 ± 0.14 1.66
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 pp 512 318.71 ± 0.70
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 tg 128 60.44 ± 0.06
llama 7B Q4_1 3.95 GiB 6.74 B Vulkan (master) 99 pp 512 177.18 ± 1.19
llama 7B Q4_1 3.95 GiB 6.74 B Vulkan (master) 99 tg 128 22.36 ± 0.22
llama 7B Q4_1 3.95 GiB 6.74 B Vulkan (PR) 99 pp 512 189.11 ± 1.13 1.06
llama 7B Q4_1 3.95 GiB 6.74 B Vulkan (PR) 99 tg 128 38.13 ± 0.08 1.71
llama 7B Q4_1 3.95 GiB 6.74 B ROCm 99 pp 512 314.79 ± 0.63
llama 7B Q4_1 3.95 GiB 6.74 B ROCm 99 tg 128 56.92 ± 0.02
llama 7B Q5_0 4.33 GiB 6.74 B Vulkan (master) 99 pp 512 177.06 ± 0.79
llama 7B Q5_0 4.33 GiB 6.74 B Vulkan (master) 99 tg 128 17.39 ± 0.03
llama 7B Q5_0 4.33 GiB 6.74 B Vulkan (PR) 99 pp 512 187.29 ± 1.29 1.06
llama 7B Q5_0 4.33 GiB 6.74 B Vulkan (PR) 99 tg 128 27.96 ± 0.21 1.61
llama 7B Q5_0 4.33 GiB 6.74 B ROCm 99 pp 512 306.27 ± 0.50
llama 7B Q5_0 4.33 GiB 6.74 B ROCm 99 tg 128 51.86 ± 0.05
llama 7B Q5_1 4.72 GiB 6.74 B Vulkan (master) 99 pp 512 176.66 ± 1.02
llama 7B Q5_1 4.72 GiB 6.74 B Vulkan (master) 99 tg 128 17.78 ± 0.04
llama 7B Q5_1 4.72 GiB 6.74 B Vulkan (PR) 99 pp 512 188.32 ± 1.20 1.07
llama 7B Q5_1 4.72 GiB 6.74 B Vulkan (PR) 99 tg 128 28.17 ± 0.03 1.58
llama 7B Q5_1 4.72 GiB 6.74 B ROCm 99 pp 512 301.39 ± 0.74
llama 7B Q5_1 4.72 GiB 6.74 B ROCm 99 tg 128 49.03 ± 0.53

For the remaining quants it seems like the improvements are mostly related to the token generation while the prompt processing has quite low but consistent gains compared to the master branch.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Mar 5, 2024

@ggerganov @slaren Can one of you approve the minimal change to llama.cpp?

The flake8 linting issue isn't coming from my changes.

@0cc4m 0cc4m merged commit 61d1c88 into master Mar 5, 2024
@0cc4m 0cc4m deleted the 0cc4m/vulkan-improvements branch March 5, 2024 12:33
hazelnutcloud pushed a commit to hazelnutcloud/llama.cpp that referenced this pull request Mar 10, 2024
* Improve dequant shaders, add fast q4_0 dequant

* Optimize dmmv non-kquants for GCN

Remove unnecessary SPIR-V shader duplication

* Fix q4_0 dequant dispatch sizes

Fix backend free bug

* Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0

* Add unary and binary op shader templates

* Fix Vulkan check results

* Enable non-contiguous support for simple ops

* Add argsort

Basic q4_0 mmq shader and unit test

* Speed up q4_0 dequant code, enable mmq for q4_0

* Rework matmul pipeline selection

* Add soft_max alibi support

* Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders

* Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size

Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* Improve dequant shaders, add fast q4_0 dequant

* Optimize dmmv non-kquants for GCN

Remove unnecessary SPIR-V shader duplication

* Fix q4_0 dequant dispatch sizes

Fix backend free bug

* Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0

* Add unary and binary op shader templates

* Fix Vulkan check results

* Enable non-contiguous support for simple ops

* Add argsort

Basic q4_0 mmq shader and unit test

* Speed up q4_0 dequant code, enable mmq for q4_0

* Rework matmul pipeline selection

* Add soft_max alibi support

* Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders

* Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size

Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency
@akingoverlook
Copy link

This probably wasn't the focus, but FYI, things are still pretty broken with Adreno. Finding a compute queue is fixed, but dequant_q4k and dequant_q5k will still choke Adreno (unknown error) on creating the pipeline, and aside from that, this backend will still die with DeviceLost on submit() if more than a few layers are offloaded to GPU.

That will still happen if the new envar is restricting the max allocation too.

hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* Improve dequant shaders, add fast q4_0 dequant

* Optimize dmmv non-kquants for GCN

Remove unnecessary SPIR-V shader duplication

* Fix q4_0 dequant dispatch sizes

Fix backend free bug

* Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0

* Add unary and binary op shader templates

* Fix Vulkan check results

* Enable non-contiguous support for simple ops

* Add argsort

Basic q4_0 mmq shader and unit test

* Speed up q4_0 dequant code, enable mmq for q4_0

* Rework matmul pipeline selection

* Add soft_max alibi support

* Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders

* Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size

Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants