Skip to content

Commit 3c71780

Browse files
mgoinnjhill
authored andcommitted
[Bugfix] Enable V1 usage stats (vllm-project#16986)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]> Signed-off-by: Zijing Liu <[email protected]>
1 parent 5b40c90 commit 3c71780

File tree

5 files changed

+134
-46
lines changed

5 files changed

+134
-46
lines changed

vllm/usage/usage_lib.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import vllm.envs as envs
2121
from vllm.connections import global_http_connection
22+
from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties
2223
from vllm.version import __version__ as VLLM_VERSION
2324

2425
_config_home = envs.VLLM_CONFIG_ROOT
@@ -168,10 +169,9 @@ def _report_usage_once(self, model_architecture: str,
168169
# Platform information
169170
from vllm.platforms import current_platform
170171
if current_platform.is_cuda_alike():
171-
device_property = torch.cuda.get_device_properties(0)
172-
self.gpu_count = torch.cuda.device_count()
173-
self.gpu_type = device_property.name
174-
self.gpu_memory_per_device = device_property.total_memory
172+
self.gpu_count = cuda_device_count_stateless()
173+
self.gpu_type, self.gpu_memory_per_device = (
174+
cuda_get_device_properties(0, ("name", "total_memory")))
175175
if current_platform.is_cuda():
176176
self.cuda_runtime = torch.version.cuda
177177
self.provider = _detect_cloud_provider()

vllm/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
from collections import UserDict, defaultdict
3939
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
4040
Iterable, Iterator, KeysView, Mapping)
41+
from concurrent.futures.process import ProcessPoolExecutor
4142
from dataclasses import dataclass, field
4243
from functools import cache, lru_cache, partial, wraps
4344
from types import MappingProxyType
4445
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
45-
Optional, Tuple, Type, TypeVar, Union, cast, overload)
46+
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
47+
overload)
4648
from uuid import uuid4
4749

4850
import cachetools
@@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool:
12351237
return torch.cuda.is_initialized()
12361238

12371239

