diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 2c805e18eeb..10f783b21a9 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import ExitStack from typing import List, Tuple import pytest @@ -6,6 +7,7 @@ from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.platforms import current_platform +from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM if not current_platform.is_cuda(): @@ -18,28 +20,39 @@ async def generate(engine: AsyncLLM, request_id: str, + output_kind: RequestOutputKind, max_tokens: int) -> Tuple[int, str]: count = 0 - async for _ in engine.generate(request_id=request_id, - prompt="Hello my name is Robert and", - sampling_params=SamplingParams( - max_tokens=max_tokens, temperature=0)): + sampling_params = SamplingParams(max_tokens=max_tokens, + output_kind=output_kind, + temperature=0) + async for out in engine.generate(request_id=request_id, + prompt="Hello my name is Robert and", + sampling_params=sampling_params): + + num_tokens = len(out.outputs[0].token_ids) + if output_kind == RequestOutputKind.DELTA: + count += num_tokens + else: + count = num_tokens - count += 1 await asyncio.sleep(0.) return count, request_id +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) @pytest.mark.asyncio -async def test_load(monkeypatch): +async def test_load(monkeypatch, output_kind: RequestOutputKind): # TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # so that in the future when we switch, we don't have to change all the # tests. - with monkeypatch.context() as m: + with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") engine = AsyncLLM.from_engine_args(ENGINE_ARGS) + after.callback(engine.shutdown) NUM_REQUESTS = 10000 NUM_EXPECTED_TOKENS = 10 @@ -51,26 +64,33 @@ async def test_load(monkeypatch): for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, NUM_EXPECTED_TOKENS))) + generate(engine, request_id, output_kind, + NUM_EXPECTED_TOKENS))) # Confirm that we got all the EXPECTED tokens from the requests. - for task in tasks: + done, pending = await asyncio.wait(tasks, + return_when=asyncio.FIRST_EXCEPTION) + for task in pending: + task.cancel() + for task in done: num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( f"{request_id} generated {num_generated_tokens} but " f"expected {NUM_EXPECTED_TOKENS}") assert not engine.output_processor.has_unfinished_requests() - engine.shutdown() +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) @pytest.mark.asyncio -async def test_abort(monkeypatch): +async def test_abort(monkeypatch, output_kind: RequestOutputKind): - with monkeypatch.context() as m: + with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") engine = AsyncLLM.from_engine_args(ENGINE_ARGS) + after.callback(engine.shutdown) NUM_REQUESTS = 100 NUM_EXPECTED_TOKENS = 100 @@ -83,7 +103,8 @@ async def test_abort(monkeypatch): for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, NUM_EXPECTED_TOKENS))) + generate(engine, request_id, output_kind, + NUM_EXPECTED_TOKENS))) # API server cancels requests when they disconnect. for idx in REQUEST_IDS_TO_ABORT: @@ -108,9 +129,7 @@ async def test_abort(monkeypatch): # Confirm we can do another generation. request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}" task = asyncio.create_task( - generate(engine, request_id, NUM_EXPECTED_TOKENS)) + generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS)) num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() - - engine.shutdown() diff --git a/vllm/outputs.py b/vllm/outputs.py index 63df7dcf519..25b2265285d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Dict, Generic, List, Optional +from typing import Dict, Generic, List, MutableSequence, Optional from typing import Sequence as GenericSequence from typing import Union @@ -162,6 +162,26 @@ def new( finished=finished, ) + def add(self, next_output: "RequestOutput") -> None: + """Merge subsequent RequestOutput into this one""" + + self.prompt = next_output.prompt + self.prompt_token_ids = next_output.prompt_token_ids + self.prompt_logprobs = next_output.prompt_logprobs + self.finished |= next_output.finished + + #TODO assuming n == 1 for now + completion = self.outputs[0] + next_completion = next_output.outputs[0] + completion.text += next_completion.text + if not isinstance(completion.token_ids, MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend(next_completion.logprobs) + completion.cumulative_logprob = next_completion.cumulative_logprob + @classmethod def from_seq_group( cls, seq_group: SequenceGroup, use_cache: bool, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1505b62504a..6dc68b3a160 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -15,7 +15,7 @@ from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext @@ -214,6 +214,14 @@ async def generate( # task switching under load which helps performance). out = q.get_nowait() if not q.empty() else await q.get() + # Coalesce any additional queued outputs + while not q.empty(): + next_out = q.get_nowait() + if sampling_params.output_kind == RequestOutputKind.DELTA: + out.add(next_out) + else: + out = next_out + # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. finished = out.finished