Skip to content

Commit 3ff92d0

Browse files
committed
Update AsyncLLM tests
Signed-off-by: Nick Hill <[email protected]>
1 parent 0d2f3dd commit 3ff92d0

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tests/v1/engine/test_async_llm.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from contextlib import ExitStack
23
from typing import List, Tuple
34

45
import pytest
@@ -20,12 +21,13 @@
2021
async def generate(engine: AsyncLLM, request_id: str,
2122
max_tokens: int) -> Tuple[int, str]:
2223
count = 0
23-
async for _ in engine.generate(request_id=request_id,
24-
prompt="Hello my name is Robert and",
25-
sampling_params=SamplingParams(
26-
max_tokens=max_tokens, temperature=0)):
24+
async for out in engine.generate(
25+
request_id=request_id,
26+
prompt="Hello my name is Robert and",
27+
sampling_params=SamplingParams(max_tokens=max_tokens,
28+
temperature=0)):
2729

28-
count += 1
30+
count += len(out.outputs[0].token_ids)
2931
await asyncio.sleep(0.)
3032

3133
return count, request_id
@@ -36,10 +38,11 @@ async def test_load(monkeypatch):
3638
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
3739
# so that in the future when we switch, we don't have to change all the
3840
# tests.
39-
with monkeypatch.context() as m:
41+
with monkeypatch.context() as m, ExitStack() as after:
4042
m.setenv("VLLM_USE_V1", "1")
4143

4244
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
45+
after.callback(engine.shutdown)
4346

4447
NUM_REQUESTS = 10000
4548
NUM_EXPECTED_TOKENS = 10
@@ -61,16 +64,16 @@ async def test_load(monkeypatch):
6164
f"expected {NUM_EXPECTED_TOKENS}")
6265

6366
assert not engine.output_processor.has_unfinished_requests()
64-
engine.shutdown()
6567

6668

6769
@pytest.mark.asyncio
6870
async def test_abort(monkeypatch):
6971

70-
with monkeypatch.context() as m:
72+
with monkeypatch.context() as m, ExitStack() as after:
7173
m.setenv("VLLM_USE_V1", "1")
7274

7375
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
76+
after.callback(engine.shutdown)
7477

7578
NUM_REQUESTS = 100
7679
NUM_EXPECTED_TOKENS = 100
@@ -112,5 +115,3 @@ async def test_abort(monkeypatch):
112115
num_generated_tokens, request_id = await task
113116
assert num_generated_tokens == NUM_EXPECTED_TOKENS
114117
assert not engine.output_processor.has_unfinished_requests()
115-
116-
engine.shutdown()

0 commit comments

Comments
 (0)