Skip to content

Commit d406bc9

Browse files
NickLucchemawong-amd
authored andcommitted
[TPU][V1] Add support for top-logprobs (vllm-project#17072)
Signed-off-by: NickLucche <[email protected]>
1 parent 8375309 commit d406bc9

File tree

4 files changed

+105
-17
lines changed

4 files changed

+105
-17
lines changed

tests/v1/tpu/test_sampler.py

+48
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,51 @@ def test_sampler_different(model_name: str):
6161
# to have deterministic results over many tokens, tests the first ~20
6262
# tokens match.
6363
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
64+
65+
66+
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
67+
# TODO TPU will appear busy if we fan-out test params here
68+
@pytest.mark.parametrize("n_prompts", [1])
69+
@pytest.mark.skipif(not current_platform.is_tpu(),
70+
reason="This test needs a TPU")
71+
def test_logprobs(model_name: str, n_prompts: int):
72+
"""
73+
Request top logprobs with different sampling settings and check
74+
that results contains the requested number, ordered ascendingly.
75+
"""
76+
77+
def check_num_logprobs(logprobs, expected_num: int):
78+
for step in logprobs:
79+
prev_logp = 1.0
80+
# order by rank
81+
sorted_step = dict(
82+
sorted(step.items(), key=lambda item: item[1].rank))
83+
84+
# Can contain the sampled token
85+
assert len(step) == expected_num or len(step) == expected_num + 1
86+
# Check results are ordered by prob value
87+
for rankno, (tid, logp) in enumerate(sorted_step.items()):
88+
assert logp.logprob <= prev_logp
89+
prev_logp = logp.logprob
90+
assert logp.rank == rankno + 1
91+
92+
llm = LLM(model_name,
93+
enforce_eager=False,
94+
max_num_seqs=1,
95+
max_model_len=128,
96+
max_num_batched_tokens=128)
97+
prompts = [
98+
"Write a short story about a robot that dreams for the first time."
99+
] * n_prompts
100+
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
101+
logprobs=4)
102+
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
103+
logprobs=4)
104+
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
105+
logprobs=4, top_k=12, top_p=0.5)
106+
107+
for sp in [greedy_sampling_params, regular_sampling_params, \
108+
topkp_sampling_params]:
109+
output = llm.generate(prompts, sp)
110+
for o in output:
111+
check_num_logprobs(o.outputs[0].logprobs, 4)

vllm/v1/sample/tpu/metadata.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ class TPUSupportedSamplingMetadata:
3131

3232
all_greedy: bool = True
3333

34-
# unsupported, you need to return an extra tensor of static size BxV
35-
max_num_logprobs = None
34+
# Whether logprobs are to be gathered in this batch of request. To balance
35+
# out compile time and runtime, a fixed `max_number_logprobs` value is used
36+
# when gathering logprobs, regardless of the values specified in the batch.
37+
logprobs: bool = False
3638

