Skip to content

Commit ca564f5

Browse files
committed
address review
Signed-off-by: NickLucche <[email protected]>
1 parent 49d5fcb commit ca564f5

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

tests/v1/tpu/test_sampler.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313

1414

1515
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
16+
@pytest.mark.parametrize("disable_sampler", [False, True])
1617
@pytest.mark.skipif(not current_platform.is_tpu(),
1718
reason="This test needs a TPU")
18-
def test_sampler_different(model_name: str):
19+
def test_sampler_different(model_name: str, disable_sampler: bool,
20+
monkeypatch):
1921
"""
2022
Test significantly different sampling params to assert the model produces
2123
different results.
2224
"""
25+
if disable_sampler:
26+
monkeypatch.setenv("VLLM_TPU_DISABLE_SAMPLER_DEBUG", "1")
2327
llm = LLM(model_name,
2428
enforce_eager=False,
2529
max_num_seqs=1,
@@ -33,4 +37,8 @@ def test_sampler_different(model_name: str):
3337

3438
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
3539
output2 = llm.generate(prompts, sampling_params)
36-
assert output[0].outputs[0].text != output2[0].outputs[0].text
40+
if disable_sampler:
41+
# When sampler is off, params are accepted but ignored (argmax-only).
42+
assert output[0].outputs[0].text == output2[0].outputs[0].text
43+
else:
44+
assert output[0].outputs[0].text != output2[0].outputs[0].text

vllm/v1/worker/tpu_model_runner.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -684,34 +684,26 @@ def execute_model(
684684
inputs_embeds = None
685685
num_reqs = self.input_batch.num_reqs
686686

687-
# Temporary debug pathway.
687+
with set_forward_context(attn_metadata, self.vllm_config):
688+
hidden_states = self.model(
689+
input_ids=input_ids,
690+
positions=self.position_ids,
691+
kv_caches=self.kv_caches,
692+
inputs_embeds=inputs_embeds,
693+
)
694+
# Temporary debug pathway for sampling.
688695
if self._disable_sampler:
689-
with set_forward_context(attn_metadata, self.vllm_config):
690-
hidden_states = self.model(
691-
input_ids=input_ids,
692-
positions=self.position_ids,
693-
kv_caches=self.kv_caches,
694-
inputs_embeds=inputs_embeds,
695-
)
696696
selected_token_ids = self.model.compute_logits_no_sampler(
697697
hidden_states, logits_indices)
698-
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
699698
else:
700699
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
701700
# are copied to device in chunks of pre-compiled padded shape to
702701
# avoid recompilations.
703702
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
704703
from_input_batch(self.input_batch, logits_indices)
705-
with set_forward_context(attn_metadata, self.vllm_config):
706-
hidden_states = self.model(
707-
input_ids=input_ids,
708-
positions=self.position_ids,
709-
kv_caches=self.kv_caches,
710-
inputs_embeds=inputs_embeds,
711-
)
712-
selected_token_ids = self.model.sample_from_hidden(
713-
hidden_states, tpu_sampling_metadata)
714-
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
704+
selected_token_ids = self.model.sample_from_hidden(
705+
hidden_states, tpu_sampling_metadata)
706+
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
715707

716708
# Update the cache state concurrently. Code above will not block until
717709
# we use `selected_token_ids`. Add mark_step if post-processing changes

0 commit comments

Comments
 (0)