Skip to content

[V1][Spec Decode] Always use argmax for sampling draft tokens #16899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/v1/spec_decode/test_max_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""

import pytest

from vllm import LLM, SamplingParams

_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
]


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
},
max_model_len=100,
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
7 changes: 7 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ def schedule(self) -> SchedulerOutput:
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
assert num_new_tokens > 0

# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
Expand Down
14 changes: 7 additions & 7 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
Expand Down Expand Up @@ -423,7 +423,7 @@ def sample_recovered_tokens(
q,
vocab_size,
triton.next_power_of_2(vocab_size),
IS_NGRAM=draft_probs is None,
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids

Expand Down Expand Up @@ -490,7 +490,7 @@ def rejection_random_sample_kernel(
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
Expand All @@ -509,7 +509,7 @@ def rejection_random_sample_kernel(
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if IS_NGRAM:
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
Expand Down Expand Up @@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel(
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
IS_NGRAM: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
Expand All @@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel(
return

vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if IS_NGRAM:
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
Expand Down Expand Up @@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel(
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)

if IS_NGRAM:
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
Expand Down
90 changes: 38 additions & 52 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata

PADDING_SLOT_ID = -1


class EagleProposer:

Expand All @@ -23,6 +25,7 @@ def __init__(
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
Expand All @@ -48,7 +51,7 @@ def propose(
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
Expand Down Expand Up @@ -91,17 +94,15 @@ def propose(
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids = logits.argmax(dim=-1)

# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1] and [batch_size, 1, vocab_size]
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
# [batch_size, 1]
return draft_token_ids.view(-1, 1)

# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
draft_probs_list = [draft_probs]

positions = target_positions[last_token_indices]
hidden_states = sample_hidden_states
Expand All @@ -112,34 +113,56 @@ def propose(
# Update the inputs.
input_ids = draft_token_ids_list[-1]
positions += 1

# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)

# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

# Compute the slot mapping.
block_numbers = positions // self.block_size
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
positions % self.block_size)
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)

# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=positions,
positions=clamped_positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
draft_probs_list.append(probs)

# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs = torch.stack(draft_probs_list, dim=1)
return draft_token_ids, draft_probs
return draft_token_ids

@staticmethod
def prepare_inputs(
Expand Down Expand Up @@ -209,43 +232,6 @@ def load_model(self, target_model: nn.Module) -> None:
self.model.lm_head = target_model.lm_head


# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs

is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)

# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.

# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q.exponential_()
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs


@triton.jit
def prepare_input_kernel(
out_ptr,
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(self, vllm_config: VllmConfig):
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len

# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))
Expand Down Expand Up @@ -50,9 +53,11 @@ def propose(
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# Do not generate draft tokens beyond the max model length.
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
# TODO(woosuk): Optimize this.
for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, self.k)
result = _find_subarray_kmp(context_token_ids, n, k)
if result is not None:
return result
return None
Expand Down
13 changes: 8 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ def execute_model(
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]

draft_token_ids, draft_probs = self.drafter.propose(
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
Expand All @@ -1240,9 +1240,6 @@ def execute_model(
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
# in the next step.
del draft_probs

# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
Expand Down Expand Up @@ -1271,7 +1268,8 @@ def generate_draft_token_ids(
draft_token_ids.append([])
continue

# Skip requests that require top-p, top-k, etc.
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = self.input_batch.req_ids[i]
if not is_spec_decode_supported(req_id, self.input_batch):
draft_token_ids.append([])
Expand All @@ -1280,6 +1278,11 @@ def generate_draft_token_ids(
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
if end_idx >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue

self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
Expand Down