Skip to content

CUDA: deduplicate FlashAttention code #7352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

Follow-up to #7314 .

This PR deduplicates the CUDA FlashAttention code, mainly the code around launching the kernels. For the kernels themselves I only deduplicated the ALiBi slope calculation.

There are also two minor fixes:

  • Fixed the vec kernels write the meta data multiple times to VRAM.
  • Fixed the tensor core kernel launch code calculating the number of blocks as parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block instead of parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block). I don't think that there is a case on master with which this leads to incorrect results though.

Copy link
Contributor

github-actions bot commented May 17, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 544 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8605.39ms p(95)=20901.68ms fails=, finish reason: stop=492 truncated=52
  • Prompt processing (pp): avg=96.78tk/s p(95)=381.39tk/s
  • Token generation (tg): avg=58.33tk/s p(95)=47.55tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-fattn-refactor-4 commit=4d9e90ca98f24d905827199e9427c1a2f83ccef2

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716026972 --> 1716027600
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 773.96, 773.96, 773.96, 773.96, 773.96, 714.88, 714.88, 714.88, 714.88, 714.88, 719.83, 719.83, 719.83, 719.83, 719.83, 759.84, 759.84, 759.84, 759.84, 759.84, 817.85, 817.85, 817.85, 817.85, 817.85, 813.36, 813.36, 813.36, 813.36, 813.36, 821.49, 821.49, 821.49, 821.49, 821.49, 839.16, 839.16, 839.16, 839.16, 839.16, 831.5, 831.5, 831.5, 831.5, 831.5, 845.33, 845.33, 845.33, 845.33, 845.33, 843.42, 843.42, 843.42, 843.42, 843.42, 848.63, 848.63, 848.63, 848.63, 848.63, 878.25, 878.25, 878.25, 878.25, 878.25, 875.33, 875.33, 875.33, 875.33, 875.33, 899.81, 899.81, 899.81, 899.81, 899.81, 839.65, 839.65, 839.65, 839.65, 839.65, 836.19, 836.19, 836.19, 836.19, 836.19, 835.46, 835.46, 835.46, 835.46, 835.46, 856.09, 856.09, 856.09, 856.09, 856.09, 858.08, 858.08, 858.08, 858.08, 858.08, 856.08, 856.08, 856.08, 856.08, 856.08, 862.13, 862.13, 862.13, 862.13, 862.13, 865.15, 865.15, 865.15, 865.15, 865.15, 882.96, 882.96, 882.96, 882.96, 882.96, 877.6, 877.6, 877.6, 877.6, 877.6, 878.95, 878.95, 878.95, 878.95, 878.95, 875.58, 875.58, 875.58, 875.58, 875.58, 870.54, 870.54, 870.54, 870.54, 870.54, 870.14, 870.14, 870.14, 870.14, 870.14, 870.88, 870.88, 870.88, 870.88, 870.88, 874.42, 874.42, 874.42, 874.42, 874.42, 872.21, 872.21, 872.21, 872.21, 872.21, 875.59, 875.59, 875.59, 875.59, 875.59, 873.15, 873.15, 873.15, 873.15, 873.15, 877.23, 877.23, 877.23, 877.23, 877.23, 876.48, 876.48, 876.48, 876.48, 876.48, 872.57, 872.57, 872.57, 872.57, 872.57, 870.41, 870.41, 870.41, 870.41, 870.41, 869.05, 869.05, 869.05, 869.05, 869.05, 872.78, 872.78, 872.78, 872.78, 872.78, 874.03, 874.03, 874.03, 874.03, 874.03, 859.56, 859.56, 859.56, 859.56, 859.56, 863.11, 863.11, 863.11, 863.11, 863.11, 860.19, 860.19, 860.19, 860.19, 860.19, 858.35, 858.35, 858.35, 858.35, 858.35, 857.54, 857.54, 857.54, 857.54, 857.54, 855.33, 855.33, 855.33, 855.33, 855.33, 847.24, 847.24, 847.24, 847.24, 847.24, 846.17, 846.17, 846.17, 846.17, 846.17, 848.42, 848.42, 848.42, 848.42, 848.42, 849.23, 849.23, 849.23, 849.23, 849.23, 853.08, 853.08, 853.08, 853.08, 853.08, 850.8, 850.8, 850.8, 850.8, 850.8, 855.84, 855.84, 855.84, 855.84, 855.84, 857.48, 857.48, 857.48, 857.48, 857.48, 858.01, 858.01, 858.01, 858.01, 858.01, 856.92, 856.92, 856.92, 856.92, 856.92, 857.51, 857.51, 857.51, 857.51, 857.51, 858.93, 858.93, 858.93, 858.93, 858.93, 858.59, 858.59, 858.59, 858.59, 858.59]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716026972 --> 1716027600
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.25, 33.25, 33.25, 33.25, 33.25, 33.35, 33.35, 33.35, 33.35, 33.35, 27.67, 27.67, 27.67, 27.67, 27.67, 29.4, 29.4, 29.4, 29.4, 29.4, 30.35, 30.35, 30.35, 30.35, 30.35, 31.27, 31.27, 31.27, 31.27, 31.27, 31.81, 31.81, 31.81, 31.81, 31.81, 32.32, 32.32, 32.32, 32.32, 32.32, 32.56, 32.56, 32.56, 32.56, 32.56, 32.42, 32.42, 32.42, 32.42, 32.42, 32.11, 32.11, 32.11, 32.11, 32.11, 32.07, 32.07, 32.07, 32.07, 32.07, 31.59, 31.59, 31.59, 31.59, 31.59, 30.91, 30.91, 30.91, 30.91, 30.91, 29.97, 29.97, 29.97, 29.97, 29.97, 29.47, 29.47, 29.47, 29.47, 29.47, 28.98, 28.98, 28.98, 28.98, 28.98, 29.08, 29.08, 29.08, 29.08, 29.08, 29.27, 29.27, 29.27, 29.27, 29.27, 29.07, 29.07, 29.07, 29.07, 29.07, 29.1, 29.1, 29.1, 29.1, 29.1, 29.19, 29.19, 29.19, 29.19, 29.19, 29.32, 29.32, 29.32, 29.32, 29.32, 29.57, 29.57, 29.57, 29.57, 29.57, 29.49, 29.49, 29.49, 29.49, 29.49, 29.74, 29.74, 29.74, 29.74, 29.74, 30.02, 30.02, 30.02, 30.02, 30.02, 29.95, 29.95, 29.95, 29.95, 29.95, 30.31, 30.31, 30.31, 30.31, 30.31, 30.49, 30.49, 30.49, 30.49, 30.49, 30.69, 30.69, 30.69, 30.69, 30.69, 30.86, 30.86, 30.86, 30.86, 30.86, 31.0, 31.0, 31.0, 31.0, 31.0, 31.06, 31.06, 31.06, 31.06, 31.06, 30.89, 30.89, 30.89, 30.89, 30.89, 30.63, 30.63, 30.63, 30.63, 30.63, 30.59, 30.59, 30.59, 30.59, 30.59, 29.95, 29.95, 29.95, 29.95, 29.95, 30.07, 30.07, 30.07, 30.07, 30.07, 30.21, 30.21, 30.21, 30.21, 30.21, 30.34, 30.34, 30.34, 30.34, 30.34, 30.33, 30.33, 30.33, 30.33, 30.33, 30.13, 30.13, 30.13, 30.13, 30.13, 29.96, 29.96, 29.96, 29.96, 29.96, 29.86, 29.86, 29.86, 29.86, 29.86, 28.77, 28.77, 28.77, 28.77, 28.77, 28.37, 28.37, 28.37, 28.37, 28.37, 28.34, 28.34, 28.34, 28.34, 28.34, 28.36, 28.36, 28.36, 28.36, 28.36, 28.43, 28.43, 28.43, 28.43, 28.43, 28.44, 28.44, 28.44, 28.44, 28.44, 28.54, 28.54, 28.54, 28.54, 28.54, 28.51, 28.51, 28.51, 28.51, 28.51, 28.52, 28.52, 28.52, 28.52, 28.52, 28.44, 28.44, 28.44, 28.44, 28.44, 28.45, 28.45, 28.45, 28.45, 28.45, 28.51, 28.51, 28.51, 28.51, 28.51, 28.66, 28.66, 28.66, 28.66, 28.66, 28.71, 28.71, 28.71, 28.71, 28.71, 28.87, 28.87, 28.87, 28.87, 28.87]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716026972 --> 1716027600
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.31, 0.31, 0.31, 0.31, 0.31, 0.35, 0.35, 0.35, 0.35, 0.35, 0.28, 0.28, 0.28, 0.28, 0.28, 0.14, 0.14, 0.14, 0.14, 0.14, 0.24, 0.24, 0.24, 0.24, 0.24, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.24, 0.24, 0.24, 0.24, 0.24, 0.25, 0.25, 0.25, 0.25, 0.25, 0.21, 0.21, 0.21, 0.21, 0.21, 0.32, 0.32, 0.32, 0.32, 0.32, 0.22, 0.22, 0.22, 0.22, 0.22, 0.36, 0.36, 0.36, 0.36, 0.36, 0.31, 0.31, 0.31, 0.31, 0.31, 0.27, 0.27, 0.27, 0.27, 0.27, 0.11, 0.11, 0.11, 0.11, 0.11, 0.31, 0.31, 0.31, 0.31, 0.31, 0.21, 0.21, 0.21, 0.21, 0.21, 0.14, 0.14, 0.14, 0.14, 0.14, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.36, 0.36, 0.36, 0.36, 0.36, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.3, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.3, 0.3, 0.3, 0.3, 0.3, 0.28, 0.28, 0.28, 0.28, 0.28, 0.37, 0.37, 0.37, 0.37, 0.37, 0.29, 0.29, 0.29, 0.29, 0.29, 0.21, 0.21, 0.21, 0.21, 0.21, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.13, 0.13, 0.13, 0.13, 0.13, 0.33, 0.33, 0.33, 0.33, 0.33, 0.62, 0.62, 0.62, 0.62, 0.62, 0.52, 0.52, 0.52, 0.52, 0.52, 0.62, 0.62, 0.62, 0.62, 0.62, 0.38, 0.38, 0.38, 0.38, 0.38, 0.16, 0.16, 0.16, 0.16, 0.16, 0.28, 0.28, 0.28, 0.28, 0.28, 0.25, 0.25, 0.25, 0.25, 0.25, 0.18, 0.18, 0.18, 0.18, 0.18, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716026972 --> 1716027600
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0]
                    
Loading

@mofosyne mofosyne added refactoring Refactoring Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 18, 2024
@mofosyne mofosyne added the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label May 18, 2024
@JohannesGaessler JohannesGaessler force-pushed the cuda-fattn-refactor-4 branch from 3ac059b to 4d9e90c Compare May 18, 2024 08:18
@JohannesGaessler JohannesGaessler merged commit 133d99c into ggml-org:master May 18, 2024
63 of 69 checks passed
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label May 18, 2024
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request May 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning merge ready indicates that this may be ready to merge soon and is just holding out in case of objections Nvidia GPU Issues specific to Nvidia GPUs refactoring Refactoring Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants