1
1
import asyncio
2
+ from contextlib import ExitStack
2
3
from typing import List , Tuple
3
4
4
5
import pytest
5
6
6
7
from vllm import SamplingParams
7
8
from vllm .engine .arg_utils import AsyncEngineArgs
8
9
from vllm .platforms import current_platform
10
+ from vllm .sampling_params import RequestOutputKind
9
11
from vllm .v1 .engine .async_llm import AsyncLLM
10
12
11
13
if not current_platform .is_cuda ():
18
20
19
21
20
22
async def generate (engine : AsyncLLM , request_id : str ,
23
+ output_kind : RequestOutputKind ,
21
24
max_tokens : int ) -> Tuple [int , str ]:
22
25
count = 0
23
- async for _ in engine .generate (request_id = request_id ,
24
- prompt = "Hello my name is Robert and" ,
25
- sampling_params = SamplingParams (
26
- max_tokens = max_tokens , temperature = 0 )):
26
+ sampling_params = SamplingParams (max_tokens = max_tokens ,
27
+ output_kind = output_kind ,
28
+ temperature = 0 )
29
+ async for out in engine .generate (request_id = request_id ,
30
+ prompt = "Hello my name is Robert and" ,
31
+ sampling_params = sampling_params ):
32
+
33
+ num_tokens = len (out .outputs [0 ].token_ids )
34
+ if output_kind == RequestOutputKind .DELTA :
35
+ count += num_tokens
36
+ else :
37
+ count = num_tokens
27
38
28
- count += 1
29
39
await asyncio .sleep (0. )
30
40
31
41
return count , request_id
32
42
33
43
44
+ @pytest .mark .parametrize (
45
+ "output_kind" , [RequestOutputKind .DELTA , RequestOutputKind .FINAL_ONLY ])
34
46
@pytest .mark .asyncio
35
- async def test_load (monkeypatch ):
47
+ async def test_load (monkeypatch , output_kind : RequestOutputKind ):
36
48
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
37
49
# so that in the future when we switch, we don't have to change all the
38
50
# tests.
39
- with monkeypatch .context () as m :
51
+ with monkeypatch .context () as m , ExitStack () as after :
40
52
m .setenv ("VLLM_USE_V1" , "1" )
41
53
42
54
engine = AsyncLLM .from_engine_args (ENGINE_ARGS )
55
+ after .callback (engine .shutdown )
43
56
44
57
NUM_REQUESTS = 10000
45
58
NUM_EXPECTED_TOKENS = 10
@@ -51,26 +64,33 @@ async def test_load(monkeypatch):
51
64
for request_id in request_ids :
52
65
tasks .append (
53
66
asyncio .create_task (
54
- generate (engine , request_id , NUM_EXPECTED_TOKENS )))
67
+ generate (engine , request_id , output_kind ,
68
+ NUM_EXPECTED_TOKENS )))
55
69
56
70
# Confirm that we got all the EXPECTED tokens from the requests.
57
- for task in tasks :
71
+ done , pending = await asyncio .wait (tasks ,
72
+ return_when = asyncio .FIRST_EXCEPTION )
73
+ for task in pending :
74
+ task .cancel ()
75
+ for task in done :
58
76
num_generated_tokens , request_id = await task
59
77
assert num_generated_tokens == NUM_EXPECTED_TOKENS , (
60
78
f"{ request_id } generated { num_generated_tokens } but "
61
79
f"expected { NUM_EXPECTED_TOKENS } " )
62
80
63
81
assert not engine .output_processor .has_unfinished_requests ()
64
- engine .shutdown ()
65
82
66
83
84
+ @pytest .mark .parametrize (
85
+ "output_kind" , [RequestOutputKind .DELTA , RequestOutputKind .FINAL_ONLY ])
67
86
@pytest .mark .asyncio
68
- async def test_abort (monkeypatch ):
87
+ async def test_abort (monkeypatch , output_kind : RequestOutputKind ):
69
88
70
- with monkeypatch .context () as m :
89
+ with monkeypatch .context () as m , ExitStack () as after :
71
90
m .setenv ("VLLM_USE_V1" , "1" )
72
91
73
92
engine = AsyncLLM .from_engine_args (ENGINE_ARGS )
93
+ after .callback (engine .shutdown )
74
94
75
95
NUM_REQUESTS = 100
76
96
NUM_EXPECTED_TOKENS = 100
@@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
83
103
for request_id in request_ids :
84
104
tasks .append (
85
105
asyncio .create_task (
86
- generate (engine , request_id , NUM_EXPECTED_TOKENS )))
106
+ generate (engine , request_id , output_kind ,
107
+ NUM_EXPECTED_TOKENS )))
87
108
88
109
# API server cancels requests when they disconnect.
89
110
for idx in REQUEST_IDS_TO_ABORT :
@@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
108
129
# Confirm we can do another generation.
109
130
request_id = f"request-{ REQUEST_IDS_TO_ABORT [0 ]} "
110
131
task = asyncio .create_task (
111
- generate (engine , request_id , NUM_EXPECTED_TOKENS ))
132
+ generate (engine , request_id , output_kind , NUM_EXPECTED_TOKENS ))
112
133
num_generated_tokens , request_id = await task
113
134
assert num_generated_tokens == NUM_EXPECTED_TOKENS
114
135
assert not engine .output_processor .has_unfinished_requests ()
115
-
116
- engine .shutdown ()
0 commit comments