Skip to content

Commit 3d68c5c

Browse files
njhillRobert Shaw
authored andcommitted
[V1][Frontend] Coalesce bunched RequestOutputs (vllm-project#12298)
Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent 68cc6b8 commit 3d68c5c

File tree

3 files changed

+65
-18
lines changed

3 files changed

+65
-18
lines changed

tests/v1/engine/test_async_llm.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import asyncio
2+
from contextlib import ExitStack
23
from typing import List, Tuple
34

45
import pytest
56

67
from vllm import SamplingParams
78
from vllm.engine.arg_utils import AsyncEngineArgs
89
from vllm.platforms import current_platform
10+
from vllm.sampling_params import RequestOutputKind
911
from vllm.v1.engine.async_llm import AsyncLLM
1012

1113
if not current_platform.is_cuda():
@@ -18,28 +20,39 @@
1820

1921

2022
async def generate(engine: AsyncLLM, request_id: str,
23+
output_kind: RequestOutputKind,
2124
max_tokens: int) -> Tuple[int, str]:
2225
count = 0
23-
async for _ in engine.generate(request_id=request_id,
24-
prompt="Hello my name is Robert and",
25-
sampling_params=SamplingParams(
26-
max_tokens=max_tokens, temperature=0)):
26+
sampling_params = SamplingParams(max_tokens=max_tokens,
27+
output_kind=output_kind,
28+
temperature=0)
29+
async for out in engine.generate(request_id=request_id,
30+
prompt="Hello my name is Robert and",
31+
sampling_params=sampling_params):
32+
33+
num_tokens = len(out.outputs[0].token_ids)
34+
if output_kind == RequestOutputKind.DELTA:
35+
count += num_tokens
36+
else:
37+
count = num_tokens
2738

28-
count += 1
2939
await asyncio.sleep(0.)
3040

3141
return count, request_id
3242

3343

44+
@pytest.mark.parametrize(
45+
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
3446
@pytest.mark.asyncio
35-
async def test_load(monkeypatch):
47+
async def test_load(monkeypatch, output_kind: RequestOutputKind):
3648
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
3749
# so that in the future when we switch, we don't have to change all the
3850
# tests.
39-
with monkeypatch.context() as m:
51+
with monkeypatch.context() as m, ExitStack() as after:
4052
m.setenv("VLLM_USE_V1", "1")
4153

4254
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
55+
after.callback(engine.shutdown)
4356

4457
NUM_REQUESTS = 10000
4558
NUM_EXPECTED_TOKENS = 10
@@ -51,26 +64,33 @@ async def test_load(monkeypatch):
5164
for request_id in request_ids:
5265
tasks.append(
5366
asyncio.create_task(
54-
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
67+
generate(engine, request_id, output_kind,
68+
NUM_EXPECTED_TOKENS)))
5569

5670
# Confirm that we got all the EXPECTED tokens from the requests.
57-
for task in tasks:
71+
done, pending = await asyncio.wait(tasks,
72+
return_when=asyncio.FIRST_EXCEPTION)
73+
for task in pending:
74+
task.cancel()
75+
for task in done:
5876
num_generated_tokens, request_id = await task
5977
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
6078
f"{request_id} generated {num_generated_tokens} but "
6179
f"expected {NUM_EXPECTED_TOKENS}")
6280

6381
assert not engine.output_processor.has_unfinished_requests()
64-
engine.shutdown()
6582

6683

84+
@pytest.mark.parametrize(
85+
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
6786
@pytest.mark.asyncio
68-
async def test_abort(monkeypatch):
87+
async def test_abort(monkeypatch, output_kind: RequestOutputKind):
6988

70-
with monkeypatch.context() as m:
89+
with monkeypatch.context() as m, ExitStack() as after:
7190
m.setenv("VLLM_USE_V1", "1")
7291

7392
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
93+
after.callback(engine.shutdown)
7494

7595
NUM_REQUESTS = 100
7696
NUM_EXPECTED_TOKENS = 100
@@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
83103
for request_id in request_ids:
84104
tasks.append(
85105
asyncio.create_task(
86-
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
106+
generate(engine, request_id, output_kind,
107+
NUM_EXPECTED_TOKENS)))
87108

88109
# API server cancels requests when they disconnect.
89110
for idx in REQUEST_IDS_TO_ABORT:
@@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
108129
# Confirm we can do another generation.
109130
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
110131
task = asyncio.create_task(
111-
generate(engine, request_id, NUM_EXPECTED_TOKENS))
132+
generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
112133
num_generated_tokens, request_id = await task
113134
assert num_generated_tokens == NUM_EXPECTED_TOKENS
114135
assert not engine.output_processor.has_unfinished_requests()
115-
116-
engine.shutdown()

vllm/outputs.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
from dataclasses import dataclass
3-
from typing import Dict, Generic, List, Optional
3+
from typing import Dict, Generic, List, MutableSequence, Optional
44
from typing import Sequence as GenericSequence
55
from typing import Union
66

@@ -162,6 +162,26 @@ def new(
162162
finished=finished,
163163
)
164164

165+
def add(self, next_output: "RequestOutput") -> None:
166+
"""Merge subsequent RequestOutput into this one"""
167+
168+
self.prompt = next_output.prompt
169+
self.prompt_token_ids = next_output.prompt_token_ids
170+
self.prompt_logprobs = next_output.prompt_logprobs
171+
self.finished |= next_output.finished
172+
173+
#TODO assuming n == 1 for now
174+
completion = self.outputs[0]
175+
next_completion = next_output.outputs[0]
176+
completion.text += next_completion.text
177+
if not isinstance(completion.token_ids, MutableSequence):
178+
completion.token_ids = list(completion.token_ids)
179+
completion.token_ids.extend(next_completion.token_ids)
180+
if next_completion.logprobs:
181+
assert completion.logprobs is not None
182+
completion.logprobs.extend(next_completion.logprobs)
183+
completion.cumulative_logprob = next_completion.cumulative_logprob
184+
165185
@classmethod
166186
def from_seq_group(
167187
cls, seq_group: SequenceGroup, use_cache: bool,

vllm/v1/engine/async_llm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from vllm.outputs import RequestOutput
1616
from vllm.pooling_params import PoolingParams
1717
from vllm.prompt_adapter.request import PromptAdapterRequest
18-
from vllm.sampling_params import SamplingParams
18+
from vllm.sampling_params import RequestOutputKind, SamplingParams
1919
from vllm.transformers_utils.tokenizer import AnyTokenizer
2020
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
2121
from vllm.usage.usage_lib import UsageContext
@@ -214,6 +214,14 @@ async def generate(
214214
# task switching under load which helps performance).
215215
out = q.get_nowait() if not q.empty() else await q.get()
216216

217+
# Coalesce any additional queued outputs
218+
while not q.empty():
219+
next_out = q.get_nowait()
220+
if sampling_params.output_kind == RequestOutputKind.DELTA:
221+
out.add(next_out)
222+
else:
223+
out = next_out
224+
217225
# Note: both OutputProcessor and EngineCore handle their
218226
# own request cleanup based on finished.
219227
finished = out.finished

0 commit comments

Comments
 (0)