Skip to content

Commit 032dfba

Browse files
divakar-amdshreyankg
authored andcommitted
[ROCm] fix get_device_name for rocm (vllm-project#13438)
Signed-off-by: Divakar Verma <[email protected]>
1 parent 666b104 commit 032dfba

File tree

1 file changed

+43
-6
lines changed

1 file changed

+43
-6
lines changed

vllm/platforms/rocm.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from functools import lru_cache
3+
import os
4+
from functools import lru_cache, wraps
45
from typing import TYPE_CHECKING, Dict, List, Optional
56

67
import torch
8+
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
9+
amdsmi_init, amdsmi_shut_down)
710

811
import vllm.envs as envs
912
from vllm.logger import init_logger
@@ -53,6 +56,41 @@
5356
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
5457
}
5558

59+
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
60+
if "HIP_VISIBLE_DEVICES" in os.environ:
61+
val = os.environ["HIP_VISIBLE_DEVICES"]
62+
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
63+
assert val == cuda_val
64+
else:
65+
os.environ["CUDA_VISIBLE_DEVICES"] = val
66+
67+
# AMDSMI utils
68+
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
69+
# all the related functions work on real physical device ids.
70+
# the major benefit of using AMDSMI is that it will not initialize CUDA
71+
72+
73+
def with_amdsmi_context(fn):
74+
75+
@wraps(fn)
76+
def wrapper(*args, **kwargs):
77+
amdsmi_init()
78+
try:
79+
return fn(*args, **kwargs)
80+
finally:
81+
amdsmi_shut_down()
82+
83+
return wrapper
84+
85+
86+
def device_id_to_physical_device_id(device_id: int) -> int:
87+
if "CUDA_VISIBLE_DEVICES" in os.environ:
88+
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
89+
physical_device_id = device_ids[device_id]
90+
return int(physical_device_id)
91+
else:
92+
return device_id
93+
5694

5795
class RocmPlatform(Platform):
5896
_enum = PlatformEnum.ROCM
@@ -96,13 +134,12 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
96134
return DeviceCapability(major=major, minor=minor)
97135

98136
@classmethod
137+
@with_amdsmi_context
99138
@lru_cache(maxsize=8)
100139
def get_device_name(cls, device_id: int = 0) -> str:
101-
# NOTE: When using V1 this function is called when overriding the
102-
# engine args. Calling torch.cuda.get_device_name(device_id) here
103-
# will result in the ROCm context being initialized before other
104-
# processes can be created.
105-
return "AMD"
140+
physical_device_id = device_id_to_physical_device_id(device_id)
141+
handle = amdsmi_get_processor_handles()[physical_device_id]
142+
return amdsmi_get_gpu_asic_info(handle)["market_name"]
106143

107144
@classmethod
108145
def get_device_total_memory(cls, device_id: int = 0) -> int:

0 commit comments

Comments
 (0)