1240+
def cuda_get_device_properties(device,
1241+
names: Sequence[str],
1242+
init_cuda=False) -> tuple[Any, ...]:
1243+
"""Get specified CUDA device property values without initializing CUDA in
1244+
the current process."""
1245+
if init_cuda or cuda_is_initialized():
1246+
props = torch.cuda.get_device_properties(device)
1247+
return tuple(getattr(props, name) for name in names)
1248+
1249+
# Run in subprocess to avoid initializing CUDA as a side effect.
1250+
mp_ctx = multiprocessing.get_context("fork")
1251+
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
1252+
return executor.submit(cuda_get_device_properties, device, names,
1253+
True).result()
1254+
1255+
12381256
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
12391257
"""Make an instance method that weakly references
12401258
its associated instance and no-ops once that

vllm/v1/engine/async_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
3636
setup_default_loggers)
3737
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
38+
from vllm.v1.utils import report_usage_stats
3839

3940
logger = init_logger(__name__)
4041

@@ -131,6 +132,9 @@ def __init__(
131132
except RuntimeError:
132133
pass
133134

135+
# If usage stat is enabled, collect relevant info.
136+
report_usage_stats(vllm_config, usage_context)
137+
134138
@classmethod
135139
def from_vllm_config(
136140
cls,

vllm/v1/engine/llm_engine.py

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from copy import copy
55
from typing import Any, Callable, Optional, Union
66

7-
from typing_extensions import TypeVar
8-
97
import vllm.envs as envs
8+
9+
from typing_extensions import TypeVar
1010
from vllm.config import ParallelConfig, VllmConfig
1111
from vllm.distributed import stateless_destroy_torch_distributed_process_group
1212
from vllm.engine.arg_utils import EngineArgs
@@ -19,7 +19,9 @@
1919
from vllm.prompt_adapter.request import PromptAdapterRequest
2020
from vllm.sampling_params import SamplingParams
2121
from vllm.transformers_utils.tokenizer_group import (
22-
BaseTokenizerGroup, init_tokenizer_from_configs)
22+
BaseTokenizerGroup,
23+
init_tokenizer_from_configs,
24+
)
2325
from vllm.usage.usage_lib import UsageContext
2426
from vllm.utils import Device
2527
from vllm.v1.engine.core_client import EngineCoreClient
@@ -28,6 +30,7 @@
2830
from vllm.v1.engine.processor import Processor
2931
from vllm.v1.executor.abstract import Executor
3032
from vllm.v1.metrics.loggers import StatLoggerFactory
33+
from vllm.v1.utils import report_usage_stats
3134

3235
logger = init_logger(__name__)
3336

@@ -54,12 +57,14 @@ def __init__(
5457
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
5558
"This should not happen. As a workaround, try using "
5659
"LLMEngine.from_vllm_config(...) or explicitly set "
57-
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
60+
"VLLM_USE_V1=0 or 1 and report this issue on Github."
61+
)
5862

5963
if stat_loggers is not None:
6064
raise NotImplementedError(
6165
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
62-
"Set VLLM_USE_V1=0 and file and issue on Github.")
66+
"Set VLLM_USE_V1=0 and file and issue on Github."
67+
)
6368

6469
self.vllm_config = vllm_config
6570
self.model_config = vllm_config.model_config
@@ -79,17 +84,17 @@ def __init__(
7984
model_config=vllm_config.model_config,
8085
scheduler_config=vllm_config.scheduler_config,
8186
parallel_config=vllm_config.parallel_config,
82-
lora_config=vllm_config.lora_config)
87+
lora_config=vllm_config.lora_config,
88+
)
8389
self.tokenizer.ping()
8490

8591
# Processor (convert Inputs --> EngineCoreRequests)
86-
self.processor = Processor(vllm_config=vllm_config,
87-
tokenizer=self.tokenizer,
88-
mm_registry=mm_registry)
92+
self.processor = Processor(
93+
vllm_config=vllm_config, tokenizer=self.tokenizer, mm_registry=mm_registry
94+
)
8995

9096
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
91-
self.output_processor = OutputProcessor(self.tokenizer,
92-
log_stats=False)
97+
self.output_processor = OutputProcessor(self.tokenizer, log_stats=False)
9398

9499
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
95100
self.engine_core = EngineCoreClient.make_client(
@@ -104,6 +109,9 @@ def __init__(
104109
# for v0 compatibility
105110
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
106111

112+
# If usage stat is enabled, collect relevant info.
113+
report_usage_stats(vllm_config, usage_context)
114+
107115
@classmethod
108116
def from_vllm_config(
109117
cls,
@@ -112,12 +120,14 @@ def from_vllm_config(
112120
stat_loggers: Optional[list[StatLoggerFactory]] = None,
113121
disable_log_stats: bool = False,
114122
) -> "LLMEngine":
115-
return cls(vllm_config=vllm_config,
116-
executor_class=Executor.get_class(vllm_config),
117-
log_stats=(not disable_log_stats),
118-
usage_context=usage_context,
119-
stat_loggers=stat_loggers,
120-
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
123+
return cls(
124+
vllm_config=vllm_config,
125+
executor_class=Executor.get_class(vllm_config),
126+
log_stats=(not disable_log_stats),
127+
usage_context=usage_context,
128+
stat_loggers=stat_loggers,
129+
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
130+
)
121131

122132
@classmethod
123133
def from_engine_args(
@@ -138,12 +148,14 @@ def from_engine_args(
138148
enable_multiprocessing = True
139149

140150
# Create the LLMEngine.
141-
return cls(vllm_config=vllm_config,
142-
executor_class=executor_class,
143-
log_stats=not engine_args.disable_log_stats,
144-
usage_context=usage_context,
145-
stat_loggers=stat_loggers,
146-
multiprocess_mode=enable_multiprocessing)
151+
return cls(
152+
vllm_config=vllm_config,
153+
executor_class=executor_class,
154+
log_stats=not engine_args.disable_log_stats,
155+
usage_context=usage_context,
156+
stat_loggers=stat_loggers,
157+
multiprocess_mode=enable_multiprocessing,
158+
)
147159

148160
def get_num_unfinished_requests(self) -> int:
149161
return self.output_processor.get_num_unfinished_requests()
@@ -156,7 +168,8 @@ def has_unfinished_requests(self) -> bool:
156168

157169
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
158170
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
159-
self.dp_group, has_unfinished)
171+
self.dp_group, has_unfinished
172+
)
160173
if not has_unfinished and aggregated_has_unfinished:
161174
self.should_execute_dummy_batch = True
162175
return aggregated_has_unfinished
@@ -183,11 +196,16 @@ def add_request(
183196
priority: int = 0,
184197
) -> None:
185198
# Process raw inputs into the request.
186-
request = self.processor.process_inputs(request_id, prompt, params,
187-
arrival_time, lora_request,
188-
trace_headers,
189-
prompt_adapter_request,
190-
priority)
199+
request = self.processor.process_inputs(
200+
request_id,
201+
prompt,
202+
params,
203+
arrival_time,
204+
lora_request,
205+
trace_headers,
206+
prompt_adapter_request,
207+
priority,
208+
)
191209

192210
n = params.n if isinstance(params, SamplingParams) else 1
193211

@@ -222,8 +240,7 @@ def step(self) -> list[RequestOutput]:
222240
outputs = self.engine_core.get_output()
223241

224242
# 2) Process EngineCoreOutputs.
225-
processed_outputs = self.output_processor.process_outputs(
226-
outputs.outputs)
243+
processed_outputs = self.output_processor.process_outputs(outputs.outputs)
227244

228245
# 3) Abort any reqs that finished due to stop strings.
229246
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
@@ -261,12 +278,15 @@ def get_tokenizer_group(
261278
tokenizer_group = self.tokenizer
262279

263280
if tokenizer_group is None:
264-
raise ValueError("Unable to get tokenizer because "
265-
"skip_tokenizer_init is True")
281+
raise ValueError(
282+
"Unable to get tokenizer because " "skip_tokenizer_init is True"
283+
)
266284
if not isinstance(tokenizer_group, group_type):
267-
raise TypeError("Invalid type of tokenizer group. "
268-
f"Expected type: {group_type}, but "
269-
f"found type: {type(tokenizer_group)}")
285+
raise TypeError(
286+
"Invalid type of tokenizer group. "
287+
f"Expected type: {group_type}, but "
288+
f"found type: {type(tokenizer_group)}"
289+
)
270290

271291
return tokenizer_group
272292

@@ -286,11 +306,13 @@ def pin_lora(self, lora_id: int) -> bool:
286306
"""Prevent an adapter from being evicted."""
287307
return self.engine_core.pin_lora(lora_id)
288308

289-
def collective_rpc(self,
290-
method: Union[str, Callable[..., _R]],
291-
timeout: Optional[float] = None,
292-
args: tuple = (),
293-
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
309+
def collective_rpc(
310+
self,
311+
method: Union[str, Callable[..., _R]],
312+
timeout: Optional[float] = None,
313+
args: tuple = (),
314+
kwargs: Optional[dict[str, Any]] = None,
315+
) -> list[_R]:
294316
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
295317

296318
def __del__(self):

vllm/v1/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from vllm.logger import init_logger
1414
from vllm.model_executor.models.utils import extract_layer_index
15+
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
16+
usage_message)
1517
from vllm.utils import get_mp_context, kill_process_tree
1618

1719
if TYPE_CHECKING:
@@ -201,3 +203,45 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
201203
Returns the sliced target tensor.
202204
"""
203205
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
206+
207+
208+
def report_usage_stats(vllm_config, usage_context: UsageContext) -> None:
209+
"""Report usage statistics if enabled."""
210+
211+
if not is_usage_stats_enabled():
212+
return
213+
214+
from vllm.model_executor.model_loader import get_architecture_class_name
215+
216+
usage_message.report_usage(
217+
get_architecture_class_name(vllm_config.model_config),
218+
usage_context,
219+
extra_kvs={
220+
# Common configuration
221+
"dtype":
222+
str(vllm_config.model_config.dtype),
223+
"tensor_parallel_size":
224+
vllm_config.parallel_config.tensor_parallel_size,
225+
"block_size":
226+
vllm_config.cache_config.block_size,
227+
"gpu_memory_utilization":
228+
vllm_config.cache_config.gpu_memory_utilization,
229+
230+
# Quantization
231+
"quantization":
232+
vllm_config.model_config.quantization,
233+
"kv_cache_dtype":
234+
str(vllm_config.cache_config.cache_dtype),
235+
236+
# Feature flags
237+
"enable_lora":
238+
bool(vllm_config.lora_config),
239+
"enable_prompt_adapter":
240+
bool(vllm_config.prompt_adapter_config),
241+
"enable_prefix_caching":
242+
vllm_config.cache_config.enable_prefix_caching,
243+
"enforce_eager":
244+
vllm_config.model_config.enforce_eager,
245+
"disable_custom_all_reduce":
246+
vllm_config.parallel_config.disable_custom_all_reduce,
247+
})

0 commit comments

Comments
 (0)