Skip to content

Commit 59d6bb4

Browse files
authored
[Hardware][AMD]: Replace HIPCC version with more precise ROCm version (#11515)
Signed-off-by: hjwei <[email protected]>
1 parent b7dcc00 commit 59d6bb4

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

setup.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ctypes
12
import importlib.util
23
import logging
34
import os
@@ -13,7 +14,7 @@
1314
from setuptools import Extension, find_packages, setup
1415
from setuptools.command.build_ext import build_ext
1516
from setuptools_scm import get_version
16-
from torch.utils.cpp_extension import CUDA_HOME
17+
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
1718

1819

1920
def load_module_from_path(module_name, path):
@@ -379,25 +380,31 @@ def _build_custom_ops() -> bool:
379380
return _is_cuda() or _is_hip() or _is_cpu()
380381

381382

382-
def get_hipcc_rocm_version():
383-
# Run the hipcc --version command
384-
result = subprocess.run(['hipcc', '--version'],
385-
stdout=subprocess.PIPE,
386-
stderr=subprocess.STDOUT,
387-
text=True)
383+
def get_rocm_version():
384+
# Get the Rocm version from the ROCM_HOME/bin/librocm-core.so
385+
# see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21
386+
try:
387+
librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so"
388+
if not librocm_core_file.is_file():
389+
return None
390+
librocm_core = ctypes.CDLL(librocm_core_file)
391+
VerErrors = ctypes.c_uint32
392+
get_rocm_core_version = librocm_core.getROCmVersion
393+
get_rocm_core_version.restype = VerErrors
394+
get_rocm_core_version.argtypes = [
395+
ctypes.POINTER(ctypes.c_uint32),
396+
ctypes.POINTER(ctypes.c_uint32),
397+
ctypes.POINTER(ctypes.c_uint32),
398+
]
399+
major = ctypes.c_uint32()
400+
minor = ctypes.c_uint32()
401+
patch = ctypes.c_uint32()
388402

389-
# Check if the command was executed successfully
390-
if result.returncode != 0:
391-
print("Error running 'hipcc --version'")
403+
if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor),
404+
ctypes.byref(patch)) == 0):
405+
return "%d.%d.%d" % (major.value, minor.value, patch.value)
392406
return None
393-
394-
# Extract the version using a regular expression
395-
match = re.search(r'HIP version: (\S+)', result.stdout)
396-
if match:
397-
# Return the version string
398-
return match.group(1)
399-
else:
400-
print("Could not find HIP version in the output")
407+
except Exception:
401408
return None
402409

403410

@@ -479,11 +486,10 @@ def get_vllm_version() -> str:
479486
if "sdist" not in sys.argv:
480487
version += f"{sep}cu{cuda_version_str}"
481488
elif _is_hip():
482-
# Get the HIP version
483-
hipcc_version = get_hipcc_rocm_version()
484-
if hipcc_version != MAIN_CUDA_VERSION:
485-
rocm_version_str = hipcc_version.replace(".", "")[:3]
486-
version += f"{sep}rocm{rocm_version_str}"
489+
# Get the Rocm Version
490+
rocm_version = get_rocm_version() or torch.version.hip
491+
if rocm_version and rocm_version != MAIN_CUDA_VERSION:
492+
version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
487493
elif _is_neuron():
488494
# Get the Neuron version
489495
neuron_version = str(get_neuronxcc_version())

0 commit comments

Comments
 (0)