Skip to content

[KERNEL] Sampler. CUDA kernel for applying repetition penalty #18437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

vadiklyutiy
Copy link
Contributor

@vadiklyutiy vadiklyutiy commented May 20, 2025

Problem

Sampler's part that responsible for applying repetition penalty took a long time, especially for small models.

Solution

This PR introduce a CUDA kernel that implements applying repetition penalty

Performance

All measurements on H100

Unit test

Tests of "torch implementation" vs "new CUDA kernel" shows speed up 2.85x-6.23x depending of input.

Benchmark source code
import torch
import vllm
from vllm._custom_ops import (
    apply_repetition_penalties,
    apply_repetition_penalties_torch,
)

VOCAB_SIZE = 151936
NUM_SEQS = [1, 8, 16, 32, 64, 100, 256, 1024, 1025]
REPETITION_PENALTY = 1.2
ITERATIONS = 20  # Reduce for large sizes to avoid long runs
NUM_INPUT_TOKENS = 256
NUM_OUTPUT_TOKENS = 256

device = torch.device("cuda:0")

for num_seqs in NUM_SEQS:
    print(f"\n===== num_seqs={num_seqs}, vocab_size={VOCAB_SIZE} =====")
    logits = torch.randn(num_seqs, VOCAB_SIZE, dtype=torch.float32, device=device)
    prompt_mask = torch.zeros(num_seqs, VOCAB_SIZE, dtype=torch.bool, device=device)
    output_mask = torch.zeros(num_seqs, VOCAB_SIZE, dtype=torch.bool, device=device)
    prompt_tokens_per_seq = NUM_INPUT_TOKENS
    output_tokens_per_seq = NUM_OUTPUT_TOKENS
    for i in range(num_seqs):
        prompt_indices = torch.randint(0, VOCAB_SIZE, (prompt_tokens_per_seq,), device=device)
        output_indices = torch.randint(0, VOCAB_SIZE, (output_tokens_per_seq,), device=device)
        prompt_mask[i, prompt_indices] = True
        output_mask[i, output_indices] = True
    repetition_penalties = torch.full((num_seqs,), REPETITION_PENALTY, dtype=torch.float32, device=device)

    # Torch implementation (baseline)
    logits_torch = logits.clone()
    apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, repetition_penalties)

    # CUDA implementation
    logits_cuda = logits.clone()
    apply_repetition_penalties(logits_cuda, prompt_mask, output_mask, repetition_penalties)
    max_diff_cuda = torch.max(torch.abs(logits_cuda - logits_torch)).item()
    print(f"Max difference (cuda vs torch): {max_diff_cuda}")
    if max_diff_cuda < 1e-5:
        print("CUDA implementation matches torch.")
    else:
        print("CUDA implementation does NOT match torch!")

    # Timing
    import time
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Torch timing
    for _ in range(2):  # Warm-up
        logits_tmp = logits.clone()
        apply_repetition_penalties_torch(logits_tmp, prompt_mask, output_mask, repetition_penalties)
    torch.cuda.synchronize()
    start.record()
    for _ in range(ITERATIONS):
        logits_tmp = logits.clone()
        apply_repetition_penalties_torch(logits_tmp, prompt_mask, output_mask, repetition_penalties)
    end.record()
    torch.cuda.synchronize()
    torch_time = start.elapsed_time(end) / ITERATIONS
    print(f"Torch implementation average time: {torch_time:.6f} ms")

    # CUDA timing
    for _ in range(2):
        logits_tmp = logits.clone()
        apply_repetition_penalties(logits_tmp, prompt_mask, output_mask, repetition_penalties)
    torch.cuda.synchronize()
    start.record()
    for _ in range(ITERATIONS):
        logits_tmp = logits.clone()
        apply_repetition_penalties(logits_tmp, prompt_mask, output_mask, repetition_penalties)
    end.record()
    torch.cuda.synchronize()
    cuda_time = start.elapsed_time(end) / ITERATIONS
    print(f"CUDA implementation average time: {cuda_time:.6f} ms (speedup: {torch_time/cuda_time:.2f}x)") 
