Skip to content

[V1] TPU - Add tensor parallel support via Ray #13618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/entrypoints/llm/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def run_test(more_args=None):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


# TODO: [AlexM] Fix it with new CI/CD tests
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"


@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU")
Expand All @@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
# Limit compilation time for TPU V1
more_args = "max_num_seqs=64"

# Add TP test (if provided)
if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR)

run_test(more_args)


Expand Down
Empty file added tests/v1/tpu/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions tests/v1/tpu/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
"""A basic correctness check for TPUs

Run `pytest tests/v1/tpu/test_basic.py`.
"""
import pytest

from vllm.platforms import current_platform

from ...conftest import VllmRunner

MODELS = [
# "Qwen/Qwen2-7B-Instruct",
"meta-llama/Llama-3.1-8B",
# TODO: Add models here as necessary
]

TENSOR_PARALLEL_SIZES = [1]

# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]


@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models(
monkeypatch,
model: str,
max_tokens: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

with VllmRunner(
model,
max_model_len=8192,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
max_num_seqs=16,
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output
7 changes: 6 additions & 1 deletion vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
if envs.VLLM_USE_V1:
# v1 always uses the compiled DAG and SPMD worker.
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"

# For TPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
Expand Down
10 changes: 8 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
Expand Down Expand Up @@ -106,10 +107,15 @@ def setup_device_if_necessary(self):
# on a background thread, so we need to reset torch's current
# device.
# We can remove this API after it is fixed in compiled graph.
import torch
assert self.worker is not None, "Worker is not initialized"
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
if current_platform.is_tpu():
# Not needed
pass
else:
import torch
torch.cuda.set_device(self.worker.device)

self.compiled_dag_cuda_device_set = True

def execute_model_ray(
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK,
Expand Down Expand Up @@ -543,6 +544,7 @@ def _gather_encoder_outputs(
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I thought intermediate_tensors was just needed for PP

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used, but it is part of the API, else it errors.

) -> ModelRunnerOutput:
# Update cached state
self._update_states(scheduler_output)
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def init_device(self):

# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
if self.model_config.seed is not None:
xm.set_rng_state(self.model_config.seed, self.device)

# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
Expand Down