Skip to content

[V1][TPU] Support V1 Sampler for ragged attention #14227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
import tempfile
from time import time

import pytest

from vllm import LLM, envs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams

if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)


@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_compilation(model_name: str, monkeypatch):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with tempfile.TemporaryDirectory() as temp_dir:
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
# Compiling model init may still take some time, enforce_eager to skip.
llm = LLM(model_name,
enforce_eager=True,
max_num_seqs=16,
max_model_len=1024,
gpu_memory_utilization=0.5)
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
]
# First inference should be slow
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
# top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
_ = llm.generate(prompts, sampling_params)
run1 = time() - s

# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
min_p=0.8,
max_tokens=24)
s = time()
_ = llm.generate(prompts, sampling_params)
run2 = time() - s
# Much faster after compiling
assert run1 * 0.1 > run2
print("TIMES", run1, run2)

# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
s = time()
_ = llm.generate(prompts, sampling_params)
run3 = time() - s
assert run1 * 0.1 > run3
print("TIMES", run1, run3)


@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_different(model_name: str):
"""
Test significantly different sampling params to assert the model produces
different results.
"""
llm = LLM(
model_name,
enforce_eager=True,
max_num_seqs=1,
max_model_len=64,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization=0.5)
prompts = [
"Write a short story about a robot that dreams for the first time."
]
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
output = llm.generate(prompts, sampling_params)

sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text
16 changes: 15 additions & 1 deletion vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(self):
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
self.forward = self.forward_tpu
else:
self.forward = self.forward_native

Expand Down Expand Up @@ -96,6 +98,18 @@ def forward_cuda(
return random_sample(probs, generators)
return flashinfer_sample(probs, k, p, generators)

def forward_tpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# TODO Placeholder for TPU optimized topk/p kernel
# logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)


def apply_top_k_top_p(
logits: torch.Tensor,
Expand All @@ -112,7 +126,7 @@ def apply_top_k_top_p(

if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
Expand Down
Empty file.
159 changes: 159 additions & 0 deletions vllm/v1/sample/tpu/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional

import torch
import torch_xla.core.xla_model as xm

from vllm.v1.sample.metadata import SamplingMetadata


@dataclass
class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor

min_p: torch.Tensor
# Still too slow on forward_native!
top_k: torch.Tensor = None
top_p: torch.Tensor = None

# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None

# speculation not supported
spec_token_ids = None

# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())

# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None

# TODO No penalties for now
no_penalties: bool = True
prompt_token_ids = None
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
# should use tensor
output_token_ids: list[list[int]] = field(default_factory=lambda: list())

min_tokens = None # impl is not vectorized

logit_bias: list[Optional[dict[int, float]]] = field(
default_factory=lambda: list())

allowed_token_ids_mask = None
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None

def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)

@classmethod
def from_sampling_metadata(
cls, metadata: SamplingMetadata,
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
device: torch.device) -> "TPUSupportedSamplingMetadata":
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).

In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.

Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
"""
metadata = cls._validate_sampling_metadata(metadata)
# NOTE we have to initialize default tensor-based params first and
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices)
do_argmax = torch.tensor(metadata.all_greedy,
dtype=torch.bool,
device=device)
new_metadata = cls.get_default_sampling_params(num_samples, device,
indices_do_sample=\
padded_do_sample_indices,
do_argmax=do_argmax
)
supported_params = \
TPUSupportedSamplingMetadata._get_default_params_values()
# Copy input non-None values into `new_metadata` fixed-sized tensors.
for p_name in supported_params:
old_val = getattr(metadata, p_name)
new_val = getattr(new_metadata, p_name)
if isinstance(old_val, torch.Tensor):
new_val[:num_do_sample] = old_val
setattr(new_metadata, p_name, new_val)

xm.mark_step()
xm.wait_device_ops()
return new_metadata

@classmethod
def get_default_sampling_params(
cls,
num_samples: int,
device: torch.device,
indices_do_sample=None,
do_argmax=None) -> "TPUSupportedSamplingMetadata":
# As sampling happens on a single traced graph, options
# are "disabled" by having them evaluate to an Identity op.
# Note that initialization is dependent on num_samples.
sampling_metadata_disable_value = \
TPUSupportedSamplingMetadata._get_default_params_values()
init_kwargs = dict()
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor

return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)

@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata

@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
Loading