Skip to content

Commit 5852e64

Browse files
committed
[V1] TPU - Add tensor parallel support via Ray
Signed-off-by: Alexander Matveev <[email protected]>
1 parent 4167252 commit 5852e64

File tree

6 files changed

+81
-3
lines changed

6 files changed

+81
-3
lines changed

tests/entrypoints/llm/test_accuracy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def run_test(more_args=None):
4242
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
4343

4444

45+
# TODO: [AlexM] Fix it with new CI/CD tests
46+
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
47+
48+
4549
@pytest.mark.skipif(not current_platform.is_cuda()
4650
and not current_platform.is_tpu(),
4751
reason="V1 is currently only supported on CUDA and TPU")
@@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
5660
# Limit compilation time for TPU V1
5761
more_args = "max_num_seqs=64"
5862

63+
# Add TP test (if provided)
64+
if TPU_TP_TEST_STR:
65+
more_args += ",{}".format(TPU_TP_TEST_STR)
66+
5967
run_test(more_args)
6068

6169

tests/v1/tpu/__init__.py

Whitespace-only changes.

tests/v1/tpu/test_basic.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""A basic correctness check for TPUs
3+
4+
Run `pytest tests/v1/tpu/test_basic.py`.
5+
"""
6+
import pytest
7+
8+
from vllm.platforms import current_platform
9+
10+
from ...conftest import VllmRunner
11+
12+
MODELS = [
13+
"Qwen/Qwen2-1.5B-Instruct",
14+
# TODO: Add models here as necessary
15+
]
16+
17+
TENSOR_PARALLEL_SIZES = [1]
18+
19+
# TODO: Enable when CI/CD will have a multi-tpu instance
20+
# TENSOR_PARALLEL_SIZES = [1, 4]
21+
22+
23+
@pytest.mark.skipif(not current_platform.is_tpu(),
24+
reason="This is a basic test for TPU only")
25+
@pytest.mark.parametrize("model", MODELS)
26+
@pytest.mark.parametrize("dtype", ["half"])
27+
@pytest.mark.parametrize("max_tokens", [5])
28+
@pytest.mark.parametrize("enforce_eager", [True])
29+
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
30+
def test_models(
31+
monkeypatch,
32+
hf_runner,
33+
model: str,
34+
dtype: str,
35+
max_tokens: int,
36+
enforce_eager: bool,
37+
tensor_parallel_size: int,
38+
) -> None:
39+
prompt = "The following numbers of the sequence " + ", ".join(
40+
str(i) for i in range(1024)) + " are:"
41+
example_prompts = [prompt]
42+
43+
with monkeypatch.context() as m:
44+
m.setenv("VLLM_USE_V1", "1")
45+
46+
with VllmRunner(
47+
model,
48+
max_model_len=8192,
49+
dtype=dtype,
50+
enforce_eager=enforce_eager,
51+
gpu_memory_utilization=0.7,
52+
max_num_seqs=16,
53+
tensor_parallel_size=tensor_parallel_size) as vllm_model:
54+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
55+
max_tokens)
56+
output = vllm_outputs[0][1]
57+
assert output.strip().endswith("1024")

vllm/executor/ray_distributed_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
5454
def _init_executor(self) -> None:
5555
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
5656
if envs.VLLM_USE_V1:
57-
# v1 always uses the compiled DAG and SPMD worker.
57+
# V1 uses SPMD worker and compiled DAG
5858
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
5959
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
60+
61+
# For TPU, avoid compiling NVIDIA's NCCL
62+
if current_platform.is_tpu():
63+
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
64+
6065
# If the env var is set, it uses the Ray's compiled DAG API
6166
# which optimizes the control plane overhead.
6267
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

vllm/executor/ray_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import ParallelConfig
1212
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1313
from vllm.logger import init_logger
14+
from vllm.platforms import current_platform
1415
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1516
from vllm.utils import get_ip
1617
from vllm.worker.worker_base import WorkerWrapperBase
@@ -106,10 +107,15 @@ def setup_device_if_necessary(self):
106107
# on a background thread, so we need to reset torch's current
107108
# device.
108109
# We can remove this API after it is fixed in compiled graph.
109-
import torch
110110
assert self.worker is not None, "Worker is not initialized"
111111
if not self.compiled_dag_cuda_device_set:
112-
torch.cuda.set_device(self.worker.device)
112+
if current_platform.is_tpu():
113+
# Not needed
114+
pass
115+
else:
116+
import torch
117+
torch.cuda.set_device(self.worker.device)
118+
113119
self.compiled_dag_cuda_device_set = True
114120

115121
def execute_model_ray(

vllm/v1/worker/tpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.logger import init_logger
1919
from vllm.model_executor.model_loader import get_model
2020
from vllm.sampling_params import SamplingType
21+
from vllm.sequence import IntermediateTensors
2122
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
2223
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
2324
NUM_QUERIES_PER_BLOCK,
@@ -430,6 +431,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
430431
def execute_model(
431432
self,
432433
scheduler_output: "SchedulerOutput",
434+
intermediate_tensors: Optional[IntermediateTensors] = None,
433435
) -> ModelRunnerOutput:
434436
# Update cached state
435437
self._update_states(scheduler_output)

0 commit comments

Comments
 (0)