|
| 1 | +import ctypes |
1 | 2 | import importlib.util
|
2 | 3 | import logging
|
3 | 4 | import os
|
|
13 | 14 | from setuptools import Extension, find_packages, setup
|
14 | 15 | from setuptools.command.build_ext import build_ext
|
15 | 16 | from setuptools_scm import get_version
|
16 |
| -from torch.utils.cpp_extension import CUDA_HOME |
| 17 | +from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME |
17 | 18 |
|
18 | 19 |
|
19 | 20 | def load_module_from_path(module_name, path):
|
@@ -379,25 +380,31 @@ def _build_custom_ops() -> bool:
|
379 | 380 | return _is_cuda() or _is_hip() or _is_cpu()
|
380 | 381 |
|
381 | 382 |
|
382 |
| -def get_hipcc_rocm_version(): |
383 |
| - # Run the hipcc --version command |
384 |
| - result = subprocess.run(['hipcc', '--version'], |
385 |
| - stdout=subprocess.PIPE, |
386 |
| - stderr=subprocess.STDOUT, |
387 |
| - text=True) |
| 383 | +def get_rocm_version(): |
| 384 | + # Get the Rocm version from the ROCM_HOME/bin/librocm-core.so |
| 385 | + # see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21 |
| 386 | + try: |
| 387 | + librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so" |
| 388 | + if not librocm_core_file.is_file(): |
| 389 | + return None |
| 390 | + librocm_core = ctypes.CDLL(librocm_core_file) |
| 391 | + VerErrors = ctypes.c_uint32 |
| 392 | + get_rocm_core_version = librocm_core.getROCmVersion |
| 393 | + get_rocm_core_version.restype = VerErrors |
| 394 | + get_rocm_core_version.argtypes = [ |
| 395 | + ctypes.POINTER(ctypes.c_uint32), |
| 396 | + ctypes.POINTER(ctypes.c_uint32), |
| 397 | + ctypes.POINTER(ctypes.c_uint32), |
| 398 | + ] |
| 399 | + major = ctypes.c_uint32() |
| 400 | + minor = ctypes.c_uint32() |
| 401 | + patch = ctypes.c_uint32() |
388 | 402 |
|
389 |
| - # Check if the command was executed successfully |
390 |
| - if result.returncode != 0: |
391 |
| - print("Error running 'hipcc --version'") |
| 403 | + if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor), |
| 404 | + ctypes.byref(patch)) == 0): |
| 405 | + return "%d.%d.%d" % (major.value, minor.value, patch.value) |
392 | 406 | return None
|
393 |
| - |
394 |
| - # Extract the version using a regular expression |
395 |
| - match = re.search(r'HIP version: (\S+)', result.stdout) |
396 |
| - if match: |
397 |
| - # Return the version string |
398 |
| - return match.group(1) |
399 |
| - else: |
400 |
| - print("Could not find HIP version in the output") |
| 407 | + except Exception: |
401 | 408 | return None
|
402 | 409 |
|
403 | 410 |
|
@@ -479,11 +486,10 @@ def get_vllm_version() -> str:
|
479 | 486 | if "sdist" not in sys.argv:
|
480 | 487 | version += f"{sep}cu{cuda_version_str}"
|
481 | 488 | elif _is_hip():
|
482 |
| - # Get the HIP version |
483 |
| - hipcc_version = get_hipcc_rocm_version() |
484 |
| - if hipcc_version != MAIN_CUDA_VERSION: |
485 |
| - rocm_version_str = hipcc_version.replace(".", "")[:3] |
486 |
| - version += f"{sep}rocm{rocm_version_str}" |
| 489 | + # Get the Rocm Version |
| 490 | + rocm_version = get_rocm_version() or torch.version.hip |
| 491 | + if rocm_version and rocm_version != MAIN_CUDA_VERSION: |
| 492 | + version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" |
487 | 493 | elif _is_neuron():
|
488 | 494 | # Get the Neuron version
|
489 | 495 | neuron_version = str(get_neuronxcc_version())
|
|
0 commit comments