Skip to content

Commit 50ef555

Browse files
committed
fix recompilation issue on sampling graph; add new tpu sampler
Signed-off-by: NickLucche <[email protected]>
1 parent 10e1a04 commit 50ef555

File tree

5 files changed

+184
-18
lines changed

5 files changed

+184
-18
lines changed

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def forward_tpu(
8989
p: Optional[torch.Tensor],
9090
) -> torch.Tensor:
9191
# TODO Placeholder for TPU optimized topk/p kernel
92-
return self.forward_native(logits, generators, k, p)
92+
# logits = apply_top_k_top_p(logits, k, p)
93+
probs = logits.softmax(dim=-1, dtype=torch.float32)
94+
return random_sample(probs, generators)
9395

9496

9597
def apply_top_k_top_p(

vllm/v1/sample/tpu/__init__.py

Whitespace-only changes.

vllm/v1/sample/tpu_metadata.py renamed to vllm/v1/sample/tpu/metadata.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from vllm.v1.sample.metadata import SamplingMetadata
99

1010

11-
# TODO (NickLucche) keep in sync with SamplingMetadata until we can drop
12-
# this class and support most options.
1311
@dataclass
1412
class TPUSupportedSamplingMetadata:
1513
# This class exposes a more xla-friendly interface than SamplingMetadata
@@ -158,4 +156,4 @@ def _get_default_params_values():
158156
# frequency_penalties=(0.0, torch.float32),
159157
# presence_penalties=(0.0, torch.float32),
160158
# repetition_penalties=(0.0, torch.float32),
161-
)
159+
)

