Skip to content

[V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics #16665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, VllmConfig)
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
Expand All @@ -31,6 +31,7 @@ def create_scheduler(
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
) -> Scheduler:
'''Create scheduler under test.

Expand Down Expand Up @@ -81,11 +82,17 @@ def create_scheduler(
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None

speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)

vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
Expand Down Expand Up @@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):

def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=1)

# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
Expand Down Expand Up @@ -481,7 +488,7 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [10, 11]

# Test case 2: Stop on custom stop token
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
Expand Down Expand Up @@ -533,7 +540,7 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [13, 14]

# Test case 3: Stop on max tokens
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
Expand Down Expand Up @@ -583,7 +590,7 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [13]

# Test case 4: Ignore EOS flag
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
Expand Down Expand Up @@ -686,13 +693,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
([[1]], [[1, 2]], (1, 1)), # single token sequence
([[]], [[5]], (0, 0)), # empty sequence
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
(2, 3, 3, [2, 1])), # multiple sequences
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(6, 3)), # multiple mismatches
(2, 6, 3, [2, 1, 0])), # multiple mismatches
])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
Expand All @@ -701,7 +709,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
scheduler = create_scheduler()
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
Expand Down Expand Up @@ -774,8 +783,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
else:
assert scheduler_stats.spec_decoding_stats is not None
stats = scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1]
assert stats.num_drafts == expected[0]
assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3]


def _assert_right_scheduler_output(
Expand Down
16 changes: 9 additions & 7 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

self.num_lookahead_tokens = 0
speculative_config = vllm_config.speculative_config
if speculative_config and speculative_config.method == "eagle":
self.num_lookahead_tokens = \
speculative_config.num_speculative_tokens
self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config:
self.num_spec_tokens = speculative_config.num_speculative_tokens
if speculative_config.method == "eagle":
self.num_lookahead_tokens = self.num_spec_tokens

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
Expand Down Expand Up @@ -827,7 +828,8 @@ def make_spec_decoding_stats(
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats()
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
spec_decoding_stats.observe_draft(
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
return spec_decoding_stats
35 changes: 9 additions & 26 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm

logger = init_logger(__name__)

Expand All @@ -39,7 +39,7 @@ def __init__(self, engine_index: int = 0):
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0

Expand Down Expand Up @@ -70,7 +70,7 @@ def record(self, scheduler_stats: SchedulerStats,
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)

if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)

self.last_scheduler_stats = scheduler_stats
Expand Down Expand Up @@ -112,7 +112,7 @@ def log(self):
)

if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log(log_fn=log_fn)
self.spec_decoding_logging.log(log_fn=log_fn)


class PrometheusStatLogger(StatLoggerBase):
Expand All @@ -133,6 +133,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):

max_model_len = vllm_config.model_config.max_model_len

self.spec_decoding_prom = SpecDecodingProm(
vllm_config.speculative_config, labelnames, labelvalues)

#
# Scheduler state
#
Expand Down Expand Up @@ -323,24 +326,6 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.labelname_running_lora_adapters,
])

#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)

#
# Cache config info metric
#
Expand Down Expand Up @@ -378,10 +363,8 @@ def record(self, scheduler_stats: SchedulerStats,
scheduler_stats.prefix_cache_stats.hits)

if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats)

if iteration_stats is None:
return
Expand Down
Loading