-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[KERNEL] Sampler. CUDA kernel for applying repetition penalty #18437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
60f88cf
7032f7c
7757f3a
c907487
f72741e
68cc9b8
9c78cee
6bf7bd2
c766ccc
08dfe19
48fd744
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#include "dispatch_utils.h" | ||
|
||
#include <torch/cuda.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#ifndef USE_ROCM | ||
#include <cub/cub.cuh> | ||
#else | ||
#include <hipcub/hipcub.hpp> | ||
#endif | ||
|
||
namespace vllm { | ||
|
||
template <typename scalar_t> | ||
__global__ void apply_repetition_penalties_kernel( | ||
scalar_t* __restrict__ logits, // [num_seqs, vocab_size] | ||
const bool* __restrict__ prompt_mask, // [num_seqs, vocab_size] | ||
const bool* __restrict__ output_mask, // [num_seqs, vocab_size] | ||
const scalar_t* __restrict__ repetition_penalties, // [num_seqs] | ||
const int num_seqs, const int vocab_size, const int tile_size) { | ||
// Each block handles one sequence and a tile of vocab | ||
const int seq_idx = blockIdx.x; | ||
if (seq_idx >= num_seqs) return; | ||
|
||
const int tile_start = blockIdx.y * tile_size; | ||
const int tile_end = min(tile_start + tile_size, vocab_size); | ||
|
||
// Load repetition penalty for this sequence | ||
const scalar_t penalty = repetition_penalties[seq_idx]; | ||
|
||
// Each thread processes multiple vocab items within the tile | ||
for (int vocab_idx = tile_start + threadIdx.x; vocab_idx < tile_end; | ||
vocab_idx += blockDim.x) { | ||
const int64_t idx = static_cast<int64_t>(seq_idx) * vocab_size + vocab_idx; | ||
const bool is_repeated = prompt_mask[idx] || output_mask[idx]; | ||
if (is_repeated) { | ||
scalar_t logit = logits[idx]; | ||
if (logit > 0) { | ||
logits[idx] = logit / penalty; | ||
} else { | ||
logits[idx] = logit * penalty; | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace vllm | ||
|
||
void apply_repetition_penalties_( | ||
torch::Tensor& logits, // [num_seqs, vocab_size], in-place | ||
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size] | ||
const torch::Tensor& output_mask, // [num_seqs, vocab_size] | ||
const torch::Tensor& repetition_penalties) { // [num_seqs] | ||
TORCH_CHECK(logits.is_contiguous()); | ||
TORCH_CHECK(prompt_mask.is_contiguous()); | ||
TORCH_CHECK(output_mask.is_contiguous()); | ||
TORCH_CHECK(repetition_penalties.is_contiguous()); | ||
|
||
int vocab_size = logits.size(-1); | ||
int num_seqs = logits.size(0); | ||
|
||
// Get number of SMs on the current device | ||
int sms = 0; | ||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, | ||
logits.get_device()); | ||
|
||
// Compute tile_num and tile_size | ||
int tile_num = | ||
std::min(vocab_size, std::max(1, (sms + num_seqs - 1) / num_seqs)); | ||
int tile_size = (vocab_size + tile_num - 1) / tile_num; | ||
|
||
// Each block handles one sequence and a tile of vocab | ||
dim3 grid(num_seqs, tile_num); | ||
dim3 block(std::min(tile_size, 1024)); | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] { | ||
vllm::apply_repetition_penalties_kernel<scalar_t> | ||
<<<grid, block, 0, stream>>>( | ||
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(), | ||
output_mask.data_ptr<bool>(), | ||
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size, | ||
tile_size); | ||
}); | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,79 @@ | ||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||
import pytest | ||||||||||
import torch | ||||||||||
|
||||||||||
from tests.kernels.utils import opcheck | ||||||||||
from vllm._custom_ops import (apply_repetition_penalties_cuda, | ||||||||||
apply_repetition_penalties_torch) | ||||||||||
from vllm.platforms import current_platform | ||||||||||
|
||||||||||
NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025] | ||||||||||
# [stress, stress, stress Qwen, llama 4] | ||||||||||
VOCAB_SIZES = [17, 256, 1019, 151936, 202048] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||||||||||
REPETITION_PENALTY_VALUES = [1.05] | ||||||||||
SEEDS = [0] | ||||||||||
DTYPES = [torch.float32, torch.float16] | ||||||||||
CUDA_DEVICES = [ | ||||||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) | ||||||||||
] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||||||||||
|
||||||||||
|
||||||||||
@pytest.mark.parametrize("num_seqs", NUM_SEQS) | ||||||||||
@pytest.mark.parametrize("vocab_size", VOCAB_SIZES) | ||||||||||
@pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES) | ||||||||||
@pytest.mark.parametrize("dtype", DTYPES) | ||||||||||
@pytest.mark.parametrize("seed", SEEDS) | ||||||||||
@pytest.mark.parametrize("device", CUDA_DEVICES) | ||||||||||
@torch.inference_mode() | ||||||||||
def test_apply_repetition_penalties( | ||||||||||
num_seqs: int, | ||||||||||
vocab_size: int, | ||||||||||
repetition_penalty: float, | ||||||||||
dtype: torch.dtype, | ||||||||||
seed: int, | ||||||||||
device: str, | ||||||||||
) -> None: | ||||||||||
""" | ||||||||||
Test the apply_repetition_penalties custom op | ||||||||||
against a reference implementation. | ||||||||||
""" | ||||||||||
current_platform.seed_everything(seed) | ||||||||||
torch.set_default_device(device) | ||||||||||
|
||||||||||
# Create test data | ||||||||||
logits = torch.randn(num_seqs, vocab_size, dtype=dtype) | ||||||||||
|
||||||||||
# Create masks with some random tokens marked as repeated | ||||||||||
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) | ||||||||||
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) | ||||||||||
|
||||||||||
# Mark some tokens as repeated in prompt and output | ||||||||||
prompt_indices = torch.randint(0, vocab_size, | ||||||||||
(num_seqs, max(1, vocab_size // 200))) | ||||||||||
output_indices = torch.randint(0, vocab_size, | ||||||||||
(num_seqs, max(1, vocab_size // 200))) | ||||||||||
|
||||||||||
for i in range(num_seqs): | ||||||||||
prompt_mask[i, prompt_indices[i]] = True | ||||||||||
output_mask[i, output_indices[i]] = True | ||||||||||
|
||||||||||
# Create repetition penalties tensor | ||||||||||
repetition_penalties = torch.full((num_seqs, ), | ||||||||||
repetition_penalty, | ||||||||||
dtype=dtype) | ||||||||||
|
||||||||||
# Run all three implementations | ||||||||||
logits_torch = logits.clone() | ||||||||||
logits_cuda = logits.clone() | ||||||||||
|
||||||||||
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, | ||||||||||
repetition_penalties) | ||||||||||
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, | ||||||||||
repetition_penalties) | ||||||||||
|
||||||||||
# Compare all outputs to reference | ||||||||||
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) | ||||||||||
|
||||||||||
# Test the operator by applying the opcheck utility | ||||||||||
opcheck(torch.ops._C.apply_repetition_penalties_, | ||||||||||
(logits.clone(), prompt_mask, output_mask, repetition_penalties)) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -281,6 +281,45 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, | |||||||||||||||||||||||||||||||||||||||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def apply_repetition_penalties_torch( | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, simply put this under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used the script from the PR description to compare "pure torch" vs "
For small len(<=32), CUDA implementation is better than |
||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, prompt_mask: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: | ||||||||||||||||||||||||||||||||||||||||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( | ||||||||||||||||||||||||||||||||||||||||||
1, logits.size(1)) | ||||||||||||||||||||||||||||||||||||||||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op. | ||||||||||||||||||||||||||||||||||||||||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, | ||||||||||||||||||||||||||||||||||||||||||
1.0) | ||||||||||||||||||||||||||||||||||||||||||
# If logits are positive, divide by penalty, otherwise multiply by penalty. | ||||||||||||||||||||||||||||||||||||||||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties) | ||||||||||||||||||||||||||||||||||||||||||
logits *= scaling | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def apply_repetition_penalties_cuda( | ||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, prompt_mask: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: | ||||||||||||||||||||||||||||||||||||||||||
torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask, | ||||||||||||||||||||||||||||||||||||||||||
repetition_penalties) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
output_mask: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
repetition_penalties: torch.Tensor) -> None: | ||||||||||||||||||||||||||||||||||||||||||
"""Apply repetition penalties to logits in-place. | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||
logits: The logits tensor of shape [num_seqs, vocab_size]. | ||||||||||||||||||||||||||||||||||||||||||
prompt_mask: A boolean tensor indicating which tokens appear in the prompt. | ||||||||||||||||||||||||||||||||||||||||||
output_mask: A boolean tensor indicating which tokens appear in the output. | ||||||||||||||||||||||||||||||||||||||||||
repetition_penalties: The repetition penalties of shape (num_seqs, ). | ||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||
if current_platform.is_cuda(): | ||||||||||||||||||||||||||||||||||||||||||
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, | ||||||||||||||||||||||||||||||||||||||||||
repetition_penalties) | ||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||
apply_repetition_penalties_torch(logits, prompt_mask, output_mask, | ||||||||||||||||||||||||||||||||||||||||||
repetition_penalties) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, | ||||||||||||||||||||||||||||||||||||||||||
input_tokens: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
sampled_token_ids: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.