Skip to content

[Bugfix] Enable V1 usage stats #16986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase)
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.utils import report_usage_stats

logger = init_logger(__name__)

Expand Down Expand Up @@ -114,6 +115,9 @@ def __init__(
except RuntimeError:
pass

# If usage stat is enabled, collect relevant info.
report_usage_stats(vllm_config, usage_context)

@classmethod
def from_vllm_config(
cls,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import report_usage_stats

logger = init_logger(__name__)

Expand Down Expand Up @@ -99,6 +100,9 @@ def __init__(
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore

# If usage stat is enabled, collect relevant info.
report_usage_stats(vllm_config, usage_context)

@classmethod
def from_vllm_config(
cls,
Expand Down
48 changes: 48 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,51 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)


def report_usage_stats(vllm_config, usage_context: str) -> None:
"""Report usage statistics if enabled.

Args:
vllm_config: The vLLM configuration object containing model_config,
cache_config, parallel_config, etc.
usage_context: The context string for usage reporting
"""
from vllm.usage.usage_lib import is_usage_stats_enabled, usage_message

if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)

usage_message.report_usage(
get_architecture_class_name(vllm_config.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(vllm_config.model_config.dtype),
"tensor_parallel_size":
vllm_config.parallel_config.tensor_parallel_size,
"block_size":
vllm_config.cache_config.block_size,
"gpu_memory_utilization":
vllm_config.cache_config.gpu_memory_utilization,

# Quantization
"quantization":
vllm_config.model_config.quantization,
"kv_cache_dtype":
str(vllm_config.cache_config.cache_dtype),

# Feature flags
"enable_lora":
bool(vllm_config.lora_config),
"enable_prompt_adapter":
bool(vllm_config.prompt_adapter_config),
"enable_prefix_caching":
vllm_config.cache_config.enable_prefix_caching,
"enforce_eager":
vllm_config.model_config.enforce_eager,
"disable_custom_all_reduce":
vllm_config.parallel_config.disable_custom_all_reduce,
})