Skip to content

Commit 9c77620

Browse files
NickLucchemgoin
authored andcommitted
[TPU][V1] Enable Top-P (vllm-project#16843)
Signed-off-by: NickLucche <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent 1929795 commit 9c77620

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

tests/v1/tpu/test_sampler.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_sampler_different(model_name: str):
4242
sampling_params = SamplingParams(temperature=0.3, seed=42)
4343
output2 = llm.generate(prompts, sampling_params)
4444

45-
# Batch-case with TopK
45+
# Batch-case with TopK/P
4646
for B in [4, 16]:
4747
p = prompts * B
4848
sampling_params = [
@@ -51,9 +51,10 @@ def test_sampler_different(model_name: str):
5151
min_p=0.8,
5252
max_tokens=64,
5353
# Vary number of ks
54-
top_k=random.randint(4, 12)) for _ in range(B)
54+
top_k=random.randint(4, 12),
55+
top_p=random.random()) for _ in range(B)
5556
]
56-
# Make sure first two reqs have the same K
57+
# Make sure first two reqs have the same K/P
5758
sampling_params[0] = sampling_params[1]
5859
output = llm.generate(p, sampling_params)
5960
assert output[0].outputs[0].text == output[1].outputs[0].text

vllm/v1/sample/tpu/metadata.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
min_p=0.0,
1212
# strictly disabled for now
1313
top_k=0,
14-
# top_p=0.0,
14+
top_p=1.0,
1515
# frequency_penalties=0.0,
1616
# presence_penalties=0.0,
1717
# repetition_penalties=0.0,
@@ -26,11 +26,9 @@ class TPUSupportedSamplingMetadata:
2626
temperature: torch.Tensor = None
2727

2828
min_p: torch.Tensor = None
29-
# Still too slow on forward_native!
3029
top_k: torch.Tensor = None
3130
top_p: torch.Tensor = None
3231

33-
# Greedy sampling flag for compiling single xla graph.
3432
all_greedy: bool = True
3533

3634
# unsupported, you need to return an extra tensor of static size BxV
@@ -103,17 +101,17 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
103101
DEFAULT_SAMPLING_PARAMS["min_p"])
104102
fill_slice(input_batch.top_k_cpu_tensor,
105103
DEFAULT_SAMPLING_PARAMS["top_k"])
106-
# TODO Temporarily disabled until sampling options are enabled
107-
# fill_slice(input_batch.top_p_cpu_tensor,
108-
# DEFAULT_SAMPLING_PARAMS["top_p"])
104+
fill_slice(input_batch.top_p_cpu_tensor,
105+
DEFAULT_SAMPLING_PARAMS["top_p"])
109106

110107
# Slice persistent device tensors to a fixed pre-compiled padded shape.
111108
return cls(
112109
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
113110
to(xla_device),
114111
all_greedy=input_batch.all_greedy,
115112
# TODO enable more and avoid returning None values
116-
top_p=None, # input_batch.top_p[:padded_num_reqs],
113+
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
114+
xla_device),
117115
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
118116
xla_device),
119117
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(

0 commit comments

Comments
 (0)