3739
# TODO No penalties for now
3840
no_penalties: bool = True
@@ -84,10 +86,12 @@ def from_input_batch(
8486
we want to pre-compile a graph with sampling parameters, even if
8587
they are not strictly needed for greedy decoding.
8688
"""
89+
needs_logprobs = input_batch.max_num_logprobs>0 if \
90+
input_batch.max_num_logprobs else False
8791
# Early return to avoid unnecessary cpu to tpu copy
8892
if (input_batch.all_greedy is True
8993
and generate_params_if_all_greedy is False):
90-
return cls(all_greedy=True)
94+
return cls(all_greedy=True, logprobs=needs_logprobs)
9195

9296
num_reqs = input_batch.num_reqs
9397

@@ -115,4 +119,5 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
115119
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
116120
xla_device),
117121
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
118-
xla_device))
122+
xla_device),
123+
logprobs=needs_logprobs)

vllm/v1/sample/tpu/sampler.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,25 @@ def forward(
2222
logits: torch.Tensor,
2323
sampling_metadata: TPUSupportedSamplingMetadata,
2424
) -> SamplerOutput:
25-
# NOTE(woosuk): Use the original logits (before any penalties or
26-
# temperature scaling) for the top-k logprobs.
27-
# This is different from the V0 sampler, which uses the logits that
28-
# is used for sampling (after penalties and temperature scaling).
29-
3025
# Use float32 for the logits.
3126
logits = logits.to(torch.float32)
3227
# Sample the next token.
3328
sampled = self.sample(logits, sampling_metadata)
3429

35-
# Use int32 to reduce the tensor size.
36-
sampled = sampled.to(torch.int32)
37-
38-
# These are GPU tensors.
30+
# These are TPU tensors.
3931
sampler_output = SamplerOutput(
4032
# The sampled tokens are expanded to 2D tensor with shape
4133
# [num_requests, 1], where each row represents one generated
4234
# token per request.
4335
sampled_token_ids=sampled.unsqueeze(-1),
44-
logprobs_tensors=None,
45-
)
36+
logprobs_tensors=None)
4637
return sampler_output
4738

4839
def apply_temperature(
4940
self,
5041
logits: torch.Tensor,
5142
temp: torch.Tensor,
5243
) -> torch.Tensor:
53-
# Use in-place division to avoid creating a new tensor.
5444
return logits.div_(temp.unsqueeze(dim=1))
5545

5646
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:

vllm/v1/worker/tpu_model_runner.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -791,8 +791,18 @@ def execute_model(
791791
arange)
792792
selected_token_ids = self.sample_from_logits(logits,
793793
tpu_sampling_metadata)
794+
795+
# NOTE (NickLucche) Use the original logits (before any penalties or
796+
# temperature scaling) for the top-k logprobs. We can't enforce it due
797+
# to recompilations outside torch.compiled code, so just make sure
798+
# `sample_from_logits` does not modify the logits in-place.
799+
logprobs = self.gather_logprobs(logits, selected_token_ids) \
800+
if tpu_sampling_metadata.logprobs else None
801+
794802
# Remove padding on cpu and keep dynamic op outside of xla graph.
795803
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
804+
logprobs_lists = logprobs.tolists() \
805+
if tpu_sampling_metadata.logprobs else None
796806

797807
# Update the cache state concurrently. Code above will not block until
798808
# we use `selected_token_ids`. Add mark_step if post-processing changes
@@ -862,7 +872,7 @@ def execute_model(
862872
req_id_to_index=self.input_batch.req_id_to_index,
863873
sampled_token_ids=valid_sampled_token_ids,
864874
spec_token_ids=None,
865-
logprobs=None,
875+
logprobs=logprobs_lists,
866876
prompt_logprobs_dict=prompt_logprobs_dict,
867877
)
868878

@@ -1121,6 +1131,22 @@ def _precompile_sample_from_logits(self) -> None:
11211131
logger.info("Compilation finished in %.2f [secs].", end - start)
11221132
self._update_num_xla_graphs("sample_from_logits")
11231133

1134+
def _precompile_gather_logprobs(self) -> None:
1135+
logger.info("Compiling gather_logprobs with different input shapes.")
1136+
start = time.perf_counter()
1137+
for num_reqs in self.num_reqs_paddings:
1138+
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
1139+
device=self.device,
1140+
dtype=self._hidden_states_dtype)
1141+
dummy_tokens = torch.zeros((num_reqs, 1),
1142+
dtype=torch.int64).to(self.device)
1143+
self.gather_logprobs(dummy_logits, dummy_tokens)
1144+
logger.info(" -- num_seqs: %d", num_reqs)
1145+
xm.wait_device_ops()
1146+
end = time.perf_counter()
1147+
logger.info("Compilation finished in %.2f [secs].", end - start)
1148+
self._update_num_xla_graphs("gather_logprobs")
1149+
11241150
def capture_model(self) -> None:
11251151
"""
11261152
Precompile all the subgraphs with possible input shapes.
@@ -1131,6 +1157,7 @@ def capture_model(self) -> None:
11311157
self._precompile_compute_logits()
11321158
self._precompile_structured_decoding()
11331159
self._precompile_sample_from_logits()
1160+
self._precompile_gather_logprobs()
11341161

11351162
def profile_run(
11361163
self,
@@ -1254,13 +1281,31 @@ def compute_logits(self,
12541281
def sample_from_logits(
12551282
self, logits: torch.Tensor,
12561283
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
1284+
"""
1285+
Sample with xla-friendly function. This function is to be traced
1286+
separately from `forward` for lighter compilation overhead.
1287+
"""
12571288
if sampling_metadata.all_greedy:
12581289
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
12591290
else:
12601291
out_tokens = self.sampler(logits,
12611292
sampling_metadata).sampled_token_ids
12621293
return out_tokens
12631294

1295+
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1296+
def gather_logprobs(self, logits: torch.Tensor,
1297+
sampled_tokens: torch.Tensor) -> LogprobsTensors:
1298+
"""
1299+
Gather the top_logprobs with corresponding tokens. Use a fixed number
1300+
of logprobs as an alternative to having multiple pre-compiled graphs.
1301+
Select the number of logprobs actually demanded by each request on CPU.
1302+
"""
1303+
logprobs = self.sampler.compute_logprobs(logits)
1304+
return self.sampler.gather_logprobs(
1305+
logprobs,
1306+
self.model_config.max_logprobs,
1307+
token_ids=sampled_tokens.squeeze(-1))
1308+
12641309
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
12651310
def structured_decode(self, require_struct_decoding: torch.Tensor,
12661311
grammar_bitmask: torch.Tensor, logits: torch.Tensor,

0 commit comments

Comments
 (0)