Benchmark results
===== num_seqs=1, vocab_size=151936 =====
Max difference (cuda vs torch): 1.1920928955078125e-07
CUDA implementation matches torch.
Torch implementation average time: 0.075288 ms
CUDA implementation average time: 0.012088 ms (speedup: 6.23x)

===== num_seqs=8, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 0.069330 ms
CUDA implementation average time: 0.011997 ms (speedup: 5.78x)

===== num_seqs=16, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 0.072938 ms
CUDA implementation average time: 0.019819 ms (speedup: 3.68x)

===== num_seqs=32, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 0.135734 ms
CUDA implementation average time: 0.047549 ms (speedup: 2.85x)

===== num_seqs=64, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 0.270811 ms
CUDA implementation average time: 0.079683 ms (speedup: 3.40x)

===== num_seqs=100, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 0.408611 ms
CUDA implementation average time: 0.119597 ms (speedup: 3.42x)

===== num_seqs=256, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 0.977638 ms
CUDA implementation average time: 0.253621 ms (speedup: 3.85x)

===== num_seqs=1024, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 3.787849 ms
CUDA implementation average time: 0.984485 ms (speedup: 3.85x)

===== num_seqs=1025, vocab_size=151936 =====
Max difference (cuda vs torch): 2.384185791015625e-07
CUDA implementation matches torch.
Torch implementation average time: 3.792390 ms
CUDA implementation average time: 0.986283 ms (speedup: 3.85x)

E2E performance test
Used benchmark_latency.py

python benchmarks/benchmark_latency.py --model meta-llama/Llama-3.2-1B-Instruct  --input-len 256 --output-len 256 --batch-size 256 --num-iters-warmup 1 --num-iters 3  --max-model-len=8192
with small modification to enable repetition penalty
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
index d5aaceeb8..fdada4af2 100644
--- a/benchmarks/benchmark_latency.py
+++ b/benchmarks/benchmark_latency.py
@@ -56,6 +56,7 @@ def main(args: argparse.Namespace):
         ignore_eos=True,
         max_tokens=args.output_len,
         detokenize=not args.disable_detokenize,
+        repetition_penalty=1.05,
     )
     print(sampling_params)
     dummy_prompt_token_ids = np.random.randint(

Before

Avg latency: 2.6969182553342157 seconds
10% percentile latency: 2.682154784200975 seconds
25% percentile latency: 2.6872044325009483 seconds
50% percentile latency: 2.695620513000904 seconds
75% percentile latency: 2.705983207000827 seconds
90% percentile latency: 2.7122008234007806 seconds
99% percentile latency: 2.715931393240753 seconds

After

Avg latency: 2.545895525667826 seconds
10% percentile latency: 2.543504690201371 seconds
25% percentile latency: 2.5439511310014495 seconds
50% percentile latency: 2.5446951990015805 seconds
75% percentile latency: 2.54723975700108 seconds
90% percentile latency: 2.5487664918007793 seconds
99% percentile latency: 2.5496825326805994 seconds

Avg latency speed up 5.9%

Correctness

Covered by added tests: tests/kernels/test_apply_repetition_penalties.py

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label May 20, 2025
@vadiklyutiy
Copy link
Contributor Author

@WoosukKwon @tlrmchlsmth
I'd like kindly remind about this PR

@simon-mo
Copy link
Collaborator

can you compare this against wrapping apply_penalties with @torch.compile?

@vadiklyutiy
Copy link
Contributor Author

vadiklyutiy commented May 27, 2025

can you compare this against wrapping apply_penalties with @torch.compile?

Seems wrapping apply_penalties with @torch.compile is better.

Same test
Before 2.62 sec
With custom kernel from this PR 2.45 sec
With wrapping apply_penalties with @torch.compile: vary between 2.28-2.35 sec

@vadiklyutiy
Copy link
Contributor Author

@simon-mo
Lets me guess you thoughts when you wrote "wrapping apply_penalties with @torch.compile".

The next proposal after my results in previous message is why don't wrap it with torch.compile.
Actually I investigated it a bit.

The first though is wrap class Sampler with @support_torch_compile. Seems this path somewhere between hard and impossible. There is code that not supported by dynamo currently (pin_memory). Also the internal logic of Sampler a bit complicated. There are a number of parameters is SamplingMetadata (passed to Sampler.forward) like all_random, top_p, etc and them depends on input prompts. If new prompt come with sampling parameters that we haven't seen before come, we got new path and re-compilation is required. But we don't allow recompilations by design.

The second though is wrap apply_penalties only. In general is a good approach IMO. But we don't have functionality that allow to wrap function(right now with @support_torch_compile we can wrap only forward method of class). If somebody doing it - this is good, if it is planned take implementation of himself - this is good. If not one of the above, I'd be glad to take implementation of it on myself.

But anyway the path with implementation of wrapping standalone function with torch.compile is not fast. It's needed in thinking, discussion, implementation, debugging. Meantime speed up implemented in this PR is pretty sufficient and is ready immediately. Also it is pretty safe and may be tested widely(what I believe I did with tests/kernels/test_apply_repetition_penalties.py). So, I propose to merge it now and when "wrapping standalone function with torch.compile" is ready we will disable this custom kernel. Does it make sense?

@simon-mo
Copy link
Collaborator

My thought is mostly thinking about whether this is a kernel that torch compiler or triton can generate directly if so it reduces complexity.

@@ -281,6 +281,45 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)


