Skip to content

[KERNEL] Sampler. CUDA kernel for applying repetition penalty #18437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ set(VLLM_EXT_SRC
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);

void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);

void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);
Expand Down
86 changes: 86 additions & 0 deletions csrc/sampler.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "dispatch_utils.h"

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>

#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif

namespace vllm {

template <typename scalar_t>
__global__ void apply_repetition_penalties_kernel(
scalar_t* __restrict__ logits, // [num_seqs, vocab_size]
const bool* __restrict__ prompt_mask, // [num_seqs, vocab_size]
const bool* __restrict__ output_mask, // [num_seqs, vocab_size]
const scalar_t* __restrict__ repetition_penalties, // [num_seqs]
const int num_seqs, const int vocab_size, const int tile_size) {
// Each block handles one sequence and a tile of vocab
const int seq_idx = blockIdx.x;
if (seq_idx >= num_seqs) return;

const int tile_start = blockIdx.y * tile_size;
const int tile_end = min(tile_start + tile_size, vocab_size);

// Load repetition penalty for this sequence
const scalar_t penalty = repetition_penalties[seq_idx];

// Each thread processes multiple vocab items within the tile
for (int vocab_idx = tile_start + threadIdx.x; vocab_idx < tile_end;
vocab_idx += blockDim.x) {
const int64_t idx = static_cast<int64_t>(seq_idx) * vocab_size + vocab_idx;
const bool is_repeated = prompt_mask[idx] || output_mask[idx];
if (is_repeated) {
scalar_t logit = logits[idx];
if (logit > 0) {
logits[idx] = logit / penalty;
} else {
logits[idx] = logit * penalty;
}
}
}
}

} // namespace vllm

void apply_repetition_penalties_(
torch::Tensor& logits, // [num_seqs, vocab_size], in-place
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size]
const torch::Tensor& output_mask, // [num_seqs, vocab_size]
const torch::Tensor& repetition_penalties) { // [num_seqs]
TORCH_CHECK(logits.is_contiguous());
TORCH_CHECK(prompt_mask.is_contiguous());
TORCH_CHECK(output_mask.is_contiguous());
TORCH_CHECK(repetition_penalties.is_contiguous());

int vocab_size = logits.size(-1);
int num_seqs = logits.size(0);

// Get number of SMs on the current device
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
logits.get_device());

// Compute tile_num and tile_size
int tile_num =
std::min(vocab_size, std::max(1, (sms + num_seqs - 1) / num_seqs));
int tile_size = (vocab_size + tile_num - 1) / tile_num;

// Each block handles one sequence and a tile of vocab
dim3 grid(num_seqs, tile_num);
dim3 block(std::min(tile_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
vllm::apply_repetition_penalties_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(),
output_mask.data_ptr<bool>(),
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
tile_size);
});
}
7 changes: 7 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
"Tensor output_mask, Tensor repetition_penalties) -> ()");
ops.impl("apply_repetition_penalties_", torch::kCUDA,
&apply_repetition_penalties_);

// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
79 changes: 79 additions & 0 deletions tests/kernels/test_apply_repetition_penalties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch

from tests.kernels.utils import opcheck
from vllm._custom_ops import (apply_repetition_penalties_cuda,
apply_repetition_penalties_torch)
from vllm.platforms import current_platform

NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025]
# [stress, stress, stress Qwen, llama 4]
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# [stress, stress, stress Qwen, llama 4]
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]
# [stress, stress, stress, Qwen3, Llama4]
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

REPETITION_PENALTY_VALUES = [1.05]
SEEDS = [0]
DTYPES = [torch.float32, torch.float16]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed



@pytest.mark.parametrize("num_seqs", NUM_SEQS)
@pytest.mark.parametrize("vocab_size", VOCAB_SIZES)
@pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_apply_repetition_penalties(
num_seqs: int,
vocab_size: int,
repetition_penalty: float,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
"""
Test the apply_repetition_penalties custom op
against a reference implementation.
"""
current_platform.seed_everything(seed)
torch.set_default_device(device)

# Create test data
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)

# Create masks with some random tokens marked as repeated
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)

# Mark some tokens as repeated in prompt and output
prompt_indices = torch.randint(0, vocab_size,
(num_seqs, max(1, vocab_size // 200)))
output_indices = torch.randint(0, vocab_size,
(num_seqs, max(1, vocab_size // 200)))

for i in range(num_seqs):
prompt_mask[i, prompt_indices[i]] = True
output_mask[i, output_indices[i]] = True

# Create repetition penalties tensor
repetition_penalties = torch.full((num_seqs, ),
repetition_penalty,
dtype=dtype)

# Run all three implementations
logits_torch = logits.clone()
logits_cuda = logits.clone()

apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
repetition_penalties)

# Compare all outputs to reference
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)

# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
39 changes: 39 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,45 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)


def apply_repetition_penalties_torch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, simply put this under @torch.compile. Does it help match the CUDA performance?

Copy link
Contributor Author

@vadiklyutiy vadiklyutiy May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the script from the PR description to compare "pure torch" vs "torch.compile" vs "cuda kernel". Results are latency in ms

len pure torch torch.compile cuda(this PR)
1 0.066 0.044 0.011
8 0.065 0.052 0.012
16 0.077 0.054 0.020
32 0.137 0.055 0.048
64 0.273 0.069 0.080
100 0.410 0.106 0.120
256 0.980 0.251 0.254
1024 3.791 0.984 0.986
1025 3.795 0.985 0.988

For small len(<=32), CUDA implementation is better than torch.compile by up to 4.5x. For medium (64, 100) torch.compile better by 10-15%. For len>=256 results the same.

logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1))
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling


def apply_repetition_penalties_cuda(
logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask,
repetition_penalties)


def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor) -> None:
"""Apply repetition penalties to logits in-place.

Args:
logits: The logits tensor of shape [num_seqs, vocab_size].
prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ).
"""
if current_platform.is_cuda():
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
repetition_penalties)
else:
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
repetition_penalties)


def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
Expand Down
13 changes: 4 additions & 9 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,11 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)

# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)

# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
# Apply repetition penalties as a custom op
from vllm._custom_ops import apply_repetition_penalties
apply_repetition_penalties(logits, prompt_mask, output_mask,
repetition_penalties)

# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
Expand Down