Skip to content

[TPU][V1] Add support for top-logprobs #17072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,51 @@ def test_sampler_different(model_name: str):
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]


@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
# TODO TPU will appear busy if we fan-out test params here
@pytest.mark.parametrize("n_prompts", [1])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_logprobs(model_name: str, n_prompts: int):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""

def check_num_logprobs(logprobs, expected_num: int):
for step in logprobs:
prev_logp = 1.0
# order by rank
sorted_step = dict(
sorted(step.items(), key=lambda item: item[1].rank))

# Can contain the sampled token
assert len(step) == expected_num or len(step) == expected_num + 1
# Check results are ordered by prob value
for rankno, (tid, logp) in enumerate(sorted_step.items()):
assert logp.logprob <= prev_logp
prev_logp = logp.logprob
assert logp.rank == rankno + 1

llm = LLM(model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=128,
max_num_batched_tokens=128)
prompts = [
"Write a short story about a robot that dreams for the first time."
] * n_prompts
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
logprobs=4)
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4)
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4, top_k=12, top_p=0.5)

for sp in [greedy_sampling_params, regular_sampling_params, \
topkp_sampling_params]:
output = llm.generate(prompts, sp)
for o in output:
check_num_logprobs(o.outputs[0].logprobs, 4)
13 changes: 9 additions & 4 deletions vllm/v1/sample/tpu/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ class TPUSupportedSamplingMetadata:

all_greedy: bool = True

# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
# Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used
# when gathering logprobs, regardless of the values specified in the batch.
logprobs: bool = False

# TODO No penalties for now
no_penalties: bool = True
Expand Down Expand Up @@ -84,10 +86,12 @@ def from_input_batch(
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
needs_logprobs = input_batch.max_num_logprobs>0 if \
input_batch.max_num_logprobs else False
# Early return to avoid unnecessary cpu to tpu copy
if (input_batch.all_greedy is True
and generate_params_if_all_greedy is False):
return cls(all_greedy=True)
return cls(all_greedy=True, logprobs=needs_logprobs)

num_reqs = input_batch.num_reqs

Expand Down Expand Up @@ -115,4 +119,5 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
xla_device),
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device))
xla_device),
logprobs=needs_logprobs)
14 changes: 2 additions & 12 deletions vllm/v1/sample/tpu/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,25 @@ def forward(
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).

# Use float32 for the logits.
logits = logits.to(torch.float32)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)

# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)

# These are GPU tensors.
# These are TPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=None,
)
logprobs_tensors=None)
return sampler_output

def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))

def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
Expand Down
47 changes: 46 additions & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,18 @@ def execute_model(
arange)
selected_token_ids = self.sample_from_logits(logits,
tpu_sampling_metadata)

# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
# `sample_from_logits` does not modify the logits in-place.
logprobs = self.gather_logprobs(logits, selected_token_ids) \
if tpu_sampling_metadata.logprobs else None

# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
logprobs_lists = logprobs.tolists() \
if tpu_sampling_metadata.logprobs else None

# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
Expand Down Expand Up @@ -862,7 +872,7 @@ def execute_model(
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
logprobs=None,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)

Expand Down Expand Up @@ -1121,6 +1131,22 @@ def _precompile_sample_from_logits(self) -> None:
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("sample_from_logits")

def _precompile_gather_logprobs(self) -> None:
logger.info("Compiling gather_logprobs with different input shapes.")
start = time.perf_counter()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
device=self.device,
dtype=self._hidden_states_dtype)
dummy_tokens = torch.zeros((num_reqs, 1),
dtype=torch.int64).to(self.device)
self.gather_logprobs(dummy_logits, dummy_tokens)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("gather_logprobs")

def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
Expand All @@ -1131,6 +1157,7 @@ def capture_model(self) -> None:
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()

def profile_run(
self,
Expand Down Expand Up @@ -1254,13 +1281,31 @@ def compute_logits(self,
def sample_from_logits(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
out_tokens = self.sampler(logits,
sampling_metadata).sampled_token_ids
return out_tokens

@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def gather_logprobs(self, logits: torch.Tensor,
sampled_tokens: torch.Tensor) -> LogprobsTensors:
"""
Gather the top_logprobs with corresponding tokens. Use a fixed number
of logprobs as an alternative to having multiple pre-compiled graphs.
Select the number of logprobs actually demanded by each request on CPU.
"""
logprobs = self.sampler.compute_logprobs(logits)
return self.sampler.gather_logprobs(
logprobs,
self.model_config.max_logprobs,
token_ids=sampled_tokens.squeeze(-1))

@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def structured_decode(self, require_struct_decoding: torch.Tensor,
grammar_bitmask: torch.Tensor, logits: torch.Tensor,
Expand Down