Skip to content

Commit 50a00ad

Browse files
committed
address review
Signed-off-by: NickLucche <[email protected]>
1 parent 61ec7f8 commit 50a00ad

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

Diff for: tests/v1/tpu/test_sampler.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313

1414

1515
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
16+
@pytest.mark.parametrize("disable_sampler", [False, True])
1617
@pytest.mark.skipif(not current_platform.is_tpu(),
1718
reason="This test needs a TPU")
18-
def test_sampler_different(model_name: str):
19+
def test_sampler_different(model_name: str, disable_sampler: bool,
20+
monkeypatch):
1921
"""
2022
Test significantly different sampling params to assert the model produces
2123
different results.
2224
"""
25+
if disable_sampler:
26+
monkeypatch.setenv("VLLM_TPU_DISABLE_SAMPLER_DEBUG", "1")
2327
llm = LLM(model_name,
2428
enforce_eager=False,
2529
max_num_seqs=1,
@@ -33,4 +37,8 @@ def test_sampler_different(model_name: str):
3337

3438
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
3539
output2 = llm.generate(prompts, sampling_params)
36-
assert output[0].outputs[0].text != output2[0].outputs[0].text
40+
if disable_sampler:
41+
# When sampler is off, params are accepted but ignored (argmax-only).
42+
assert output[0].outputs[0].text == output2[0].outputs[0].text
43+
else:
44+
assert output[0].outputs[0].text != output2[0].outputs[0].text

Diff for: vllm/v1/worker/tpu_model_runner.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -639,34 +639,26 @@ def execute_model(
639639
inputs_embeds = None
640640
num_reqs = self.input_batch.num_reqs
641641

642-
# Temporary debug pathway.
642+
with set_forward_context(attn_metadata, self.vllm_config):
643+
hidden_states = self.model(
644+
input_ids=input_ids,
645+
positions=self.position_ids,
646+
kv_caches=self.kv_caches,
647+
inputs_embeds=inputs_embeds,
648+
)
649+
# Temporary debug pathway for sampling.
643650
if self._disable_sampler:
644-
with set_forward_context(attn_metadata, self.vllm_config):
645-
hidden_states = self.model(
646-
input_ids=input_ids,
647-
positions=self.position_ids,
648-
kv_caches=self.kv_caches,
649-
inputs_embeds=inputs_embeds,
650-
)
651651
selected_token_ids = self.model.compute_logits_no_sampler(
652652
hidden_states, logits_indices)
653-
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
654653
else:
655654
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
656655
# are copied to device in chunks of pre-compiled padded shape to
657656
# avoid recompilations.
658657
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
659658
from_input_batch(self.input_batch, logits_indices)
660-
with set_forward_context(attn_metadata, self.vllm_config):
661-
hidden_states = self.model(
662-
input_ids=input_ids,
663-
positions=self.position_ids,
664-
kv_caches=self.kv_caches,
665-
inputs_embeds=inputs_embeds,
666-
)
667-
selected_token_ids = self.model.sample_from_hidden(
668-
hidden_states, tpu_sampling_metadata)
669-
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
659+
selected_token_ids = self.model.sample_from_hidden(
660+
hidden_states, tpu_sampling_metadata)
661+
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
670662

671663
# Update the cache state concurrently. Code above will not block until
672664
# we use `selected_token_ids`. Add mark_step if post-processing changes

0 commit comments

Comments
 (0)