Skip to content

[V1][Spec Decode] Ngram Spec Decode #12193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 92 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
6071a4b
skeleton
LiuXiaoxuanPKU Jan 19, 2025
be798e7
runnable but incorrect
LiuXiaoxuanPKU Jan 19, 2025
f0976dd
fix
LiuXiaoxuanPKU Jan 19, 2025
6039933
pass for simple non spec case
LiuXiaoxuanPKU Jan 19, 2025
03cd3dd
pass args and minor variable name bug fix
LiuXiaoxuanPKU Jan 19, 2025
26ba690
minor
LiuXiaoxuanPKU Jan 19, 2025
ca5e0dd
minimal example
LiuXiaoxuanPKU Jan 19, 2025
704634c
Merge branch 'main' of github.com:LiuXiaoxuanPKU/vllm into ngram
LiuXiaoxuanPKU Jan 20, 2025
62012d1
minor
LiuXiaoxuanPKU Jan 20, 2025
7bd3f27
format
LiuXiaoxuanPKU Jan 20, 2025
b0c5d25
basic test
LiuXiaoxuanPKU Jan 20, 2025
d5ee081
minor
LiuXiaoxuanPKU Jan 20, 2025
4e11585
minor
LiuXiaoxuanPKU Jan 20, 2025
f915eda
stop checking
LiuXiaoxuanPKU Jan 20, 2025
bd8ac07
test for stop checking
LiuXiaoxuanPKU Jan 21, 2025
008a41e
style and disable scheduling chunked requests
LiuXiaoxuanPKU Jan 22, 2025
784b24a
signed-off-by
LiuXiaoxuanPKU Jan 22, 2025
f3f6ebc
ngram proposer
LiuXiaoxuanPKU Jan 24, 2025
5e7306e
style and minor output token fix
LiuXiaoxuanPKU Jan 24, 2025
a26df8d
partial cleanup & update the kmp
LiuXiaoxuanPKU Jan 28, 2025
eeab204
minor
LiuXiaoxuanPKU Jan 28, 2025
6772e07
minor
LiuXiaoxuanPKU Jan 28, 2025
a5932a7
fix comments
LiuXiaoxuanPKU Jan 31, 2025
a6245e8
minor
LiuXiaoxuanPKU Feb 3, 2025
7890287
merge
LiuXiaoxuanPKU Feb 3, 2025
5d3a31a
change sampled_token_ids to tensor
LiuXiaoxuanPKU Feb 3, 2025
f7f4c24
minor
LiuXiaoxuanPKU Feb 3, 2025
a1eecd3
remove double free
LiuXiaoxuanPKU Feb 3, 2025
c843121
fix bug in input batch token id update
LiuXiaoxuanPKU Feb 3, 2025
cdcace5
constant list for spec tokens
LiuXiaoxuanPKU Feb 3, 2025
97ea9f2
Merge branch 'main' into ngram
LiuXiaoxuanPKU Feb 3, 2025
2cab6e6
header
LiuXiaoxuanPKU Feb 3, 2025
65875e8
Merge branch 'main' into ngram
LiuXiaoxuanPKU Feb 3, 2025
e30bc0c
bug fix for invalid token id check
LiuXiaoxuanPKU Feb 4, 2025
18fba42
type
LiuXiaoxuanPKU Feb 4, 2025
7f08f9c
Merge commit 'e30bc0cd' into ngram
LiuXiaoxuanPKU Feb 4, 2025
ba1d0fd
prefix caching + sd
LiuXiaoxuanPKU Feb 4, 2025
fc69953
pass in max_spec_num
LiuXiaoxuanPKU Feb 4, 2025
970a91a
fix block calcaulation
LiuXiaoxuanPKU Feb 4, 2025
7ecb668
minor
LiuXiaoxuanPKU Feb 4, 2025
10b3fe6
merge
LiuXiaoxuanPKU Feb 5, 2025
acda923
fix comments
LiuXiaoxuanPKU Feb 5, 2025
2006c75
fix test
LiuXiaoxuanPKU Feb 5, 2025
2ad4f39
fix test
LiuXiaoxuanPKU Feb 5, 2025
faafcb6
fix test
LiuXiaoxuanPKU Feb 5, 2025
038e203
merge test
LiuXiaoxuanPKU Feb 5, 2025
4dc0f87
stop checking
LiuXiaoxuanPKU Feb 7, 2025
54508d5
kv cache manager
LiuXiaoxuanPKU Feb 8, 2025
c25b9eb
bug fix
LiuXiaoxuanPKU Feb 8, 2025
ace9518
merge
LiuXiaoxuanPKU Feb 8, 2025
f4ee865
fix scheduler
LiuXiaoxuanPKU Feb 8, 2025
d02844a
fix scheduler and tests
LiuXiaoxuanPKU Feb 8, 2025
50ab162
Simplify request
LiuXiaoxuanPKU Feb 8, 2025
f3b08f4
rejection sampling tests update
LiuXiaoxuanPKU Feb 8, 2025
036a23f
Merge branch 'ngram' of github.com:LiuXiaoxuanPKU/vllm into ngram
LiuXiaoxuanPKU Feb 8, 2025
840413b
optimize rejection sampler
LiuXiaoxuanPKU Feb 8, 2025
03f6bee
static
LiuXiaoxuanPKU Feb 8, 2025
5fb9ac1
format
LiuXiaoxuanPKU Feb 8, 2025
7c0497e
Update vllm/v1/core/scheduler.py
LiuXiaoxuanPKU Feb 10, 2025
e0bd8cc
Update vllm/v1/worker/gpu_model_runner.py
LiuXiaoxuanPKU Feb 10, 2025
95d34f0
fix comments
LiuXiaoxuanPKU Feb 10, 2025
3ff5ead
Merge branch 'ngram' of github.com:LiuXiaoxuanPKU/vllm into ngram
LiuXiaoxuanPKU Feb 10, 2025
4cc5f8d
minor
LiuXiaoxuanPKU Feb 10, 2025
e1654b9
merge
LiuXiaoxuanPKU Feb 10, 2025
4086a77
input prepare
LiuXiaoxuanPKU Feb 10, 2025
1e218af
fix input prepare
LiuXiaoxuanPKU Feb 11, 2025
633567a
simplify scheduleroutput
LiuXiaoxuanPKU Feb 11, 2025
888f183
change test case to make output more deterministic
LiuXiaoxuanPKU Feb 11, 2025
353c372
update cpu gpu sync
LiuXiaoxuanPKU Feb 11, 2025
9416792
vectorize rejection sampler
LiuXiaoxuanPKU Feb 11, 2025
ab22c2d
merge
LiuXiaoxuanPKU Feb 11, 2025
54c5fa5
fix comments
LiuXiaoxuanPKU Feb 11, 2025
00b9d69
merge
LiuXiaoxuanPKU Feb 11, 2025
4ea2fda
minor
LiuXiaoxuanPKU Feb 11, 2025
0d6d713
minor
LiuXiaoxuanPKU Feb 11, 2025
d064a1a
Merge branch 'main' into ngram
LiuXiaoxuanPKU Feb 11, 2025
af7322e
fix input prepare bug
LiuXiaoxuanPKU Feb 12, 2025
4dec71d
fix
LiuXiaoxuanPKU Feb 12, 2025
8758b96
fix test
LucasWilkinson Feb 12, 2025
8929ad1
fix comments
LiuXiaoxuanPKU Feb 15, 2025
65bb67f
minor fix
LiuXiaoxuanPKU Feb 15, 2025
992aab8
make test more deterministic
LiuXiaoxuanPKU Feb 15, 2025
6608e31
Merge branch 'main' into ngram
LiuXiaoxuanPKU Feb 15, 2025
e298bb3
merge conflict
LiuXiaoxuanPKU Feb 15, 2025
4329970
Merge branch 'ngram' of github.com:LiuXiaoxuanPKU/vllm into ngram
LiuXiaoxuanPKU Feb 15, 2025
4e015ae
fix
LiuXiaoxuanPKU Feb 15, 2025
5fc5264
fix rejection sampler tests
LiuXiaoxuanPKU Feb 15, 2025
b56a8e4
fix num_token
LiuXiaoxuanPKU Feb 15, 2025
a669c1c
merge
LiuXiaoxuanPKU Feb 15, 2025
2cbf57e
fix scheduler test
LiuXiaoxuanPKU Feb 15, 2025
29d3054
fix scheduler test, minor
LiuXiaoxuanPKU Feb 15, 2025
2dc7909
fix gpu model runner
LiuXiaoxuanPKU Feb 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 194 additions & 4 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

