Skip to content

metal: implement flash attention kernel for quantized KV cache #9735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

FanShupei
Copy link
Contributor

This PR is mainly for discussion. The strategy and code quality is far from being merged.

To support quantized KV cache. I write a new FA kernel similar to kernel_flash_attn_ext_vec_f16 and add dequantization support. Since 'kernel_flash_attn_ext_vec_f16' use vce4 extensively thus forces D is at least 128. I write a new version using only scalars, then D is only required to be multiple of 32.

I only implement ctk = ctv = q8_0 as a proof of concept .The code is generic and support of other formats could be added easily.

measurement (before this PR)

model size params backend ngl fa test t/s
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp64 269.44 ± 0.22
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp128 276.87 ± 0.08
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp512 280.58 ± 0.06
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp2048 264.88 ± 0.02
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 tg128 30.97 ± 0.03

measurement (after this PR)

For measurement purpose, this PR forces all FA code path uses the new 'kernel_flash_attn_ext_scalar_f16'.

I observe that prefill slows down severely in long input case (131tok/s vs 265tok/s when pp2048).

model size params backend ngl fa test t/s
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp64 227.37 ± 1.04
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp128 233.13 ± 0.11
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp512 198.81 ± 0.09
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 pp2048 131.52 ± 0.03
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 1 tg128 31.00 ± 0.03
model size params backend ngl type_k type_v fa test t/s
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 q8_0 q8_0 1 pp64 226.74 ± 0.95
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 q8_0 q8_0 1 pp128 231.97 ± 0.21
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 q8_0 q8_0 1 pp512 193.76 ± 0.05
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 q8_0 q8_0 1 pp2048 126.06 ± 0.01
phi3 3B Q4_K - Medium 2.23 GiB 3.82 B Metal 99 q8_0 q8_0 1 tg128 30.84 ± 0.01

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant