Skip to content

[WIP] [V1] TPU support #11936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ repos:
name: Suggestion
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
language: system
verbose: true
verbose: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

6 changes: 3 additions & 3 deletions examples/offline_inference/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert


# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
15 changes: 11 additions & 4 deletions tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default
["--enable-chunked-prefill"], # Chunked
Expand Down Expand Up @@ -66,14 +66,21 @@ def run_test(more_args):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 currently only supported on CUDA")
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test([])
more_args = []

# Limit compilation time for V1
if current_platform.is_tpu():
more_args = ["--max-num-seqs", "64"]

run_test(more_args)


@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
Expand Down
2 changes: 1 addition & 1 deletion tools/mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1
run_mypy vllm/v1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

2 changes: 1 addition & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
"vllm.v1.worker.gpu_worker.GPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
Expand Down
52 changes: 37 additions & 15 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend
Expand Down Expand Up @@ -30,10 +31,16 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
if selected_backend != _Backend.PALLAS:
if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"

if use_v1:
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand All @@ -45,7 +52,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
return not envs.VLLM_USE_V1

@classmethod
def inference_mode(cls):
Expand All @@ -60,22 +67,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config.block_size = 16

compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION

# TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."

if compilation_config.backend == "":
compilation_config.backend = "openxla"

assert vllm_config.speculative_config is None, \
"TPU does not support speculative decoding"

assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
Expand All @@ -85,8 +88,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"

# Adjust scheduler config for V1
# TODO: Add support for these
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
logger.warning("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False

assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False
Loading