EOS_TOKEN_ID = 50256


def create_scheduler(
model: str = "facebook/opt-125m",
Expand Down Expand Up @@ -38,15 +40,20 @@ def create_scheduler(
return Scheduler(scheduler_config,
model_config,
cache_config,
speculative_config=None,
lora_config=None)


def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[List[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[List[int]] = None,
):
sampling_params = SamplingParams()
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids)
requests = []
for i in range(num_requests):
if mm_positions is not None:
Expand All @@ -63,7 +70,7 @@ def create_requests(
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
requests.append(request)
Expand Down Expand Up @@ -194,7 +201,7 @@ def test_schedule_partial_requests():
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[0] * len(requests),
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
)
Expand All @@ -212,3 +219,186 @@ def test_schedule_partial_requests():
assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700
assert requests[2].request_id not in output.num_scheduled_tokens


def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler()

# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
use_spec_decode=True,
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]

# Test case 2: Stop on custom stop token
scheduler = create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
use_spec_decode=True,
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]

# Test case 3: Stop on max tokens
scheduler = create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
use_spec_decode=True,
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11
] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]

# Test case 4: Ignore EOS flag
scheduler = create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
use_spec_decode=True,
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
49 changes: 49 additions & 0 deletions tests/v1/e2e/test_ngram_specdecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

from vllm import LLM, SamplingParams


@pytest.fixture
def test_prompts():
return [
"Can you repeat the sentence ten times, this is a sentence?",
"This is a basic spec decode test",
]


@pytest.fixture
def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=100, ignore_eos=False)


@pytest.fixture
def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"


def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
model_name):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name)
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
del ref_llm

spec_llm = LLM(model=model_name,
speculative_model='[ngram]',
ngram_prompt_lookup_max=5,
ngram_prompt_lookup_min=3,
num_speculative_tokens=3)
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
(f"ref_output: {ref_output.outputs[0].text},"
f"spec_output: {spec_output.outputs[0].text}")
del spec_llm
Loading