Skip to content

Commit 88ad9ec

Browse files
[Frontend] Support chat_template_kwargs in LLM.chat (#17356)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 40896bd commit 88ad9ec

File tree

2 files changed

+106
-24
lines changed

2 files changed

+106
-24
lines changed

tests/entrypoints/llm/test_chat.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,31 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import weakref
23

34
import pytest
45

56
from vllm import LLM
7+
from vllm.distributed import cleanup_dist_env_and_memory
68

79
from ..openai.test_vision import TEST_IMAGE_URLS
810

911

10-
def test_chat():
11-
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
12+
@pytest.fixture(scope="function")
13+
def text_llm():
14+
# pytest caches the fixture so we use weakref.proxy to
15+
# enable garbage collection
16+
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
17+
enforce_eager=True,
18+
seed=0)
1219

20+
with llm.deprecate_legacy_api():
21+
yield weakref.proxy(llm)
22+
23+
del llm
24+
25+
cleanup_dist_env_and_memory()
26+
27+
28+
def test_chat(text_llm):
1329
prompt1 = "Explain the concept of entropy."
1430
messages = [
1531
{
@@ -21,13 +37,11 @@ def test_chat():
2137
"content": prompt1
2238
},
2339
]
24-
outputs = llm.chat(messages)
40+
outputs = text_llm.chat(messages)
2541
assert len(outputs) == 1
2642

2743

28-
def test_multi_chat():
29-
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
30-
44+
def test_multi_chat(text_llm):
3145
prompt1 = "Explain the concept of entropy."
3246
prompt2 = "Explain what among us is."
3347

@@ -55,22 +69,35 @@ def test_multi_chat():
5569

5670
messages = [conversation1, conversation2]
5771

58-
outputs = llm.chat(messages)
72+
outputs = text_llm.chat(messages)
5973
assert len(outputs) == 2
6074

6175

62-
@pytest.mark.parametrize("image_urls",
63-
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
64-
def test_chat_multi_image(image_urls: list[str]):
76+
@pytest.fixture(scope="function")
77+
def vision_llm():
78+
# pytest caches the fixture so we use weakref.proxy to
79+
# enable garbage collection
6580
llm = LLM(
6681
model="microsoft/Phi-3.5-vision-instruct",
6782
max_model_len=4096,
6883
max_num_seqs=5,
6984
enforce_eager=True,
7085
trust_remote_code=True,
7186
limit_mm_per_prompt={"image": 2},
87+
seed=0,
7288
)
7389

90+
with llm.deprecate_legacy_api():
91+
yield weakref.proxy(llm)
92+
93+
del llm
94+
95+
cleanup_dist_env_and_memory()
96+
97+
98+
@pytest.mark.parametrize("image_urls",
99+
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
100+
def test_chat_multi_image(vision_llm, image_urls: list[str]):
74101
messages = [{
75102
"role":
76103
"user",
@@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]):
87114
},
88115
],
89116
}]
90-
outputs = llm.chat(messages)
117+
outputs = vision_llm.chat(messages)
91118
assert len(outputs) >= 0
92119

93120

94-
def test_llm_chat_tokenization_no_double_bos():
121+
def test_llm_chat_tokenization_no_double_bos(text_llm):
95122
"""
96123
LLM.chat() should not add special tokens when using chat templates.
97124
Check we get a single BOS token for llama chat.
98125
"""
99-
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True)
100126
messages = [
101127
{
102128
"role": "system",
@@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos():
107133
"content": "Hello!"
108134
},
109135
]
110-
outputs = llm.chat(messages)
136+
outputs = text_llm.chat(messages)
111137
assert len(outputs) == 1
112-
prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None)
138+
139+
prompt_token_ids = outputs[0].prompt_token_ids
113140
assert prompt_token_ids is not None
114141

115-
bos_token = llm.get_tokenizer().bos_token_id
142+
bos_token = text_llm.get_tokenizer().bos_token_id
116143

