Skip to content

Commit 2ec0270

Browse files
LiuXiaoxuanPKUjimpang
authored and
jimpang
committed
[V1][Spec Decode] Ngram Spec Decode (vllm-project#12193)
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
1 parent 1198197 commit 2ec0270

21 files changed

+1025
-84
lines changed

tests/v1/core/test_scheduler.py

+196-6
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
55
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
66
from vllm.sampling_params import SamplingParams
7-
from vllm.v1.core.scheduler import Scheduler
7+
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
88
from vllm.v1.outputs import ModelRunnerOutput
99
from vllm.v1.request import Request, RequestStatus
1010

11+
EOS_TOKEN_ID = 50256
12+
1113

1214
def create_scheduler(
1315
model: str = "facebook/opt-125m",
@@ -38,6 +40,7 @@ def create_scheduler(
3840
return Scheduler(scheduler_config,
3941
model_config,
4042
cache_config,
43+
speculative_config=None,
4144
lora_config=None,
4245
log_stats=True)
4346

@@ -46,8 +49,12 @@ def create_requests(
4649
num_requests: int,
4750
num_tokens: int = 10,
4851
mm_positions: Optional[List[PlaceholderRange]] = None,
52+
max_tokens: int = 16,
53+
stop_token_ids: Optional[List[int]] = None,
4954
):
50-
sampling_params = SamplingParams()
55+
sampling_params = SamplingParams(ignore_eos=False,
56+
max_tokens=max_tokens,
57+
stop_token_ids=stop_token_ids)
5158
requests = []
5259
for i in range(num_requests):
5360
if mm_positions is not None:
@@ -64,7 +71,7 @@ def create_requests(
6471
multi_modal_inputs=mm_inputs,
6572
multi_modal_placeholders=mm_position,
6673
multi_modal_hashes=None,
67-
eos_token_id=None,
74+
eos_token_id=EOS_TOKEN_ID,
6875
arrival_time=0,
6976
)
7077
requests.append(request)
@@ -195,7 +202,7 @@ def test_schedule_partial_requests():
195202
model_runner_output = ModelRunnerOutput(
196203
req_ids=[request.request_id for request in requests],
197204
req_id_to_index=req_to_index,
198-
sampled_token_ids=[0] * len(requests),
205+
sampled_token_ids=[[0] for _ in range(len(requests))],
199206
logprobs=None,
200207
prompt_logprobs_dict={},
201208
)
@@ -215,6 +222,189 @@ def test_schedule_partial_requests():
215222
assert requests[2].request_id not in output.num_scheduled_tokens
216223

217224

225+
def test_stop_via_update_from_output():
226+
"""Test stopping behavior through update_from_output"""
227+
scheduler = create_scheduler()
228+
229+
# Test case 1: Stop on EOS token
230+
requests = create_requests(num_requests=2, max_tokens=10)
231+
for req in requests:
232+
req.num_computed_tokens = req.num_tokens
233+
scheduler.requests[req.request_id] = req
234+
scheduler.running.append(req)
235+
scheduler.scheduled_req_ids.add(req.request_id)
236+
237+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
238+
scheduled_cached_reqs=[],
239+
num_scheduled_tokens={
240+
requests[0].request_id: 1,
241+
requests[1].request_id: 2
242+
},
243+
total_num_scheduled_tokens=3,
244+
scheduled_encoder_inputs={},
245+
scheduled_spec_decode_tokens={
246+
requests[0].request_id: [],
247+
requests[1].request_id: [10]
248+
},
249+
num_common_prefix_blocks=0,
250+
finished_req_ids=set(),
251+
free_encoder_input_ids=[])
252+
253+
model_output = ModelRunnerOutput(
254+
req_ids=[req.request_id for req in requests],
255+
req_id_to_index={
256+
req.request_id: i
257+
for i, req in enumerate(requests)
258+
},
259+
sampled_token_ids=[[EOS_TOKEN_ID],
260+
[10,
261+
11]], # First request hits EOS, second continues
262+
logprobs=None,
263+
prompt_logprobs_dict={})
264+
265+
scheduler.update_from_output(scheduler_output, model_output)
266+
267+
# Verify first request stopped, second continues
268+
assert len(scheduler.running) == 1
269+
assert scheduler.running[0].request_id == requests[1].request_id
270+
assert requests[0].status == RequestStatus.FINISHED_STOPPED
271+
assert requests[0].request_id in scheduler.finished_req_ids
272+
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
273+
assert list(requests[1].output_token_ids) == [10, 11]
274+
275+
# Test case 2: Stop on custom stop token
276+
scheduler = create_scheduler()
277+
requests = create_requests(num_requests=2,
278+
max_tokens=10,
279+
stop_token_ids=[42, 43])
280+
for req in requests:
281+
req.num_computed_tokens = req.num_tokens
282+
scheduler.requests[req.request_id] = req
283+
scheduler.running.append(req)
284+
scheduler.scheduled_req_ids.add(req.request_id)
285+
286+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
287+
scheduled_cached_reqs=[],
288+
num_scheduled_tokens={
289+
requests[0].request_id: 3,
290+
requests[1].request_id: 2
291+
},
292+
total_num_scheduled_tokens=5,
293+
scheduled_encoder_inputs={},
294+
scheduled_spec_decode_tokens={
295+
requests[0].request_id: [10, 42],
296+
requests[1].request_id: [13]
297+
},
298+
num_common_prefix_blocks=0,
299+
finished_req_ids=set(),
300+
free_encoder_input_ids=[])
301+
302+
model_output = ModelRunnerOutput(
303+
req_ids=[req.request_id for req in requests],
304+
req_id_to_index={
305+
req.request_id: i
306+
for i, req in enumerate(requests)
307+
},
308+
sampled_token_ids=[[10, 42, 12],
309+
[13, 14]], # First request hits stop token
310+
logprobs=None,
311+
prompt_logprobs_dict={})
312+
313+
scheduler.update_from_output(scheduler_output, model_output)
314+
315+
# Verify first request stopped on custom token
316+
assert len(scheduler.running) == 1
317+
assert scheduler.running[0].request_id == requests[1].request_id
318+
assert requests[0].status == RequestStatus.FINISHED_STOPPED
319+
assert requests[0].stop_reason == 42
320+
assert requests[0].request_id in scheduler.finished_req_ids
321+
assert list(requests[0].output_token_ids) == [10, 42]
322+
assert list(requests[1].output_token_ids) == [13, 14]
323+
324+
# Test case 3: Stop on max tokens
325+
scheduler = create_scheduler()
326+
requests = create_requests(num_requests=2, max_tokens=2)
327+
for req in requests:
328+
req.num_computed_tokens = req.num_tokens
329+
scheduler.requests[req.request_id] = req
330+
scheduler.running.append(req)
331+
scheduler.scheduled_req_ids.add(req.request_id)
332+
333+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
334+
scheduled_cached_reqs=[],
335+
num_scheduled_tokens={
336+
requests[0].request_id: 3,
337+
requests[1].request_id: 1
338+
},
339+
total_num_scheduled_tokens=4,
340+
scheduled_encoder_inputs={},
341+
scheduled_spec_decode_tokens={
342+
requests[0].request_id: [10, 11],
343+
requests[1].request_id: []
344+
},
345+
num_common_prefix_blocks=0,
346+
finished_req_ids=set(),
347+
free_encoder_input_ids=[])
348+
349+
model_output = ModelRunnerOutput(
350+
req_ids=[req.request_id for req in requests],
351+
req_id_to_index={
352+
req.request_id: i
353+
for i, req in enumerate(requests)
354+
},
355+
sampled_token_ids=[[10, 11, 12],
356+
[13]], # First request exceeds max_tokens
357+
logprobs=None,
358+
prompt_logprobs_dict={})
359+
360+
scheduler.update_from_output(scheduler_output, model_output)
361+
362+
# Verify first request stopped due to length
363+
assert len(scheduler.running) == 1
364+
assert scheduler.running[0].request_id == requests[1].request_id
365+
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
366+
assert requests[0].request_id in scheduler.finished_req_ids
367+
assert list(requests[0].output_token_ids) == [10, 11
368+
] # Truncated to max_tokens
369+
assert list(requests[1].output_token_ids) == [13]
370+
371+
# Test case 4: Ignore EOS flag
372+
scheduler = create_scheduler()
373+
requests = create_requests(num_requests=1, max_tokens=10)
374+
requests[0].sampling_params.ignore_eos = True
375+
requests[0].num_computed_tokens = requests[0].num_tokens
376+
scheduler.requests[requests[0].request_id] = requests[0]
377+
scheduler.running.append(requests[0])
378+
scheduler.scheduled_req_ids.add(requests[0].request_id)
379+
380+
scheduler_output = SchedulerOutput(
381+
scheduled_new_reqs=[],
382+
scheduled_cached_reqs=[],
383+
num_scheduled_tokens={requests[0].request_id: 3},
384+
total_num_scheduled_tokens=3,
385+
scheduled_encoder_inputs={},
386+
scheduled_spec_decode_tokens={
387+
requests[0].request_id: [EOS_TOKEN_ID, 10]
388+
},
389+
num_common_prefix_blocks=0,
390+
finished_req_ids=set(),
391+
free_encoder_input_ids=[])
392+
393+
model_output = ModelRunnerOutput(
394+
req_ids=[requests[0].request_id],
395+
req_id_to_index={requests[0].request_id: 0},
396+
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
397+
logprobs=None,
398+
prompt_logprobs_dict={})
399+
400+
scheduler.update_from_output(scheduler_output, model_output)
401+
402+
# Verify request continues past EOS
403+
assert len(scheduler.running) == 1
404+
assert not requests[0].is_finished()
405+
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
406+
407+
218408
def test_schedule_concurrent_batches():
219409
scheduler = create_scheduler(
220410
max_num_batched_tokens=1024,
@@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
243433
model_runner_output = ModelRunnerOutput(
244434
req_ids=[requests[0].request_id],
245435
req_id_to_index={requests[0].request_id: 0},
246-
sampled_token_ids=[0],
436+
sampled_token_ids=[[0]],
247437
logprobs=None,
248438
prompt_logprobs_dict={},
249439
)
@@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
259449
model_runner_output = ModelRunnerOutput(
260450
req_ids=[requests[1].request_id],
261451
req_id_to_index={requests[1].request_id: 0},
262-
sampled_token_ids=[0],
452+
sampled_token_ids=[[0]],
263453
logprobs=None,
264454
prompt_logprobs_dict={},
265455
)
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
from vllm import LLM, SamplingParams
5+
6+
7+
@pytest.fixture
8+
def test_prompts():
9+
return [
10+
"Can you repeat the sentence ten times, this is a sentence.",
11+
"Can you repeat the sentence ten times, this is a test.",
12+
]
13+
14+
15+
@pytest.fixture
16+
def sampling_config():
17+
# Only support greedy for now
18+
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)
19+
20+
21+
@pytest.fixture
22+
def model_name():
23+
return "meta-llama/Meta-Llama-3-8B-Instruct"
24+
25+
26+
def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
27+
model_name):
28+
'''
29+
Compare the outputs of a original LLM and a speculative LLM
30+
should be the same when using ngram speculative decoding.
31+
'''
32+
with monkeypatch.context() as m:
33+
m.setenv("VLLM_USE_V1", "1")
34+
35+
ref_llm = LLM(model=model_name)
36+
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
37+
del ref_llm
38+
39+
spec_llm = LLM(model=model_name,
40+
speculative_model='[ngram]',
41+
ngram_prompt_lookup_max=5,
42+
ngram_prompt_lookup_min=3,
43+
num_speculative_tokens=3)
44+
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
45+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
46+
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
47+
(f"ref_output: {ref_output.outputs[0].text},"
48+
f"spec_output: {spec_output.outputs[0].text}")
49+
del spec_llm

0 commit comments

Comments
 (0)