Skip to content

Commit 9765940

Browse files
authored
[TPU] Enable gemma3-27b with TP>1 on multi-chips. (#17335)
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent 5ea5c51 commit 9765940

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

tests/v1/tpu/test_basic.py

+43
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING
99

1010
import pytest
11+
from torch_xla._internal import tpu
1112

1213
from vllm.platforms import current_platform
1314

@@ -63,3 +64,45 @@ def test_basic(
6364
output = vllm_outputs[0][1]
6465

6566
assert "1024" in output or "0, 1" in output
67+
68+
69+
TP_SIZE_8 = 8
70+
71+
72+
@pytest.mark.skipif(not current_platform.is_tpu(),
73+
reason="This is a test for TPU only")
74+
@pytest.mark.skipif(tpu.num_available_chips() < TP_SIZE_8,
75+
reason=f"This test requires {TP_SIZE_8} TPU chips.")
76+
def test_gemma3_27b_with_text_input_and_tp(
77+
vllm_runner: type[VllmRunner],
78+
monkeypatch: pytest.MonkeyPatch,
79+
) -> None:
80+
model = "google/gemma-3-27b-it"
81+
max_tokens = 16
82+
tensor_parallel_size = TP_SIZE_8
83+
max_num_seqs = 4
84+
prompts = [
85+
"A robot may not injure a human being",
86+
"It is only with the heart that one can see rightly;",
87+
"The greatest glory in living lies not in never falling,",
88+
]
89+
answers = [
90+
" or, through inaction, allow a human being to come to harm.",
91+
" what is essential is invisible to the eye.",
92+
" but in rising every time we fall.",
93+
]
94+
95+
with monkeypatch.context() as m:
96+
m.setenv("VLLM_USE_V1", "1")
97+
98+
with vllm_runner(
99+
model,
100+
max_num_batched_tokens=256,
101+
max_num_seqs=max_num_seqs,
102+
tensor_parallel_size=tensor_parallel_size) as vllm_model:
103+
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
104+
# vllm_outputs is a list of tuples whose first element is the token id
105+
# and the second element is the output (including the prompt).
106+
for output, answer in zip(vllm_outputs, answers):
107+
generated_text = output[1]
108+
assert answer in generated_text

vllm/platforms/tpu.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class TpuPlatform(Platform):
3030
dispatch_key: str = "XLA"
3131
ray_device_key: str = "TPU"
3232
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
33+
simple_compile_backend: str = "openxla"
3334

3435
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
3536

0 commit comments

Comments
 (0)