Skip to content

Commit cb8bdfa

Browse files
authored
[V1] TPU - Add tensor parallel support via Ray (#13618)
Signed-off-by: Alexander Matveev <[email protected]>
1 parent 33f227e commit cb8bdfa

File tree

7 files changed

+80
-4
lines changed

7 files changed

+80
-4
lines changed

tests/entrypoints/llm/test_accuracy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def run_test(more_args=None):
4242
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
4343

4444

45+
# TODO: [AlexM] Fix it with new CI/CD tests
46+
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
47+
48+
4549
@pytest.mark.skipif(not current_platform.is_cuda()
4650
and not current_platform.is_tpu(),
4751
reason="V1 is currently only supported on CUDA and TPU")
@@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
5660
# Limit compilation time for TPU V1
5761
more_args = "max_num_seqs=64"
5862

63+
# Add TP test (if provided)
64+
if TPU_TP_TEST_STR:
65+
more_args += ",{}".format(TPU_TP_TEST_STR)
66+
5967
run_test(more_args)
6068

6169

tests/v1/tpu/__init__.py

Whitespace-only changes.

tests/v1/tpu/test_basic.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""A basic correctness check for TPUs
3+
4+
Run `pytest tests/v1/tpu/test_basic.py`.
5+
"""
6+
import pytest
7+
8+
from vllm.platforms import current_platform
9+
10+
from ...conftest import VllmRunner
11+
12+
MODELS = [
13+
# "Qwen/Qwen2-7B-Instruct",
14+
"meta-llama/Llama-3.1-8B",
15+
# TODO: Add models here as necessary
16+
]
17+
18+
TENSOR_PARALLEL_SIZES = [1]
19+
20+
# TODO: Enable when CI/CD will have a multi-tpu instance
21+
# TENSOR_PARALLEL_SIZES = [1, 4]
22+
23+
24+
@pytest.mark.skipif(not current_platform.is_tpu(),
25+
reason="This is a basic test for TPU only")
26+
@pytest.mark.parametrize("model", MODELS)
27+
@pytest.mark.parametrize("max_tokens", [5])
28+
@pytest.mark.parametrize("enforce_eager", [True])
29+
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
30+
def test_models(
31+
monkeypatch,
32+
model: str,
33+
max_tokens: int,
34+
enforce_eager: bool,
35+
tensor_parallel_size: int,
36+
) -> None:
37+
prompt = "The next numbers of the sequence " + ", ".join(
38+
str(i) for i in range(1024)) + " are:"
39+
example_prompts = [prompt]
40+
41+
with monkeypatch.context() as m:
42+
m.setenv("VLLM_USE_V1", "1")
43+
44+
with VllmRunner(
45+
model,
46+
max_model_len=8192,
47+
enforce_eager=enforce_eager,
48+
gpu_memory_utilization=0.7,
49+
max_num_seqs=16,
50+
tensor_parallel_size=tensor_parallel_size) as vllm_model:
51+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
52+
max_tokens)
53+
output = vllm_outputs[0][1]
54+
assert "1024" in output

vllm/executor/ray_distributed_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
7373
def _init_executor(self) -> None:
7474
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
7575
if envs.VLLM_USE_V1:
76-
# v1 always uses the compiled DAG and SPMD worker.
76+
# V1 uses SPMD worker and compiled DAG
7777
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
7878
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
79+
80+
# For TPU, avoid compiling NVIDIA's NCCL
81+
if current_platform.is_tpu():
82+
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
83+
7984
# If the env var is set, it uses the Ray's compiled DAG API
8085
# which optimizes the control plane overhead.
8186
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

vllm/executor/ray_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import ParallelConfig
1212
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1313
from vllm.logger import init_logger
14+
from vllm.platforms import current_platform
1415
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1516
from vllm.utils import get_ip
1617
from vllm.worker.worker_base import WorkerWrapperBase
@@ -106,10 +107,15 @@ def setup_device_if_necessary(self):
106107
# on a background thread, so we need to reset torch's current
107108
# device.
108109
# We can remove this API after it is fixed in compiled graph.
109-
import torch
110110
assert self.worker is not None, "Worker is not initialized"
111111
if not self.compiled_dag_cuda_device_set:
112-
torch.cuda.set_device(self.worker.device)
112+
if current_platform.is_tpu():
113+
# Not needed
114+
pass
115+
else:
116+
import torch
117+
torch.cuda.set_device(self.worker.device)
118+
113119
self.compiled_dag_cuda_device_set = True
114120

115121
def execute_model_ray(

vllm/v1/worker/tpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
2222
from vllm.multimodal.utils import group_mm_inputs_by_modality
2323
from vllm.sampling_params import SamplingType
24+
from vllm.sequence import IntermediateTensors
2425
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
2526
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
2627
NUM_QUERIES_PER_BLOCK,
@@ -545,6 +546,7 @@ def _gather_encoder_outputs(
545546
def execute_model(
546547
self,
547548
scheduler_output: "SchedulerOutput",
549+
intermediate_tensors: Optional[IntermediateTensors] = None,
548550
) -> ModelRunnerOutput:
549551
# Update cached state
550552
self._update_states(scheduler_output)

vllm/v1/worker/tpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def init_device(self):
9696

9797
# Set random seed.
9898
set_random_seed(self.model_config.seed)
99-
xm.set_rng_state(self.model_config.seed, self.device)
99+
if self.model_config.seed is not None:
100+
xm.set_rng_state(self.model_config.seed, self.device)
100101

101102
# Increase the cache size limit, which is the maximum number of
102103
# dynamo graphs that can be compiled.

0 commit comments

Comments
 (0)