def apply_repetition_penalties_torch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, simply put this under @torch.compile. Does it help match the CUDA performance?

Copy link
Contributor Author

@vadiklyutiy vadiklyutiy May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the script from the PR description to compare "pure torch" vs "torch.compile" vs "cuda kernel". Results are latency in ms

len pure torch torch.compile cuda(this PR)
1 0.066 0.044 0.011
8 0.065 0.052 0.012
16 0.077 0.054 0.020
32 0.137 0.055 0.048
64 0.273 0.069 0.080
100 0.410 0.106 0.120
256 0.980 0.251 0.254
1024 3.791 0.984 0.986
1025 3.795 0.985 0.988

For small len(<=32), CUDA implementation is better than torch.compile by up to 4.5x. For medium (64, 100) torch.compile better by 10-15%. For len>=256 results the same.

@vadiklyutiy
Copy link
Contributor Author

My thought is mostly thinking about whether this is a kernel that torch compiler or triton can generate directly if so it reduces complexity.

I agree that torch.compile is preferable option. But we can not just add torch.compile because it might cause re-compilation in inference time, what is not so good. We need to implement something similar to @support_torch_compile but for standalone function.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks good overall. I do share the same sentiment as @simon-mo that it would be nicer to have this in triton, or to use torch.compile.

Signed-off-by: Vadim Gimpelson <[email protected]>

Co-authored-by: Tyler Michael Smith <[email protected]>
@vadiklyutiy vadiklyutiy force-pushed the sampler-penalty-kernel branch from dce5236 to 5a7038d Compare May 28, 2025 01:23
Signed-off-by: Vadim Gimpelson <[email protected]>
@vadiklyutiy vadiklyutiy force-pushed the sampler-penalty-kernel branch from 5a7038d to 7757f3a Compare May 28, 2025 01:29
@vadiklyutiy
Copy link
Contributor Author

PR looks good overall. I do share the same sentiment as @simon-mo that it would be nicer to have this in triton, or to use torch.compile.

I think torch.compile is the best option. But we can't just add torch.compile because it might cause re-compilation in inference time, what is not so good. We have to implement something similar to @support_torch_compile but for standalone function. I'd be glad to add such support if nobody is already doing it. But this is not very fast changes. So, how about merge this PR because it bring improvement immediately and after having support of torch.compile for standalone function we change this code torch.compile version?

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please submit a follow up PR with torch compile

Comment on lines 11 to 12
# [stress, stress, stress Qwen, llama 4]
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# [stress, stress, stress Qwen, llama 4]
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]
# [stress, stress, stress, Qwen3, Llama4]
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines 16 to 18
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Signed-off-by: Vadim Gimpelson <[email protected]>
@simon-mo simon-mo enabled auto-merge (squash) May 30, 2025 16:25
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants