From caf268c703a127236b332dcafbcd5aee906f4665 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 22 Apr 2025 14:06:43 +0000 Subject: [PATCH 1/4] Enable V1 usage stats Signed-off-by: mgoin --- vllm/v1/engine/async_llm.py | 40 +++++++++++++++++++++++++++++++++++- vllm/v1/engine/llm_engine.py | 40 +++++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index bc49a0d3bb5..1c94288d43d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -23,7 +23,8 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.usage.usage_lib import UsageContext +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient @@ -114,6 +115,43 @@ def __init__( except RuntimeError: pass + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(self.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(self.model_config.dtype), + "tensor_parallel_size": + vllm_config.parallel_config.tensor_parallel_size, + "block_size": + vllm_config.cache_config.block_size, + "gpu_memory_utilization": + vllm_config.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + self.model_config.quantization, + "kv_cache_dtype": + str(vllm_config.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(vllm_config.lora_config), + "enable_prompt_adapter": + bool(vllm_config.prompt_adapter_config), + "enable_prefix_caching": + vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": + vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": + vllm_config.parallel_config.disable_custom_all_reduce, + }) + @classmethod def from_vllm_config( cls, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c05319f3d80..bbd6510cb6b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,7 +21,8 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) -from vllm.usage.usage_lib import UsageContext +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) from vllm.utils import Device from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor @@ -99,6 +100,43 @@ def __init__( # for v0 compatibility self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(self.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(self.model_config.dtype), + "tensor_parallel_size": + parallel_config.tensor_parallel_size, + "block_size": + self.cache_config.block_size, + "gpu_memory_utilization": + self.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + self.model_config.quantization, + "kv_cache_dtype": + str(self.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(vllm_config.lora_config), + "enable_prompt_adapter": + bool(vllm_config.prompt_adapter_config), + "enable_prefix_caching": + self.cache_config.enable_prefix_caching, + "enforce_eager": + self.model_config.enforce_eager, + "disable_custom_all_reduce": + parallel_config.disable_custom_all_reduce, + }) + @classmethod def from_vllm_config( cls, From 24dfb08a6b130f6906ab743c47a13fe0088e24d0 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 22 Apr 2025 17:33:15 +0000 Subject: [PATCH 2/4] Move to util Signed-off-by: mgoin --- vllm/v1/engine/async_llm.py | 40 +++--------------------------- vllm/v1/engine/llm_engine.py | 40 +++--------------------------- vllm/v1/utils.py | 48 ++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 74 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1c94288d43d..1149dfa9ce5 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -23,8 +23,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) +from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient @@ -37,6 +36,7 @@ from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, StatLoggerBase) from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.utils import report_usage_stats logger = init_logger(__name__) @@ -116,41 +116,7 @@ def __init__( pass # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import ( - get_architecture_class_name) - usage_message.report_usage( - get_architecture_class_name(self.model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": - str(self.model_config.dtype), - "tensor_parallel_size": - vllm_config.parallel_config.tensor_parallel_size, - "block_size": - vllm_config.cache_config.block_size, - "gpu_memory_utilization": - vllm_config.cache_config.gpu_memory_utilization, - - # Quantization - "quantization": - self.model_config.quantization, - "kv_cache_dtype": - str(vllm_config.cache_config.cache_dtype), - - # Feature flags - "enable_lora": - bool(vllm_config.lora_config), - "enable_prompt_adapter": - bool(vllm_config.prompt_adapter_config), - "enable_prefix_caching": - vllm_config.cache_config.enable_prefix_caching, - "enforce_eager": - vllm_config.model_config.enforce_eager, - "disable_custom_all_reduce": - vllm_config.parallel_config.disable_custom_all_reduce, - }) + report_usage_stats(vllm_config, usage_context) @classmethod def from_vllm_config( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index bbd6510cb6b..6fa90b26982 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,14 +21,14 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) +from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import report_usage_stats logger = init_logger(__name__) @@ -101,41 +101,7 @@ def __init__( self.model_executor = self.engine_core.engine_core.model_executor # type: ignore # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import ( - get_architecture_class_name) - usage_message.report_usage( - get_architecture_class_name(self.model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": - str(self.model_config.dtype), - "tensor_parallel_size": - parallel_config.tensor_parallel_size, - "block_size": - self.cache_config.block_size, - "gpu_memory_utilization": - self.cache_config.gpu_memory_utilization, - - # Quantization - "quantization": - self.model_config.quantization, - "kv_cache_dtype": - str(self.cache_config.cache_dtype), - - # Feature flags - "enable_lora": - bool(vllm_config.lora_config), - "enable_prompt_adapter": - bool(vllm_config.prompt_adapter_config), - "enable_prefix_caching": - self.cache_config.enable_prefix_caching, - "enforce_eager": - self.model_config.enforce_eager, - "disable_custom_all_reduce": - parallel_config.disable_custom_all_reduce, - }) + report_usage_stats(vllm_config, usage_context) @classmethod def from_vllm_config( diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 32d8101f681..7f509c3f0d7 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -201,3 +201,51 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, Returns the sliced target tensor. """ return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) + + +def report_usage_stats(vllm_config, usage_context: str) -> None: + """Report usage statistics if enabled. + + Args: + vllm_config: The vLLM configuration object containing model_config, + cache_config, parallel_config, etc. + usage_context: The context string for usage reporting + """ + from vllm.usage.usage_lib import is_usage_stats_enabled, usage_message + + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + + usage_message.report_usage( + get_architecture_class_name(vllm_config.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(vllm_config.model_config.dtype), + "tensor_parallel_size": + vllm_config.parallel_config.tensor_parallel_size, + "block_size": + vllm_config.cache_config.block_size, + "gpu_memory_utilization": + vllm_config.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + vllm_config.model_config.quantization, + "kv_cache_dtype": + str(vllm_config.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(vllm_config.lora_config), + "enable_prompt_adapter": + bool(vllm_config.prompt_adapter_config), + "enable_prefix_caching": + vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": + vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": + vllm_config.parallel_config.disable_custom_all_reduce, + }) From b82124d21ae3e0044fe52a3e51b6d707c186569e Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 23 Apr 2025 01:51:37 +0000 Subject: [PATCH 3/4] Reformat Signed-off-by: mgoin --- vllm/v1/utils.py | 88 +++++++++++++++++++++++------------------------- 1 file changed, 42 insertions(+), 46 deletions(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 7f509c3f0d7..6744a068484 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) from vllm.utils import get_mp_context, kill_process_tree if TYPE_CHECKING: @@ -203,49 +205,43 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) -def report_usage_stats(vllm_config, usage_context: str) -> None: - """Report usage statistics if enabled. - - Args: - vllm_config: The vLLM configuration object containing model_config, - cache_config, parallel_config, etc. - usage_context: The context string for usage reporting - """ - from vllm.usage.usage_lib import is_usage_stats_enabled, usage_message - - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import ( - get_architecture_class_name) - - usage_message.report_usage( - get_architecture_class_name(vllm_config.model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": - str(vllm_config.model_config.dtype), - "tensor_parallel_size": - vllm_config.parallel_config.tensor_parallel_size, - "block_size": - vllm_config.cache_config.block_size, - "gpu_memory_utilization": - vllm_config.cache_config.gpu_memory_utilization, - - # Quantization - "quantization": - vllm_config.model_config.quantization, - "kv_cache_dtype": - str(vllm_config.cache_config.cache_dtype), - - # Feature flags - "enable_lora": - bool(vllm_config.lora_config), - "enable_prompt_adapter": - bool(vllm_config.prompt_adapter_config), - "enable_prefix_caching": - vllm_config.cache_config.enable_prefix_caching, - "enforce_eager": - vllm_config.model_config.enforce_eager, - "disable_custom_all_reduce": - vllm_config.parallel_config.disable_custom_all_reduce, - }) +def report_usage_stats(vllm_config, usage_context: UsageContext) -> None: + """Report usage statistics if enabled.""" + + if not is_usage_stats_enabled(): + return + + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(vllm_config.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(vllm_config.model_config.dtype), + "tensor_parallel_size": + vllm_config.parallel_config.tensor_parallel_size, + "block_size": + vllm_config.cache_config.block_size, + "gpu_memory_utilization": + vllm_config.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + vllm_config.model_config.quantization, + "kv_cache_dtype": + str(vllm_config.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(vllm_config.lora_config), + "enable_prompt_adapter": + bool(vllm_config.prompt_adapter_config), + "enable_prefix_caching": + vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": + vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": + vllm_config.parallel_config.disable_custom_all_reduce, + }) From 58a9714df1a51effcb3cfdcd2eda6f4b43469736 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 23 Apr 2025 16:09:03 -0700 Subject: [PATCH 4/4] init cuda in separate process Signed-off-by: Nick Hill --- vllm/usage/usage_lib.py | 8 ++++---- vllm/utils.py | 20 +++++++++++++++++++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 2ee3f9104d1..7132681050e 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -19,6 +19,7 @@ import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -168,10 +169,9 @@ def _report_usage_once(self, model_architecture: str, # Platform information from vllm.platforms import current_platform if current_platform.is_cuda_alike(): - device_property = torch.cuda.get_device_properties(0) - self.gpu_count = torch.cuda.device_count() - self.gpu_type = device_property.name - self.gpu_memory_per_device = device_property.total_memory + self.gpu_count = cuda_device_count_stateless() + self.gpu_type, self.gpu_memory_per_device = ( + cuda_get_device_properties(0, ("name", "total_memory"))) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda self.provider = _detect_cloud_provider() diff --git a/vllm/utils.py b/vllm/utils.py index c6e2afff72d..c1bbc2b6e2b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -38,11 +38,13 @@ from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, Iterable, Iterator, KeysView, Mapping) +from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, Tuple, Type, TypeVar, Union, cast, overload) + Optional, Sequence, Tuple, Type, TypeVar, Union, cast, + overload) from uuid import uuid4 import cachetools @@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool: return torch.cuda.is_initialized() +def cuda_get_device_properties(device, + names: Sequence[str], + init_cuda=False) -> tuple[Any, ...]: + """Get specified CUDA device property values without initializing CUDA in + the current process.""" + if init_cuda or cuda_is_initialized(): + props = torch.cuda.get_device_properties(device) + return tuple(getattr(props, name) for name in names) + + # Run in subprocess to avoid initializing CUDA as a side effect. + mp_ctx = multiprocessing.get_context("fork") + with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: + return executor.submit(cuda_get_device_properties, device, names, + True).result() + + def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: """Make an instance method that weakly references its associated instance and no-ops once that