Skip to content

[Bugfix] Enable V1 usage stats #16986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vllm/usage/usage_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties
from vllm.version import __version__ as VLLM_VERSION

_config_home = envs.VLLM_CONFIG_ROOT
Expand Down Expand Up @@ -168,10 +169,9 @@ def _report_usage_once(self, model_architecture: str,
# Platform information
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
device_property = torch.cuda.get_device_properties(0)
self.gpu_count = torch.cuda.device_count()
self.gpu_type = device_property.name
self.gpu_memory_per_device = device_property.total_memory
self.gpu_count = cuda_device_count_stateless()
self.gpu_type, self.gpu_memory_per_device = (
cuda_get_device_properties(0, ("name", "total_memory")))
if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda
self.provider = _detect_cloud_provider()
Expand Down
20 changes: 19 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, KeysView, Mapping)
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, Type, TypeVar, Union, cast, overload)
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
overload)
from uuid import uuid4

import cachetools
Expand Down Expand Up @@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool:
return torch.cuda.is_initialized()


def cuda_get_device_properties(device,
names: Sequence[str],
init_cuda=False) -> tuple[Any, ...]:
"""Get specified CUDA device property values without initializing CUDA in
the current process."""
if init_cuda or cuda_is_initialized():
props = torch.cuda.get_device_properties(device)
return tuple(getattr(props, name) for name in names)

# Run in subprocess to avoid initializing CUDA as a side effect.
mp_ctx = multiprocessing.get_context("fork")
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
return executor.submit(cuda_get_device_properties, device, names,
True).result()


def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
"""Make an instance method that weakly references
its associated instance and no-ops once that
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase)
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.utils import report_usage_stats

logger = init_logger(__name__)

Expand Down Expand Up @@ -114,6 +115,9 @@ def __init__(
except RuntimeError:
pass

# If usage stat is enabled, collect relevant info.
report_usage_stats(vllm_config, usage_context)

@classmethod
def from_vllm_config(
cls,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import report_usage_stats

logger = init_logger(__name__)

Expand Down Expand Up @@ -99,6 +100,9 @@ def __init__(
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore

# If usage stat is enabled, collect relevant info.
report_usage_stats(vllm_config, usage_context)

@classmethod
def from_vllm_config(
cls,
Expand Down
44 changes: 44 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import get_mp_context, kill_process_tree

if TYPE_CHECKING:
Expand Down Expand Up @@ -201,3 +203,45 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)


def report_usage_stats(vllm_config, usage_context: UsageContext) -> None:
"""Report usage statistics if enabled."""

if not is_usage_stats_enabled():
return

from vllm.model_executor.model_loader import get_architecture_class_name

usage_message.report_usage(
get_architecture_class_name(vllm_config.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(vllm_config.model_config.dtype),
"tensor_parallel_size":
vllm_config.parallel_config.tensor_parallel_size,
"block_size":
vllm_config.cache_config.block_size,
"gpu_memory_utilization":
vllm_config.cache_config.gpu_memory_utilization,

# Quantization
"quantization":
vllm_config.model_config.quantization,
"kv_cache_dtype":
str(vllm_config.cache_config.cache_dtype),

# Feature flags
"enable_lora":
bool(vllm_config.lora_config),
"enable_prompt_adapter":
bool(vllm_config.prompt_adapter_config),
"enable_prefix_caching":
vllm_config.cache_config.enable_prefix_caching,
"enforce_eager":
vllm_config.model_config.enforce_eager,
"disable_custom_all_reduce":
vllm_config.parallel_config.disable_custom_all_reduce,
})