Skip to content

Commit 86757ac

Browse files
committed
[V1][Metrics] Add API for accessing in-memory Prometheus metrics
Signed-off-by: Mark McLoughlin <[email protected]>
1 parent 069c723 commit 86757ac

File tree

5 files changed

+238
-4
lines changed

5 files changed

+238
-4
lines changed

examples/offline_inference/metrics.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import LLM, SamplingParams
4+
5+
# Sample prompts.
6+
prompts = [
7+
"Hello, my name is",
8+
"The president of the United States is",
9+
"The capital of France is",
10+
"The future of AI is",
11+
]
12+
# Create a sampling params object.
13+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
14+
15+
16+
def main():
17+
# Create an LLM.
18+
llm = LLM(model="facebook/opt-125m", disable_log_stats=False)
19+
20+
# Generate texts from the prompts.
21+
outputs = llm.generate(prompts, sampling_params)
22+
23+
# Print the outputs.
24+
print("-" * 50)
25+
for output in outputs:
26+
prompt = output.prompt
27+
generated_text = output.outputs[0].text
28+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
29+
print("-" * 50)
30+
31+
# Dump all metrics
32+
metrics = llm.get_metrics()
33+
for metric in metrics.list_metrics():
34+
metric_type = metrics.get_type(metric)
35+
if metric_type in ("counter", "gauge"):
36+
print(f"{metric} ({metric_type}) = {metrics.get_value(metric)}")
37+
else:
38+
assert metric_type == "histogram"
39+
print(f"{metric} ({metric_type})")
40+
print(f" sum = {metrics.get_histogram_sum(metric)}")
41+
print(f" count = {metrics.get_histogram_count(metric)}")
42+
for bucket_le in metrics.get_histogram_buckets(metric):
43+
print(
44+
f" {bucket_le} = {metrics.get_value(metric, bucket_le)}"
45+
)
46+
47+
48+
if __name__ == "__main__":
49+
main()

tests/v1/engine/test_llm_engine.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,48 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
9797
raise AssertionError(
9898
f"{len(completion_counts)} unique completions; expected"
9999
f" {n}. Repeats: {repeats}")
100+
101+
102+
def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
103+
max_tokens = 100
104+
with vllm_runner(
105+
MODEL,
106+
disable_log_stats=False,
107+
) as vllm_model:
108+
model: LLM = vllm_model.model
109+
sampling_params = SamplingParams(temperature=0.0,
110+
max_tokens=max_tokens)
111+
outputs = model.generate(example_prompts, sampling_params)
112+
113+
n_prompts = len(example_prompts)
114+
assert len(outputs) == n_prompts
115+
116+
total_tokens = 0
117+
for out in outputs:
118+
assert len(out.outputs) == 1
119+
total_tokens += len(out.outputs[0].token_ids)
120+
assert total_tokens == max_tokens * n_prompts
121+
122+
snapshot = model.get_metrics()
123+
metrics = snapshot.list_metrics()
124+
125+
assert "vllm:num_requests_running" in metrics
126+
assert snapshot.get_type("vllm:num_requests_running") == "gauge"
127+
assert snapshot.get_value("vllm:num_requests_running") == 0
128+
129+
assert "vllm:generation_tokens" in metrics
130+
assert snapshot.get_type("vllm:generation_tokens") == "counter"
131+
assert snapshot.get_value("vllm:generation_tokens") == total_tokens
132+
133+
assert "vllm:request_generation_tokens" in metrics
134+
assert snapshot.get_type(
135+
"vllm:request_generation_tokens") == "histogram"
136+
buckets = snapshot.get_histogram_buckets(
137+
"vllm:request_generation_tokens")
138+
assert buckets[-1] == "+Inf"
139+
assert snapshot.get_value("vllm:request_generation_tokens",
140+
buckets[-1]) == n_prompts
141+
assert snapshot.get_histogram_count(
142+
"vllm:request_generation_tokens") == n_prompts
143+
assert snapshot.get_histogram_sum(
144+
"vllm:request_generation_tokens") == total_tokens

vllm/entrypoints/llm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from vllm.usage.usage_lib import UsageContext
4545
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
4646
is_list_of)
47+
from vllm.v1.metrics.reader import MetricsSnapshot
4748