117144
# Ensure we have a single BOS
118145
assert prompt_token_ids[0] == bos_token
119146
assert prompt_token_ids[1] != bos_token, "Double BOS"
147+
148+
149+
@pytest.fixture(scope="function")
150+
def thinking_llm():
151+
# pytest caches the fixture so we use weakref.proxy to
152+
# enable garbage collection
153+
llm = LLM(
154+
model="Qwen/Qwen3-0.6B",
155+
max_model_len=4096,
156+
enforce_eager=True,
157+
seed=0,
158+
)
159+
160+
with llm.deprecate_legacy_api():
161+
yield weakref.proxy(llm)
162+
163+
del llm
164+
165+
cleanup_dist_env_and_memory()
166+
167+
168+
@pytest.mark.parametrize("enable_thinking", [True, False])
169+
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
170+
messages = [
171+
{
172+
"role": "system",
173+
"content": "You are a helpful assistant"
174+
},
175+
{
176+
"role": "user",
177+
"content": "What is 1+1?"
178+
},
179+
]
180+
181+
outputs = thinking_llm.chat(
182+
messages,
183+
chat_template_kwargs={"enable_thinking": enable_thinking},
184+
)
185+
assert len(outputs) == 1
186+
187+
prompt_token_ids = outputs[0].prompt_token_ids
188+
assert prompt_token_ids is not None
189+
190+
think_id = thinking_llm.get_tokenizer().get_vocab()["<think>"]
191+
192+
if enable_thinking:
193+
assert think_id not in prompt_token_ids
194+
else:
195+
# The chat template includes dummy thinking process
196+
assert think_id in prompt_token_ids

vllm/entrypoints/llm.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ def chat(
656656
add_generation_prompt: bool = True,
657657
continue_final_message: bool = False,
658658
tools: Optional[list[dict[str, Any]]] = None,
659+
chat_template_kwargs: Optional[dict[str, Any]] = None,
659660
mm_processor_kwargs: Optional[dict[str, Any]] = None,
660661
) -> list[RequestOutput]:
661662
"""
@@ -696,6 +697,8 @@ def chat(
696697
continue_final_message: If True, continues the final message in
697698
the conversation instead of starting a new one. Cannot be
698699
``True`` if ``add_generation_prompt`` is also ``True``.
700+
chat_template_kwargs: Additional kwargs to pass to the chat
701+
template.
699702
mm_processor_kwargs: Multimodal processor kwarg overrides for this
700703
chat request. Only used for offline requests.
701704
@@ -726,6 +729,14 @@ def chat(
726729
trust_remote_code=model_config.trust_remote_code,
727730
)
728731

732+
_chat_template_kwargs: dict[str, Any] = dict(
733+
chat_template=chat_template,
734+
add_generation_prompt=add_generation_prompt,
735+
continue_final_message=continue_final_message,
736+
tools=tools,
737+
)
738+
_chat_template_kwargs.update(chat_template_kwargs or {})
739+
729740
prompts: list[Union[TokensPrompt, TextPrompt]] = []
730741

731742
for msgs in list_of_messages:
@@ -743,20 +754,14 @@ def chat(
743754
prompt_token_ids = apply_mistral_chat_template(
744755
tokenizer,
745756
messages=msgs,
746-
chat_template=chat_template,
747-
tools=tools,
748-
add_generation_prompt=add_generation_prompt,
749-
continue_final_message=continue_final_message,
757+
**_chat_template_kwargs,
750758
)
751759
else:
752760
prompt_str = apply_hf_chat_template(
753761
tokenizer,
754762
trust_remote_code=model_config.trust_remote_code,
755763
conversation=conversation,
756-
chat_template=chat_template,
757-
tools=tools,
758-
add_generation_prompt=add_generation_prompt,
759-
continue_final_message=continue_final_message,
764+
**_chat_template_kwargs,
760765
)
761766
# Special tokens are already included in chat templates so
762767
# should not be added by the tokenizer in this case.

0 commit comments

Comments
 (0)