-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[V1][TPU] Support V1 Sampler for ragged attention #14227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 22 commits into
vllm-project:main
from
NickLucche:tpu-sampler-ragged
Mar 20, 2025
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
4baf255
xla friendly minp+topk
NickLucche 0f50a2a
fix platform check
NickLucche 4be47d6
tracing sampler
NickLucche d53f7c8
forward_tpu + revert topk selection
NickLucche e9e48f1
refactor to avoid recompiling None values and disable topk/topp
NickLucche b84baa9
tests + updated torch dev version
NickLucche 25dead0
wip: adapt to ragged attn kernel
NickLucche deda02a
adapt to ragged attn kernel and multimodal
NickLucche 317714e
add tests
NickLucche be7bcec
break up model|sample graph to speed up compilation
NickLucche 07d9a1f
minor check on optional temp
NickLucche e65c55d
fix greedy sampling
NickLucche ade6054
move tpu sampling params in own file
NickLucche 178f104
address review
NickLucche 76460c6
rebase cruft
NickLucche d1f79a5
max_num_tokens stopping condition when compiling
NickLucche 4596951
rebase changes
NickLucche 10e1a04
fix capture_graph loop
NickLucche 50ef555
fix recompilation issue on sampling graph; add new tpu sampler
NickLucche 92d23cd
newline conflict(?)
NickLucche c2b5760
Merge branch 'main' into tpu-sampler-ragged
NickLucche 4d6d30c
revert gpu sampler change
NickLucche File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import tempfile | ||
from time import time | ||
|
||
import pytest | ||
|
||
from vllm import LLM, envs | ||
from vllm.platforms import current_platform | ||
from vllm.sampling_params import SamplingParams | ||
|
||
if not envs.VLLM_USE_V1: | ||
pytest.skip( | ||
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", | ||
allow_module_level=True, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"]) | ||
@pytest.mark.skipif(not current_platform.is_tpu(), | ||
reason="This test needs a TPU") | ||
def test_sampler_compilation(model_name: str, monkeypatch): | ||
""" | ||
Check that no recompilation happens despite changing sampling parameters. | ||
We can't read XLA metrics from the engine process, hence we measure time. | ||
""" | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir) | ||
# Compiling model init may still take some time, enforce_eager to skip. | ||
llm = LLM(model_name, | ||
enforce_eager=True, | ||
max_num_seqs=16, | ||
max_model_len=1024, | ||
gpu_memory_utilization=0.5) | ||
prompts = [ | ||
"A robot may not injure a human being", | ||
"It is only with the heart that one can see rightly;", | ||
] | ||
# First inference should be slow | ||
sampling_params = SamplingParams( | ||
temperature=0.7, | ||
# top_p=0.6, # TODO too slow! | ||
# top_k=10, | ||
min_p=0.2, | ||
max_tokens=16) | ||
s = time() | ||
_ = llm.generate(prompts, sampling_params) | ||
run1 = time() - s | ||
|
||
# Second request with different params, but for which we | ||
# compiled for in previous eager iteration. | ||
sampling_params = SamplingParams(temperature=0.1, | ||
min_p=0.8, | ||
max_tokens=24) | ||
s = time() | ||
_ = llm.generate(prompts, sampling_params) | ||
run2 = time() - s | ||
# Much faster after compiling | ||
assert run1 * 0.1 > run2 | ||
print("TIMES", run1, run2) | ||
|
||
# Third request with min_p set to "None". It will not trigger | ||
# recompilation as a default 0 value will be used. | ||
sampling_params = SamplingParams(max_tokens=24, temperature=0.0) | ||
s = time() | ||
_ = llm.generate(prompts, sampling_params) | ||
run3 = time() - s | ||
assert run1 * 0.1 > run3 | ||
print("TIMES", run1, run3) | ||
|
||
|
||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) | ||
@pytest.mark.skipif(not current_platform.is_tpu(), | ||
reason="This test needs a TPU") | ||
def test_sampler_different(model_name: str): | ||
""" | ||
Test significantly different sampling params to assert the model produces | ||
different results. | ||
""" | ||
llm = LLM( | ||
model_name, | ||
enforce_eager=True, | ||
max_num_seqs=1, | ||
max_model_len=64, | ||
# TODO: setting to 0.5 or it will go OOM | ||
gpu_memory_utilization=0.5) | ||
prompts = [ | ||
"Write a short story about a robot that dreams for the first time." | ||
] | ||
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) | ||
output = llm.generate(prompts, sampling_params) | ||
|
||
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) | ||
output2 = llm.generate(prompts, sampling_params) | ||
assert output[0].outputs[0].text != output2[0].outputs[0].text | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
import torch | ||
import torch_xla.core.xla_model as xm | ||
|
||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
||
|
||
@dataclass | ||
class TPUSupportedSamplingMetadata: | ||
# This class exposes a more xla-friendly interface than SamplingMetadata | ||
# on TPU, in particular all arguments should be traceable and no optionals | ||
# are allowed, to avoid graph recompilation on Nones. | ||
temperature: torch.Tensor | ||
|
||
min_p: torch.Tensor | ||
# Still too slow on forward_native! | ||
top_k: torch.Tensor = None | ||
top_p: torch.Tensor = None | ||
|
||
# XLA-unfriendly control flow in Sampler | ||
all_greedy: bool = False | ||
all_random: bool = False | ||
# Greedy sampling flag for compiling single xla graph. | ||
do_argmax: torch.Tensor = None | ||
|
||
# speculation not supported | ||
spec_token_ids = None | ||
|
||
# Generator not supported by xla | ||
generators: dict[int, | ||
torch.Generator] = field(default_factory=lambda: dict()) | ||
|
||
# unsupported, you need to return an extra tensor of static size BxV | ||
max_num_logprobs = None | ||
|
||
# TODO No penalties for now | ||
no_penalties: bool = True | ||
prompt_token_ids = None | ||
frequency_penalties = None | ||
presence_penalties = None | ||
repetition_penalties = None | ||
# should use tensor | ||
output_token_ids: list[list[int]] = field(default_factory=lambda: list()) | ||
|
||
min_tokens = None # impl is not vectorized | ||
|
||
logit_bias: list[Optional[dict[int, float]]] = field( | ||
default_factory=lambda: list()) | ||
|
||
allowed_token_ids_mask = None | ||
bad_words_token_ids = None | ||
indices_do_sample: torch.Tensor = None | ||
|
||
def __post_init__(self): | ||
temp = self.temperature | ||
if self.indices_do_sample is None: | ||
self.indices_do_sample = torch.zeros(temp.shape[0], | ||
device=temp.device, | ||
dtype=torch.int32) | ||
if self.do_argmax is None: | ||
self.do_argmax = torch.tensor(0, | ||
dtype=torch.bool, | ||
device=temp.device) | ||
|
||
@classmethod | ||
def from_sampling_metadata( | ||
cls, metadata: SamplingMetadata, | ||
padded_do_sample_indices: torch.Tensor, num_do_sample: int, | ||
device: torch.device) -> "TPUSupportedSamplingMetadata": | ||
""" | ||
Create an XLA-frienly SamplingMetadata structure. Do so by first | ||
instantiating an object with fixed-sized tensors and then writing the | ||
values in input `metadata`. Do that only for non-None values so that | ||
recompilation is not triggered for optional values (None/torch.Tensor). | ||
|
||
In order to handle different sizes for the params that range from 1 up | ||
to `max_num_seqs`, pad tensors to the closest pre-compiled shape. | ||
Same thing for `padded_do_sample_indices`, which contains the indices | ||
to be fed to the Sampler, padded to the closest pre-compiled shape. | ||
|
||
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] | ||
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0] | ||
""" | ||
metadata = cls._validate_sampling_metadata(metadata) | ||
# NOTE we have to initialize default tensor-based params first and | ||
# skip None values altogether to produce the same xla graph. | ||
num_samples = len(padded_do_sample_indices) | ||
do_argmax = torch.tensor(metadata.all_greedy, | ||
dtype=torch.bool, | ||
device=device) | ||
new_metadata = cls.get_default_sampling_params(num_samples, device, | ||
indices_do_sample=\ | ||
padded_do_sample_indices, | ||
do_argmax=do_argmax | ||
) | ||
supported_params = \ | ||
TPUSupportedSamplingMetadata._get_default_params_values() | ||
# Copy input non-None values into `new_metadata` fixed-sized tensors. | ||
for p_name in supported_params: | ||
old_val = getattr(metadata, p_name) | ||
new_val = getattr(new_metadata, p_name) | ||
if isinstance(old_val, torch.Tensor): | ||
new_val[:num_do_sample] = old_val | ||
setattr(new_metadata, p_name, new_val) | ||
|
||
xm.mark_step() | ||
xm.wait_device_ops() | ||
return new_metadata | ||
|
||
@classmethod | ||
def get_default_sampling_params( | ||
cls, | ||
num_samples: int, | ||
device: torch.device, | ||
indices_do_sample=None, | ||
do_argmax=None) -> "TPUSupportedSamplingMetadata": | ||
# As sampling happens on a single traced graph, options | ||
# are "disabled" by having them evaluate to an Identity op. | ||
# Note that initialization is dependent on num_samples. | ||
sampling_metadata_disable_value = \ | ||
TPUSupportedSamplingMetadata._get_default_params_values() | ||
init_kwargs = dict() | ||
for p_name, (default_val, | ||
dtype) in sampling_metadata_disable_value.items(): | ||
default_tensor = torch.full((num_samples, ), | ||
default_val, | ||
dtype=dtype, | ||
device=device) | ||
init_kwargs[p_name] = default_tensor | ||
|
||
return cls(**init_kwargs, | ||
indices_do_sample=indices_do_sample, | ||
do_argmax=do_argmax) | ||
|
||
@staticmethod | ||
def _validate_sampling_metadata( | ||
sampling_metadata: SamplingMetadata) -> SamplingMetadata: | ||
if sampling_metadata.all_greedy: | ||
# Set to None since #13587. Make sure default isn't overruled. | ||
assert sampling_metadata.temperature is None | ||
return sampling_metadata | ||
|
||
@staticmethod | ||
def _get_default_params_values(): | ||
return dict( | ||
# Since #13587 greedy sampling requires branching off which leads | ||
# to separate graphs. We set temp to noop and handle argmax here. | ||
temperature=(1.0, torch.float32), | ||
min_p=(0.0, torch.float32), | ||
# strictly disabled for now | ||
# top_k=(-1, torch.int32), | ||
# top_p=(0.0, torch.float32), | ||
# frequency_penalties=(0.0, torch.float32), | ||
# presence_penalties=(0.0, torch.float32), | ||
# repetition_penalties=(0.0, torch.float32), | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.