Skip to content

Commit d1fa714

Browse files
noemotiovonnoemotiovon
and
noemotiovon
authored
[Refactor]A simple device-related refactor (#11163)
Signed-off-by: noemotiovon <[email protected]> Co-authored-by: noemotiovon <[email protected]>
1 parent 969da7d commit d1fa714

File tree

7 files changed

+51
-31
lines changed

7 files changed

+51
-31
lines changed

vllm/platforms/cpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9898
"vllm.worker.cpu_worker.CPUWorker"
9999
else:
100100
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
101+
102+
@classmethod
103+
def is_pin_memory_available(cls) -> bool:
104+
logger.warning("Pin memory is not supported on CPU.")
105+
return False

vllm/platforms/hpu.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
import torch
44

5+
from vllm.logger import init_logger
6+
57
from .interface import Platform, PlatformEnum, _Backend
68

79
if TYPE_CHECKING:
810
from vllm.config import VllmConfig
911
else:
1012
VllmConfig = None
1113

14+
logger = init_logger(__name__)
15+
1216

1317
class HpuPlatform(Platform):
1418
_enum = PlatformEnum.HPU
@@ -43,3 +47,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
4347
parallel_config = vllm_config.parallel_config
4448
if parallel_config.worker_cls == "auto":
4549
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
50+
51+
@classmethod
52+
def is_pin_memory_available(cls):
53+
logger.warning("Pin memory is not supported on HPU.")
54+
return False

vllm/platforms/interface.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
import platform
33
import random
4+
from platform import uname
45
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
56

67
import numpy as np
@@ -16,6 +17,11 @@
1617
logger = init_logger(__name__)
1718

1819

20+
def in_wsl() -> bool:
21+
# Reference: https://github.com/microsoft/WSL/issues/4071
22+
return "microsoft" in " ".join(uname()).lower()
23+
24+
1925
class _Backend(enum.Enum):
2026
FLASH_ATTN = enum.auto()
2127
FLASH_ATTN_VLLM_V1 = enum.auto()
@@ -221,6 +227,17 @@ def get_cpu_architecture(cls) -> CpuArchEnum:
221227

222228
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
223229

230+
@classmethod
231+
def is_pin_memory_available(cls) -> bool:
232+
"""Checks whether pin memory is available on the current platform."""
233+
if in_wsl():
234+
# Pinning memory in WSL is not supported.
235+
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
236+
logger.warning("Using 'pin_memory=False' as WSL is detected. "
237+
"This may slow down the performance.")
238+
return False
239+
return True
240+
224241

225242
class UnspecifiedPlatform(Platform):
226243
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/neuron.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from typing import TYPE_CHECKING, Optional
22

3+
from vllm.logger import init_logger
4+
35
from .interface import Platform, PlatformEnum
46

57
if TYPE_CHECKING:
68
from vllm.config import VllmConfig
79
else:
810
VllmConfig = None
911

12+
logger = init_logger(__name__)
13+
1014

1115
class NeuronPlatform(Platform):
1216
_enum = PlatformEnum.NEURON
@@ -28,3 +32,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
2832
if parallel_config.worker_cls == "auto":
2933
parallel_config.worker_cls = \
3034
"vllm.worker.neuron_worker.NeuronWorker"
35+
36+
@classmethod
37+
def is_pin_memory_available(cls) -> bool:
38+
logger.warning("Pin memory is not supported on Neuron.")
39+
return False

vllm/platforms/openvino.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,27 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
3434
return _Backend.OPENVINO
3535

3636
@classmethod
37-
def get_device_name(self, device_id: int = 0) -> str:
37+
def get_device_name(cls, device_id: int = 0) -> str:
3838
return "openvino"
3939

4040
@classmethod
4141
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
4242
return False
4343

4444
@classmethod
45-
def inference_mode(self):
45+
def inference_mode(cls):
4646
return torch.inference_mode(mode=True)
4747

4848
@classmethod
49-
def is_openvino_cpu(self) -> bool:
49+
def is_openvino_cpu(cls) -> bool:
5050
return "CPU" in envs.VLLM_OPENVINO_DEVICE
5151

5252
@classmethod
53-
def is_openvino_gpu(self) -> bool:
53+
def is_openvino_gpu(cls) -> bool:
5454
return "GPU" in envs.VLLM_OPENVINO_DEVICE
5555

5656
@classmethod
57-
def is_pin_memory_available(self) -> bool:
57+
def is_pin_memory_available(cls) -> bool:
5858
logger.warning("Pin memory is not supported on OpenViNO.")
5959
return False
6060

vllm/platforms/xpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
7878
parallel_config.distributed_executor_backend = "ray"
7979
if parallel_config.worker_cls == "auto":
8080
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
81+
82+
@classmethod
83+
def is_pin_memory_available(cls):
84+
logger.warning("Pin memory is not supported on XPU.")
85+
return False

vllm/utils.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from collections import UserDict, defaultdict
2525
from collections.abc import Iterable, Mapping
2626
from functools import lru_cache, partial, wraps
27-
from platform import uname
2827
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
2928
Dict, Generic, Hashable, List, Literal, Optional,
3029
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
@@ -344,12 +343,6 @@ def random_uuid() -> str:
344343
return str(uuid.uuid4().hex)
345344

346345

347-
@lru_cache(maxsize=None)
348-
def in_wsl() -> bool:
349-
# Reference: https://github.com/microsoft/WSL/issues/4071
350-
return "microsoft" in " ".join(uname()).lower()
351-
352-
353346
def make_async(
354347
func: Callable[P, T],
355348
executor: Optional[concurrent.futures.Executor] = None
@@ -729,25 +722,7 @@ def print_warning_once(msg: str) -> None:
729722

730723
@lru_cache(maxsize=None)
731724
def is_pin_memory_available() -> bool:
732-
733-
if in_wsl():
734-
# Pinning memory in WSL is not supported.
735-
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
736-
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
737-
"This may slow down the performance.")
738-
return False
739-
elif current_platform.is_xpu():
740-
print_warning_once("Pin memory is not supported on XPU.")
741-
return False
742-
elif current_platform.is_neuron():
743-
print_warning_once("Pin memory is not supported on Neuron.")
744-
return False
745-
elif current_platform.is_hpu():
746-
print_warning_once("Pin memory is not supported on HPU.")
747-
return False
748-
elif current_platform.is_cpu() or current_platform.is_openvino():
749-
return False
750-
return True
725+
return current_platform.is_pin_memory_available()
751726

752727

753728
class DeviceMemoryProfiler:

0 commit comments

Comments
 (0)