Skip to content

[Frontend] Support chat_template_kwargs in LLM.chat #17356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 93 additions & 16 deletions tests/entrypoints/llm/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
import weakref

import pytest

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory

from ..openai.test_vision import TEST_IMAGE_URLS


def test_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
@pytest.fixture(scope="function")
def text_llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
seed=0)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)

del llm

cleanup_dist_env_and_memory()


def test_chat(text_llm):
prompt1 = "Explain the concept of entropy."
messages = [
{
Expand All @@ -21,13 +37,11 @@ def test_chat():
"content": prompt1
},
]
outputs = llm.chat(messages)
outputs = text_llm.chat(messages)
assert len(outputs) == 1


def test_multi_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")

def test_multi_chat(text_llm):
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."

Expand Down Expand Up @@ -55,22 +69,35 @@ def test_multi_chat():

messages = [conversation1, conversation2]

outputs = llm.chat(messages)
outputs = text_llm.chat(messages)
assert len(outputs) == 2


@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: list[str]):
@pytest.fixture(scope="function")
def vision_llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
seed=0,
)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)

del llm

cleanup_dist_env_and_memory()


@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(vision_llm, image_urls: list[str]):
messages = [{
"role":
"user",
Expand All @@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]):
},
],
}]
outputs = llm.chat(messages)
outputs = vision_llm.chat(messages)
assert len(outputs) >= 0


def test_llm_chat_tokenization_no_double_bos():
def test_llm_chat_tokenization_no_double_bos(text_llm):
"""
LLM.chat() should not add special tokens when using chat templates.
Check we get a single BOS token for llama chat.
"""
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True)
messages = [
{
"role": "system",
Expand All @@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos():
"content": "Hello!"
},
]
outputs = llm.chat(messages)
outputs = text_llm.chat(messages)
assert len(outputs) == 1
prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None)

prompt_token_ids = outputs[0].prompt_token_ids
assert prompt_token_ids is not None

bos_token = llm.get_tokenizer().bos_token_id
bos_token = text_llm.get_tokenizer().bos_token_id

# Ensure we have a single BOS
assert prompt_token_ids[0] == bos_token
assert prompt_token_ids[1] != bos_token, "Double BOS"


@pytest.fixture(scope="function")
def thinking_llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(
model="Qwen/Qwen3-0.6B",
max_model_len=4096,
enforce_eager=True,
seed=0,
)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)

del llm

cleanup_dist_env_and_memory()


@pytest.mark.parametrize("enable_thinking", [True, False])
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is 1+1?"
},
]

outputs = thinking_llm.chat(
messages,
chat_template_kwargs={"enable_thinking": enable_thinking},
)
assert len(outputs) == 1

prompt_token_ids = outputs[0].prompt_token_ids
assert prompt_token_ids is not None

think_id = thinking_llm.get_tokenizer().get_vocab()["<think>"]

if enable_thinking:
assert think_id not in prompt_token_ids
else:
# The chat template includes dummy thinking process
assert think_id in prompt_token_ids
21 changes: 13 additions & 8 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def chat(
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[list[dict[str, Any]]] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
) -> list[RequestOutput]:
"""
Expand Down Expand Up @@ -696,6 +697,8 @@ def chat(
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
``True`` if ``add_generation_prompt`` is also ``True``.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Expand Down Expand Up @@ -726,6 +729,14 @@ def chat(
trust_remote_code=model_config.trust_remote_code,
)

_chat_template_kwargs: dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
_chat_template_kwargs.update(chat_template_kwargs or {})

prompts: list[Union[TokensPrompt, TextPrompt]] = []

for msgs in list_of_messages:
Expand All @@ -743,20 +754,14 @@ def chat(
prompt_token_ids = apply_mistral_chat_template(
tokenizer,
messages=msgs,
chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
**_chat_template_kwargs,
)
else:
prompt_str = apply_hf_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
**_chat_template_kwargs,
)
# Special tokens are already included in chat templates so
# should not be added by the tokenizer in this case.
Expand Down