vllm/v1/sample/tpu/sampler.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Sampler layer implementing TPU supported operations."""
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
8+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
9+
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
10+
11+
_SAMPLING_EPS = 1e-5
12+
13+
14+
class Sampler(nn.Module):
15+
16+
def __init__(self):
17+
super().__init__()
18+
self.topk_topp_sampler = TopKTopPSampler()
19+
20+
def forward(
21+
self,
22+
logits: torch.Tensor,
23+
sampling_metadata: TPUSupportedSamplingMetadata,
24+
) -> SamplerOutput:
25+
# NOTE(woosuk): Use the original logits (before any penalties or
26+
# temperature scaling) for the top-k logprobs.
27+
# This is different from the V0 sampler, which uses the logits that
28+
# is used for sampling (after penalties and temperature scaling).
29+
30+
# Use float32 for the logits.
31+
logits = logits.to(torch.float32)
32+
# Sample the next token.
33+
sampled = self.sample(logits, sampling_metadata)
34+
35+
# Use int32 to reduce the tensor size.
36+
sampled = sampled.to(torch.int32)
37+
38+
# These are GPU tensors.
39+
sampler_output = SamplerOutput(
40+
# The sampled tokens are expanded to 2D tensor with shape
41+
# [num_requests, 1], where each row represents one generated
42+
# token per request.
43+
sampled_token_ids=sampled.unsqueeze(-1),
44+
logprobs_tensors=None,
45+
)
46+
return sampler_output
47+
48+
def apply_temperature(
49+
self,
50+
logits: torch.Tensor,
51+
temp: torch.Tensor,
52+
) -> torch.Tensor:
53+
# Use in-place division to avoid creating a new tensor.
54+
return logits.div_(temp.unsqueeze(dim=1))
55+
56+
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
57+
return logits.argmax(dim=-1).view(-1)
58+
59+
def sample(
60+
self,
61+
logits: torch.Tensor,
62+
sampling_metadata: TPUSupportedSamplingMetadata,
63+
) -> torch.Tensor:
64+
greedy_sampled = self.greedy_sample(logits)
65+
66+
assert sampling_metadata.temperature is not None
67+
68+
# Apply temperature.
69+
logits = self.apply_temperature(logits, sampling_metadata.temperature)
70+
71+
# Apply min_p.
72+
if sampling_metadata.min_p is not None:
73+
logits = self.apply_min_p(logits, sampling_metadata.min_p)
74+
75+
# Apply top_k and/or top_p.
76+
random_sampled = self.topk_topp_sampler(
77+
logits,
78+
sampling_metadata.generators,
79+
sampling_metadata.top_k,
80+
sampling_metadata.top_p,
81+
)
82+
83+
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
84+
greedy_sampled, random_sampled)
85+
return sampled
86+
87+
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
88+
return logits.log_softmax(dim=-1, dtype=torch.float32)
89+
90+
def gather_logprobs(
91+
self,
92+
logprobs: torch.Tensor,
93+
num_logprobs: int,
94+
token_ids: torch.Tensor,
95+
) -> LogprobsTensors:
96+
"""
97+
Gather logprobs for topk and sampled/prompt token.
98+
99+
Args:
100+
logits: (num tokens) x (vocab) tensor
101+
num_logprobs: minimum number of logprobs to
102+
retain per token
103+
token_ids: prompt tokens (if prompt logprobs)
104+
or sampled tokens (if sampled
105+
logprobs); 1D token ID tensor
106+
with (num tokens) elements
107+
108+
Returns:
109+
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
110+
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
111+
Sampled token rank tensor, (num tokens)
112+
"""
113+
# Find the topK values.
114+
topk_logprobs, topk_indices = torch.topk(logprobs,
115+
num_logprobs,
116+
dim=-1)
117+
118+
# Get with the logprob of the prompt or sampled token.
119+
token_ids = token_ids.unsqueeze(-1)
120+
token_logprobs = logprobs.gather(-1, token_ids)
121+
122+
# Compute the ranks of the actual token.
123+
token_ranks = (logprobs >= token_logprobs).sum(-1)
124+
125+
# Concatenate together with the topk.
126+
indices = torch.cat((token_ids, topk_indices), dim=1)
127+
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
128+
129+
# Use int32 to reduce the tensor size.
130+
indices = indices.to(torch.int32)
131+
132+
return LogprobsTensors(indices, logprobs, token_ranks)
133+
134+
def apply_min_p(
135+
self,
136+
logits: torch.Tensor,
137+
min_p: torch.Tensor,
138+
) -> torch.Tensor:
139+
"""
140+
Filters logits using adaptive probability thresholding.
141+
"""
142+
# Convert logits to probability distribution
143+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
144+
# Calculate maximum probabilities per sequence
145+
max_probabilities = torch.amax(probability_values,
146+
dim=-1,
147+
keepdim=True)
148+
# Reshape min_p for broadcasting
149+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
150+
# Identify valid tokens using threshold comparison
151+
valid_token_mask = probability_values >= adjusted_min_p
152+
# Apply mask using boolean indexing (xla friendly)
153+
logits.masked_fill_(~valid_token_mask, -float("inf"))
154+
return logits

vllm/v1/worker/tpu_model_runner.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323
from vllm.sampling_params import SamplingType
2424
from vllm.sequence import IntermediateTensors
2525
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
26-
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
26+
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
27+
PallasAttentionBackend,
2728
PallasMetadata)
2829
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
2930
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
3031
KVCacheSpec)
3132
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
3233
ModelRunnerOutput, SamplerOutput)
33-
from vllm.v1.sample.metadata import SamplingMetadata
34-
from vllm.v1.sample.sampler import Sampler
35-
from vllm.v1.sample.tpu_metadata import TPUSupportedSamplingMetadata
34+
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
35+
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
3636
from vllm.v1.utils import bind_kv_cache
3737
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
3838

@@ -142,9 +142,10 @@ def __init__(
142142
dtype=torch.int64,
143143
device="cpu")
144144
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
145-
145+
padded_max_num_blocks_per_req = _get_padded_number(
146+
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
146147
self.block_table_cpu = torch.zeros(
147-
(self.max_num_tokens, self.max_num_blocks_per_req),
148+
(self.max_num_tokens, padded_max_num_blocks_per_req),
148149
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
149150
device="cpu")
150151

@@ -595,8 +596,6 @@ def execute_model(
595596
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
596597
from_sampling_metadata(sampling_metadata, logits_indices,
597598
num_reqs, self.device)
598-
# Make mypy happy
599-
sampling_metadata = cast(SamplingMetadata, tpu_sampling_metadata)
600599
# Run the decoder
601600
with set_forward_context(attn_metadata, self.vllm_config):
602601
hidden_states = self.model(
@@ -605,9 +604,8 @@ def execute_model(
605604
kv_caches=self.kv_caches,
606605
inputs_embeds=inputs_embeds,
607606
)
608-
xm.mark_step() # break model graph
609607
selected_token_ids = self.model.sample_from_hidden(
610-
hidden_states, sampling_metadata)
608+
hidden_states, tpu_sampling_metadata)
611609
# Remove padding on cpu and keep dynamic op outside of xla graph.
612610
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
613611

@@ -780,11 +778,21 @@ def capture_model(self) -> None:
780778
# n_tokens x max_num_reqs. Graph is really small so this is fine.
781779
while True:
782780
num_reqs_to_sample = MIN_NUM_SEQS
783-
dummy_hidden = torch.randn((num_tokens, hsize), device=device)
781+
dummy_hidden = torch.randn((num_tokens, hsize),
782+
device=device,
783+
dtype=torch.bfloat16)
784784
while True:
785-
# To allow sampling, trace with all supported sampling args.
785+
# Default metadata is an all_greedy setup. But since the
786+
# `do_argmax` flag is a tensor, we still compile the full graph
787+
meta = self.input_batch.sampling_metadata
788+
indices = torch.zeros(
789+
num_reqs_to_sample,
790+
dtype=torch.int32,
791+
device=device,
792+
)
786793
sampling_meta = TPUSupportedSamplingMetadata.\
787-
get_default_sampling_params(num_reqs_to_sample, device)
794+
from_sampling_metadata(meta, indices,
795+
num_reqs_to_sample, device)
788796
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
789797
num_reqs_to_sample)
790798
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
@@ -843,7 +851,7 @@ class ModelWrapperV1(nn.Module):
843851
def __init__(self, model: nn.Module):
844852
super().__init__()
845853
self.model = model
846-
self.sampler = Sampler()
854+
self.sampler = TPUSampler()
847855

848856
def sample(
849857
self, logits: torch.Tensor,
@@ -912,6 +920,10 @@ def get_input_embeddings(self, *args, **kwargs):
912920
return self.model.get_input_embeddings(*args, **kwargs)
913921

914922

923+
def _get_padded_number(n: int, multiple: int) -> int:
924+
return ((n + multiple - 1) // multiple) * multiple
925+
926+
915927
def _get_padded_token_len(x: int) -> int:
916928
if x <= 16:
917929
return 16

0 commit comments

Comments
 (0)