diff --git a/tests/conftest.py b/tests/conftest.py index 25e70319e2c..000da74c2db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -272,7 +272,8 @@ class HfRunner: def get_default_device(self): from vllm.platforms import current_platform - return ("cpu" if current_platform.is_cpu() else "cuda") + return ("cpu" + if current_platform.is_cpu() else current_platform.device_type) def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: if x is None or isinstance(x, (bool, )): diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 5f041b44893..24b759bc1fa 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -6,6 +6,7 @@ import pytest import torch +from vllm.platforms import current_platform from vllm.utils import make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -13,7 +14,8 @@ VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) + f"{current_platform.device_type}:{i}" + for i in range(1 if current_platform.device_count() == 1 else 2) ] MAX_NUM_PROMPT_TOKENS = 64 diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 7ddf382079c..d044e18c257 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -14,6 +14,7 @@ SamplerOutput, SamplingMetadata, get_logprobs, get_pythonized_sample_results) +from vllm.platforms import current_platform from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream @@ -158,8 +159,8 @@ class StatefulModelInput(BroadcastableModelInput): is_first_multi_step: bool = False base_output_proc_callback: Optional[Callable] = None # ping-pong data structures for multi-step to wait on the previous step - step_cuda_events: List[torch.cuda.Event] = field( - default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + step_cuda_events: List[current_platform.Event] = field( + default_factory=lambda: [current_platform.Event(blocking=True)] * 2) num_seqs: int = -1 num_queries: int = -1 num_single_step_prefills: int = 0