Skip to content

Commit d21324f

Browse files
danieljannai21Alvant
authored andcommitted
[Frontend] Added support for HF's new continue_final_message parameter (vllm-project#8942)
Signed-off-by: Alvant <[email protected]>
1 parent b846778 commit d21324f

File tree

7 files changed

+105
-31
lines changed

7 files changed

+105
-31
lines changed

tests/entrypoints/openai/test_chat_template.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,28 @@
1212

1313
# Define models, templates, and their corresponding expected outputs
1414
MODEL_TEMPLATE_GENERATON_OUTPUT = [
15-
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
15+
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
1616
Hello<|im_end|>
1717
<|im_start|>assistant
1818
Hi there!<|im_end|>
1919
<|im_start|>user
2020
What is the capital of<|im_end|>
2121
<|im_start|>assistant
2222
"""),
23-
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
23+
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
2424
Hello<|im_end|>
2525
<|im_start|>assistant
2626
Hi there!<|im_end|>
2727
<|im_start|>user
28-
What is the capital of""")
28+
What is the capital of"""),
29+
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
30+
Hello<|im_end|>
31+
<|im_start|>assistant
32+
Hi there!<|im_end|>
33+
<|im_start|>user
34+
What is the capital of<|im_end|>
35+
<|im_start|>assistant
36+
The capital of"""),
2937
]
3038

3139
TEST_MESSAGES = [
@@ -42,6 +50,10 @@
4250
'content': 'What is the capital of'
4351
},
4452
]
53+
ASSISTANT_MESSAGE_TO_CONTINUE = {
54+
'role': 'assistant',
55+
'content': 'The capital of'
56+
}
4557

4658

4759
def test_load_chat_template():
@@ -73,26 +85,30 @@ def test_no_load_chat_template_literallike():
7385

7486

7587
@pytest.mark.parametrize(
76-
"model,template,add_generation_prompt,expected_output",
88+
"model,template,add_generation_prompt,continue_final_message,expected_output",
7789
MODEL_TEMPLATE_GENERATON_OUTPUT)
7890
def test_get_gen_prompt(model, template, add_generation_prompt,
79-
expected_output):
91+
continue_final_message, expected_output):
8092
# Initialize the tokenizer
8193
tokenizer = get_tokenizer(tokenizer_name=model)
8294
template_content = load_chat_template(chat_template=template)
8395

8496
# Create a mock request object using keyword arguments
8597
mock_request = ChatCompletionRequest(
8698
model=model,
87-
messages=TEST_MESSAGES,
88-
add_generation_prompt=add_generation_prompt)
99+
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
100+
if continue_final_message else TEST_MESSAGES,
101+
add_generation_prompt=add_generation_prompt,
102+
continue_final_message=continue_final_message,
103+
)
89104

90105
# Call the function and get the result
91106
result = apply_hf_chat_template(
92107
tokenizer,
93108
conversation=mock_request.messages,
94109
chat_template=mock_request.chat_template or template_content,
95110
add_generation_prompt=mock_request.add_generation_prompt,
111+
continue_final_message=mock_request.continue_final_message,
96112
)
97113

98114
# Test assertion

tests/entrypoints/openai/test_tokenization.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,28 +104,40 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
104104
"role": "user",
105105
"content": "Can I ask a question? vllm1"
106106
}]
107-
108-
prompt = tokenizer.apply_chat_template(
109-
add_generation_prompt=add_generation,
110-
conversation=conversation,
111-
tokenize=False)
112-
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
113-
114-
response = requests.post(base_url + "/tokenize",
115-
json={
116-
"add_generation_prompt":
117-
add_generation,
118-
"add_special_tokens": add_special,
119-
"messages": conversation,
120-
"model": model_name
121-
})
122-
response.raise_for_status()
123-
124-
assert response.json() == {
125-
"tokens": tokens,
126-
"count": len(tokens),
127-
"max_model_len": 8192
128-
}
107+
for continue_final in [False, True]:
108+
if add_generation and continue_final:
109+
continue
110+
if continue_final:
111+
conversation.append({
112+
"role": "assistant",
113+
"content": "Sure,"
114+
})
115+
116+
prompt = tokenizer.apply_chat_template(
117+
add_generation_prompt=add_generation,
118+
continue_final_message=continue_final,
119+
conversation=conversation,
120+
tokenize=False)
121+
tokens = tokenizer.encode(prompt,
122+
add_special_tokens=add_special)
123+
124+
response = requests.post(base_url + "/tokenize",
125+
json={
126+
"add_generation_prompt":
127+
add_generation,
128+
"continue_final_message":
129+
continue_final,
130+
"add_special_tokens": add_special,
131+
"messages": conversation,
132+
"model": model_name
133+
})
134+
response.raise_for_status()
135+
136+
assert response.json() == {
137+
"tokens": tokens,
138+
"count": len(tokens),
139+
"max_model_len": 8192
140+
}
129141

130142

131143
@pytest.mark.asyncio

