1
1
# SPDX-License-Identifier: Apache-2.0
2
+ import weakref
2
3
3
4
import pytest
4
5
5
6
from vllm import LLM
7
+ from vllm .distributed import cleanup_dist_env_and_memory
6
8
7
9
from ..openai .test_vision import TEST_IMAGE_URLS
8
10
9
11
10
- def test_chat ():
11
- llm = LLM (model = "meta-llama/Llama-3.2-1B-Instruct" )
12
+ @pytest .fixture (scope = "function" )
13
+ def text_llm ():
14
+ # pytest caches the fixture so we use weakref.proxy to
15
+ # enable garbage collection
16
+ llm = LLM (model = "meta-llama/Llama-3.2-1B-Instruct" ,
17
+ enforce_eager = True ,
18
+ seed = 0 )
12
19
20
+ with llm .deprecate_legacy_api ():
21
+ yield weakref .proxy (llm )
22
+
23
+ del llm
24
+
25
+ cleanup_dist_env_and_memory ()
26
+
27
+
28
+ def test_chat (text_llm ):
13
29
prompt1 = "Explain the concept of entropy."
14
30
messages = [
15
31
{
@@ -21,13 +37,11 @@ def test_chat():
21
37
"content" : prompt1
22
38
},
23
39
]
24
- outputs = llm .chat (messages )
40
+ outputs = text_llm .chat (messages )
25
41
assert len (outputs ) == 1
26
42
27
43
28
- def test_multi_chat ():
29
- llm = LLM (model = "meta-llama/Llama-3.2-1B-Instruct" )
30
-
44
+ def test_multi_chat (text_llm ):
31
45
prompt1 = "Explain the concept of entropy."
32
46
prompt2 = "Explain what among us is."
33
47
@@ -55,22 +69,35 @@ def test_multi_chat():
55
69
56
70
messages = [conversation1 , conversation2 ]
57
71
58
- outputs = llm .chat (messages )
72
+ outputs = text_llm .chat (messages )
59
73
assert len (outputs ) == 2
60
74
61
75
62
- @pytest .mark .parametrize ("image_urls" ,
63
- [[TEST_IMAGE_URLS [0 ], TEST_IMAGE_URLS [1 ]]])
64
- def test_chat_multi_image (image_urls : list [str ]):
76
+ @pytest .fixture (scope = "function" )
77
+ def vision_llm ():
78
+ # pytest caches the fixture so we use weakref.proxy to
79
+ # enable garbage collection
65
80
llm = LLM (
66
81
model = "microsoft/Phi-3.5-vision-instruct" ,
67
82
max_model_len = 4096 ,
68
83
max_num_seqs = 5 ,
69
84
enforce_eager = True ,
70
85
trust_remote_code = True ,
71
86
limit_mm_per_prompt = {"image" : 2 },
87
+ seed = 0 ,
72
88
)
73
89
90
+ with llm .deprecate_legacy_api ():
91
+ yield weakref .proxy (llm )
92
+
93
+ del llm
94
+
95
+ cleanup_dist_env_and_memory ()
96
+
97
+
98
+ @pytest .mark .parametrize ("image_urls" ,
99
+ [[TEST_IMAGE_URLS [0 ], TEST_IMAGE_URLS [1 ]]])
100
+ def test_chat_multi_image (vision_llm , image_urls : list [str ]):
74
101
messages = [{
75
102
"role" :
76
103
"user" ,
@@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]):
87
114
},
88
115
],
89
116
}]
90
- outputs = llm .chat (messages )
117
+ outputs = vision_llm .chat (messages )
91
118
assert len (outputs ) >= 0
92
119
93
120
94
- def test_llm_chat_tokenization_no_double_bos ():
121
+ def test_llm_chat_tokenization_no_double_bos (text_llm ):
95
122
"""
96
123
LLM.chat() should not add special tokens when using chat templates.
97
124
Check we get a single BOS token for llama chat.
98
125
"""
99
- llm = LLM (model = "meta-llama/Llama-3.2-1B-Instruct" , enforce_eager = True )
100
126
messages = [
101
127
{
102
128
"role" : "system" ,
@@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos():
107
133
"content" : "Hello!"
108
134
},
109
135
]
110
- outputs = llm .chat (messages )
136
+ outputs = text_llm .chat (messages )
111
137
assert len (outputs ) == 1
112
- prompt_token_ids = getattr (outputs [0 ], "prompt_token_ids" , None )
138
+
139
+ prompt_token_ids = outputs [0 ].prompt_token_ids
113
140
assert prompt_token_ids is not None
114
141
115
- bos_token = llm .get_tokenizer ().bos_token_id
142
+ bos_token = text_llm .get_tokenizer ().bos_token_id
116
143
117
144
# Ensure we have a single BOS
118
145
assert prompt_token_ids [0 ] == bos_token
119
146
assert prompt_token_ids [1 ] != bos_token , "Double BOS"
147
+
148
+
149
+ @pytest .fixture (scope = "function" )
150
+ def thinking_llm ():
151
+ # pytest caches the fixture so we use weakref.proxy to
152
+ # enable garbage collection
153
+ llm = LLM (
154
+ model = "Qwen/Qwen3-0.6B" ,
155
+ max_model_len = 4096 ,
156
+ enforce_eager = True ,
157
+ seed = 0 ,
158
+ )
159
+
160
+ with llm .deprecate_legacy_api ():
161
+ yield weakref .proxy (llm )
162
+
163
+ del llm
164
+
165
+ cleanup_dist_env_and_memory ()
166
+
167
+
168
+ @pytest .mark .parametrize ("enable_thinking" , [True , False ])
169
+ def test_chat_extra_kwargs (thinking_llm , enable_thinking ):
170
+ messages = [
171
+ {
172
+ "role" : "system" ,
173
+ "content" : "You are a helpful assistant"
174
+ },
175
+ {
176
+ "role" : "user" ,
177
+ "content" : "What is 1+1?"
178
+ },
179
+ ]
180
+
181
+ outputs = thinking_llm .chat (
182
+ messages ,
183
+ chat_template_kwargs = {"enable_thinking" : enable_thinking },
184
+ )
185
+ assert len (outputs ) == 1
186
+
187
+ prompt_token_ids = outputs [0 ].prompt_token_ids
188
+ assert prompt_token_ids is not None
189
+
190
+ think_id = thinking_llm .get_tokenizer ().get_vocab ()["<think>" ]
191
+
192
+ if enable_thinking :
193
+ assert think_id not in prompt_token_ids
194
+ else :
195
+ # The chat template includes dummy thinking process
196
+ assert think_id in prompt_token_ids
0 commit comments