4849
logger = init_logger(__name__)
4950

@@ -1269,6 +1270,20 @@ def wake_up(self, tags: Optional[list[str]] = None):
12691270
"""
12701271
self.llm_engine.wake_up(tags)
12711272

1273+
def get_metrics(self) -> MetricsSnapshot:
1274+
"""Return a snapshot of aggregated metrics from Prometheus.
1275+
1276+
Returns:
1277+
A ``MetricSnapshot`` instance capturing the current state
1278+
of all aggregated metrics from Prometheus.
1279+
1280+
Note:
1281+
This method is only available with the V1 LLM engine.
1282+
"""
1283+
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
1284+
assert isinstance(self.llm_engine, V1LLMEngine)
1285+
return self.llm_engine.get_metrics()
1286+
12721287
# LEGACY
12731288
def _convert_v1_inputs(
12741289
self,

vllm/v1/engine/llm_engine.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from vllm.config import ParallelConfig, VllmConfig
1111
from vllm.distributed import stateless_destroy_torch_distributed_process_group
1212
from vllm.engine.arg_utils import EngineArgs
13-
from vllm.engine.metrics_types import StatLoggerBase
1413
from vllm.inputs import PromptType
1514
from vllm.logger import init_logger
1615
from vllm.lora.request import LoRARequest
@@ -28,6 +27,9 @@
2827
from vllm.v1.engine.parallel_sampling import ParentRequest
2928
from vllm.v1.engine.processor import Processor
3029
from vllm.v1.executor.abstract import Executor
30+
from vllm.v1.metrics.loggers import PrometheusStatLogger, StatLoggerBase
31+
from vllm.v1.metrics.reader import MetricsSnapshot
32+
from vllm.v1.metrics.stats import IterationStats
3133

3234
logger = init_logger(__name__)
3335

@@ -60,6 +62,11 @@ def __init__(
6062
self.model_config = vllm_config.model_config
6163
self.cache_config = vllm_config.cache_config
6264

65+
self.log_stats = log_stats
66+
self.stat_logger: Optional[StatLoggerBase] = None
67+
if self.log_stats:
68+
self.stat_logger = PrometheusStatLogger(vllm_config)
69+
6370
# important: init dp group before init the engine_core
6471
# In the decoupled engine case this is handled in EngineCoreProc.
6572
parallel_config = vllm_config.parallel_config
@@ -84,15 +91,15 @@ def __init__(
8491

8592
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
8693
self.output_processor = OutputProcessor(self.tokenizer,
87-
log_stats=False)
94+
log_stats=self.log_stats)
8895

8996
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
9097
self.engine_core = EngineCoreClient.make_client(
9198
multiprocess_mode=multiprocess_mode,
9299
asyncio_mode=False,
93100
vllm_config=vllm_config,
94101
executor_class=executor_class,
95-
log_stats=False, # FIXME: implement
102+
log_stats=self.log_stats,
96103
)
97104

98105
if not multiprocess_mode:
@@ -222,12 +229,21 @@ def step(self) -> list[RequestOutput]:
222229
outputs = self.engine_core.get_output()
223230

224231
# 2) Process EngineCoreOutputs.
232+
iteration_stats = IterationStats() if self.log_stats else None
225233
processed_outputs = self.output_processor.process_outputs(
226-
outputs.outputs)
234+
outputs.outputs,
235+
engine_core_timestamp=outputs.timestamp,
236+
iteration_stats=iteration_stats)
227237

228238
# 3) Abort any reqs that finished due to stop strings.
229239
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
230240

241+
# 4) Record stats
242+
if self.stat_logger is not None:
243+
assert outputs.scheduler_stats is not None
244+
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
245+
iteration_stats=iteration_stats)
246+
231247
return processed_outputs.request_outputs
232248

233249
def get_vllm_config(self):
@@ -254,6 +270,10 @@ def wake_up(self, tags: Optional[list[str]] = None):
254270
def is_sleeping(self) -> bool:
255271
return self.engine_core.is_sleeping()
256272

273+
def get_metrics(self) -> Optional[MetricsSnapshot]:
274+
assert self.log_stats, "Stat logging disabled"
275+
return MetricsSnapshot()
276+
257277
def get_tokenizer_group(
258278
self,
259279
group_type: type[_G] = BaseTokenizerGroup,

vllm/v1/metrics/reader.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from collections.abc import Iterable
4+
from typing import Optional
5+
6+
from prometheus_client import REGISTRY, Metric
7+
from prometheus_client.samples import Sample
8+
9+
10+
class MetricsSnapshot:
11+
"""An API for accessing in-memory Prometheus metrics.
12+
13+
Example:
14+
>>> metrics = llm.get_metrics()
15+
>>> for metric in metrics.list_metrics():
16+
... if metrics.get_type(metric) in ("counter", "gauge"):
17+
... print(f"{metric} = {metrics.get_value(metric)}")
18+
... else:
19+
... print(f"{metric}")
20+
... print(f" sum = {metrics.get_histogram_sum(metric)}")
21+
... print(f" count = {metrics.get_histogram_count(metric)}")
22+
... for bucket_le in metrics.get_histogram_buckets(metric):
23+
... print(f" {bucket_le} = "
24+
... "{metrics.get_value(metric, bucket_le)}")
25+
"""
26+
27+
def __init__(self):
28+
self._collected: dict[str, Metric] = {}
29+
for metric in REGISTRY.collect():
30+
self._collected[metric.name] = metric
31+
32+
def list_metrics(self) -> list[str]:
33+
"""Returns a list of all metric names currently available."""
34+
return list(self._collected.keys())
35+
36+
def get_type(self, name: str) -> str:
37+
"""Return the type of a metric - gauge, counter, or histogram."""
38+
return self._collected[name].type
39+
40+
def get_value(self, name: str, bucket_le: Optional[str] = None) -> float:
41+
"""Retrieves the value of a specific metric by its name.
42+
43+
Only supports gauge, counter, and histogram.
44+
45+
Raises an exception if the metric name is not found, or if it
46+
is an unknown type.
47+
48+
Args:
49+
name: The name of the metric.
50+
bucket_le: The bucket label for histograms, if applicable.
51+
"""
52+
metric = self._collected[name]
53+
if metric.type == "gauge":
54+
sample = self._must_get_sample(metric, name)
55+
elif metric.type == "counter":
56+
sample = self._must_get_sample(metric, name, "_total")
57+
elif metric.type == "histogram":
58+
assert bucket_le is not None
59+
for sample in self._get_samples(metric, name, "_bucket"):
60+
if sample.labels["le"] == bucket_le:
61+
break
62+
else:
63+
raise KeyError(f"No bucket {bucket_le} for {name}")
64+
else:
65+
raise AssertionError(f"Unknown metric type {metric.type}")
66+
assert sample is not None
67+
return sample.value
68+
69+
def get_histogram_buckets(self, name: str) -> list[str]:
70+
"""Returns the bucket labels for a histogram metric."""
71+
histogram = self._must_get_histogram(name)
72+
buckets = self._get_samples(histogram, name, "_bucket")
73+
return [s.labels["le"] for s in buckets]
74+
75+
def get_histogram_count(self, name: str) -> float:
76+
"""Returns the count of samples for a histogram metric."""
77+
histogram = self._must_get_histogram(name)
78+
return self._must_get_sample(histogram, name, "_count").value
79+
80+
def get_histogram_sum(self, name: str) -> float:
81+
"""Returns the sum of samples for a histogram metric."""
82+
histogram = self._must_get_histogram(name)
83+
return self._must_get_sample(histogram, name, "_sum").value
84+
85+
#
86+
# Helper methods
87+
#
88+
def _must_get_histogram(self, name: str) -> Metric:
89+
metric = self._collected[name]
90+
assert metric.type == "histogram"
91+
return metric
92+
93+
@staticmethod
94+
def _must_get_sample(metric: Metric,
95+
name: str,
96+
suffix: Optional[str] = None) -> Sample:
97+
fullname = (name + suffix) if suffix is not None else name
98+
return next(s for s in metric.samples if s.name == fullname)
99+
100+
@staticmethod
101+
def _get_samples(metric: Metric,
102+
name: str,
103+
suffix: Optional[str] = None) -> Iterable[Sample]:
104+
fullname = (name + suffix) if suffix is not None else name
105+
return (s for s in metric.samples if s.name == fullname)

0 commit comments

Comments
 (0)