Skip to content

Commit 7011645

Browse files
authored
[BugFix][Frontend] Fix LLM.chat() tokenization (#16081)
Signed-off-by: Nick Hill <[email protected]>
1 parent 65e262b commit 7011645

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

tests/entrypoints/llm/test_chat.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,31 @@ def test_chat_multi_image(image_urls: list[str]):
8989
}]
9090
outputs = llm.chat(messages)
9191
assert len(outputs) >= 0
92+
93+
94+
def test_llm_chat_tokenization_no_double_bos():
95+
"""
96+
LLM.chat() should not add special tokens when using chat templates.
97+
Check we get a single BOS token for llama chat.
98+
"""
99+
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True)
100+
messages = [
101+
{
102+
"role": "system",
103+
"content": "You are a helpful assistant"
104+
},
105+
{
106+
"role": "user",
107+
"content": "Hello!"
108+
},
109+
]
110+
outputs = llm.chat(messages)
111+
assert len(outputs) == 1
112+
prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None)
113+
assert prompt_token_ids is not None
114+
115+
bos_token = llm.get_tokenizer().bos_token_id
116+
117+
# Ensure we have a single BOS
118+
assert prompt_token_ids[0] == bos_token
119+
assert prompt_token_ids[1] != bos_token, "Double BOS"

vllm/entrypoints/llm.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class LLM:
117117
disable_async_output_proc: Disable async output processing.
118118
This may result in lower performance.
119119
hf_token: The token to use as HTTP bearer authorization for remote files
120-
. If `True`, will use the token generated when running
120+
. If `True`, will use the token generated when running
121121
`huggingface-cli login` (stored in `~/.huggingface`).
122122
hf_overrides: If a dictionary, contains arguments to be forwarded to the
123123
HuggingFace config. If a callable, it is called to update the
@@ -251,8 +251,12 @@ def __init__(
251251
self.request_counter = Counter()
252252
self.default_sampling_params: Union[dict[str, Any], None] = None
253253

254-
def get_tokenizer(self) -> AnyTokenizer:
255-
return self.llm_engine.get_tokenizer_group().tokenizer
254+
def get_tokenizer(
255+
self,
256+
lora_request: Optional[LoRARequest] = None,
257+
) -> AnyTokenizer:
258+
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
259+
lora_request)
256260

257261
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
258262
tokenizer_group = self.llm_engine.get_tokenizer_group()
@@ -712,7 +716,7 @@ def chat(
712716
cast(list[ChatCompletionMessageParam], messages)
713717
]
714718

715-
tokenizer = self.get_tokenizer()
719+
tokenizer = self.get_tokenizer(lora_request)
716720
model_config = self.llm_engine.get_model_config()
717721
resolved_content_format = resolve_chat_template_content_format(
718722
chat_template,
@@ -735,9 +739,8 @@ def chat(
735739
content_format=resolved_content_format,
736740
)
737741

738-
prompt_data: Union[str, list[int]]
739742
if isinstance(tokenizer, MistralTokenizer):
740-
prompt_data = apply_mistral_chat_template(
743+
prompt_token_ids = apply_mistral_chat_template(
741744
tokenizer,
742745
messages=msgs,
743746
chat_template=chat_template,
@@ -746,7 +749,7 @@ def chat(
746749
continue_final_message=continue_final_message,
747750
)
748751
else:
749-
prompt_data = apply_hf_chat_template(
752+
prompt_str = apply_hf_chat_template(
750753
tokenizer,
751754
trust_remote_code=model_config.trust_remote_code,
752755
conversation=conversation,
@@ -755,12 +758,12 @@ def chat(
755758
add_generation_prompt=add_generation_prompt,
756759
continue_final_message=continue_final_message,
757760
)
761+
# Special tokens are already included in chat templates so
762+
# should not be added by the tokenizer in this case.
763+
prompt_token_ids = tokenizer.encode(prompt_str,
764+
add_special_tokens=False)
758765

759-
prompt: Union[TokensPrompt, TextPrompt]
760-
if is_list_of(prompt_data, int):
761-
prompt = TokensPrompt(prompt_token_ids=prompt_data)
762-
else:
763-
prompt = TextPrompt(prompt=prompt_data)
766+
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
764767

765768
if mm_data is not None:
766769
prompt["multi_modal_data"] = mm_data
@@ -1059,8 +1062,6 @@ def _embedding_score(
10591062
if len(encoded_output_1) == 1:
10601063
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
10611064

1062-
scores: list[PoolingRequestOutput] = []
1063-
10641065
scores = _cosine_similarity(tokenizer=tokenizer,
10651066
embed_1=encoded_output_1,
10661067
embed_2=encoded_output_2)

0 commit comments

Comments
 (0)