diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 2ee3f9104d1..7132681050e 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -19,6 +19,7 @@ import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -168,10 +169,9 @@ def _report_usage_once(self, model_architecture: str, # Platform information from vllm.platforms import current_platform if current_platform.is_cuda_alike(): - device_property = torch.cuda.get_device_properties(0) - self.gpu_count = torch.cuda.device_count() - self.gpu_type = device_property.name - self.gpu_memory_per_device = device_property.total_memory + self.gpu_count = cuda_device_count_stateless() + self.gpu_type, self.gpu_memory_per_device = ( + cuda_get_device_properties(0, ("name", "total_memory"))) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda self.provider = _detect_cloud_provider() diff --git a/vllm/utils.py b/vllm/utils.py index c65a370bd53..ed406a6b7b1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -38,11 +38,13 @@ from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, Iterable, Iterator, KeysView, Mapping) +from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, Tuple, Type, TypeVar, Union, cast, overload) + Optional, Sequence, Tuple, Type, TypeVar, Union, cast, + overload) from uuid import uuid4 import cachetools @@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool: return torch.cuda.is_initialized() +def cuda_get_device_properties(device, + names: Sequence[str], + init_cuda=False) -> tuple[Any, ...]: + """Get specified CUDA device property values without initializing CUDA in + the current process.""" + if init_cuda or cuda_is_initialized(): + props = torch.cuda.get_device_properties(device) + return tuple(getattr(props, name) for name in names) + + # Run in subprocess to avoid initializing CUDA as a side effect. + mp_ctx = multiprocessing.get_context("fork") + with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: + return executor.submit(cuda_get_device_properties, device, names, + True).result() + + def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: """Make an instance method that weakly references its associated instance and no-ops once that diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index bc49a0d3bb5..1149dfa9ce5 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -36,6 +36,7 @@ from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, StatLoggerBase) from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.utils import report_usage_stats logger = init_logger(__name__) @@ -114,6 +115,9 @@ def __init__( except RuntimeError: pass + # If usage stat is enabled, collect relevant info. + report_usage_stats(vllm_config, usage_context) + @classmethod def from_vllm_config( cls, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c05319f3d80..6fa90b26982 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -28,6 +28,7 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import report_usage_stats logger = init_logger(__name__) @@ -99,6 +100,9 @@ def __init__( # for v0 compatibility self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + # If usage stat is enabled, collect relevant info. + report_usage_stats(vllm_config, usage_context) + @classmethod def from_vllm_config( cls, diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 9c0fa2d0773..dc6457bf903 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) from vllm.utils import get_mp_context, kill_process_tree if TYPE_CHECKING: @@ -201,3 +203,45 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, Returns the sliced target tensor. """ return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) + + +def report_usage_stats(vllm_config, usage_context: UsageContext) -> None: + """Report usage statistics if enabled.""" + + if not is_usage_stats_enabled(): + return + + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(vllm_config.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(vllm_config.model_config.dtype), + "tensor_parallel_size": + vllm_config.parallel_config.tensor_parallel_size, + "block_size": + vllm_config.cache_config.block_size, + "gpu_memory_utilization": + vllm_config.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + vllm_config.model_config.quantization, + "kv_cache_dtype": + str(vllm_config.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(vllm_config.lora_config), + "enable_prompt_adapter": + bool(vllm_config.prompt_adapter_config), + "enable_prefix_caching": + vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": + vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": + vllm_config.parallel_config.disable_custom_all_reduce, + })