vllm/entrypoints/chat_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,14 @@ def apply_mistral_chat_template(
542542
if chat_template is not None:
543543
logger.warning(
544544
"'chat_template' cannot be overridden for mistral tokenizer.")
545+
if "add_generation_prompt" in kwargs:
546+
logger.warning(
547+
"'add_generation_prompt' is not supported for mistral tokenizer, "
548+
"so it will be ignored.")
549+
if "continue_final_message" in kwargs:
550+
logger.warning(
551+
"'continue_final_message' is not supported for mistral tokenizer, "
552+
"so it will be ignored.")
545553

546554
return tokenizer.apply_chat_template(
547555
messages=messages,

vllm/entrypoints/llm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def chat(
501501
lora_request: Optional[LoRARequest] = None,
502502
chat_template: Optional[str] = None,
503503
add_generation_prompt: bool = True,
504+
continue_final_message: bool = False,
504505
tools: Optional[List[Dict[str, Any]]] = None,
505506
) -> List[RequestOutput]:
506507
"""
@@ -528,6 +529,9 @@ def chat(
528529
If not provided, the model's default chat template will be used.
529530
add_generation_prompt: If True, adds a generation template
530531
to each message.
532+
continue_final_message: If True, continues the final message in
533+
the conversation instead of starting a new one. Cannot be `True`
534+
if `add_generation_prompt` is also `True`.
531535
532536
Returns:
533537
A list of ``RequestOutput`` objects containing the generated
@@ -559,6 +563,7 @@ def chat(
559563
messages=msgs,
560564
chat_template=chat_template,
561565
add_generation_prompt=add_generation_prompt,
566+
continue_final_message=continue_final_message,
562567
tools=tools,
563568
)
564569
else:
@@ -567,6 +572,7 @@ def chat(
567572
conversation=conversation,
568573
chat_template=chat_template,
569574
add_generation_prompt=add_generation_prompt,
575+
continue_final_message=continue_final_message,
570576
tools=tools,
571577
)
572578

vllm/entrypoints/openai/protocol.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
211211
"This is a parameter used by chat template in tokenizer config of the "
212212
"model."),
213213
)
214+
continue_final_message: bool = Field(
215+
default=False,
216+
description=
217+
("If this is set, the chat will be formatted so that the final "
218+
"message in the chat is open-ended, without any EOS tokens. The "
219+
"model will continue this message rather than starting a new one. "
220+
"This allows you to \"prefill\" part of the model's response for it. "
221+
"Cannot be used at the same time as `add_generation_prompt`."),
222+
)
214223
add_special_tokens: bool = Field(
215224
default=False,
216225
description=(
@@ -431,6 +440,15 @@ def check_tool_usage(cls, data):
431440
" of the specified `tools`")
432441
return data
433442

443+
@model_validator(mode="before")
444+
@classmethod
445+
def check_generation_prompt(cls, data):
446+
if data.get("continue_final_message") and data.get(
447+
"add_generation_prompt"):
448+
raise ValueError("Cannot set both `continue_final_message` and "
449+
"`add_generation_prompt` to True.")
450+
return data
451+
434452

435453
class CompletionRequest(OpenAIBaseModel):
436454
# Ordered by official OpenAI API documentation
@@ -862,8 +880,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
862880
messages: List[ChatCompletionMessageParam]
863881

864882
add_generation_prompt: bool = Field(default=True)
883+
continue_final_message: bool = Field(default=False)
865884
add_special_tokens: bool = Field(default=False)
866885

886+
@model_validator(mode="before")
887+
@classmethod
888+
def check_generation_prompt(cls, data):
889+
if data.get("continue_final_message") and data.get(
890+
"add_generation_prompt"):
891+
raise ValueError("Cannot set both `continue_final_message` and "
892+
"`add_generation_prompt` to True.")
893+
return data
894+
867895

868896
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
869897

vllm/entrypoints/openai/serving_chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def create_chat_completion(
140140
messages=request.messages,
141141
chat_template=request.chat_template or self.chat_template,
142142
add_generation_prompt=request.add_generation_prompt,
143+
continue_final_message=request.continue_final_message,
143144
tools=tool_dicts,
144145
documents=request.documents,
145146
**(request.chat_template_kwargs or {}),
@@ -150,6 +151,7 @@ async def create_chat_completion(
150151
conversation=conversation,
151152
chat_template=request.chat_template or self.chat_template,
152153
add_generation_prompt=request.add_generation_prompt,
154+
continue_final_message=request.continue_final_message,
153155
tools=tool_dicts,
154156
documents=request.documents,
155157
**(request.chat_template_kwargs or {}),
@@ -361,7 +363,7 @@ async def chat_completion_stream_generator(
361363

362364
# Send response to echo the input portion of the
363365
# last message
364-
if request.echo:
366+
if request.echo or request.continue_final_message:
365367
last_msg_content: str = ""
366368
if conversation and "content" in conversation[
367369
-1] and conversation[-1].get("role") == role:
@@ -716,7 +718,7 @@ async def chat_completion_full_generator(
716718
stop_reason=output.stop_reason)
717719
choices.append(choice_data)
718720

719-
if request.echo:
721+
if request.echo or request.continue_final_message:
720722
last_msg_content = ""
721723
if conversation and "content" in conversation[-1] and conversation[
722724
-1].get("role") == role:

vllm/entrypoints/openai/serving_tokenization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@ async def create_tokenize(
8787
messages=request.messages,
8888
chat_template=self.chat_template,
8989
add_generation_prompt=request.add_generation_prompt,
90+
continue_final_message=request.continue_final_message,
9091
)
9192
else:
9293
prompt = apply_hf_chat_template(
9394
tokenizer,
9495
conversation=conversation,
9596
chat_template=self.chat_template,
9697
add_generation_prompt=request.add_generation_prompt,
98+
continue_final_message=request.continue_final_message,
9799
)
98100
else:
99101
prompt = request.prompt

0 commit comments

Comments
 (0)