Skip to content

Commit aa3030e

Browse files
fgreinacherkwang1012
authored andcommitted
[Bugfix] Support missing tool parameters in mistral tokenizer (vllm-project#12884)
Signed-off-by: Florian Greinacher <[email protected]>
1 parent 16e5b69 commit aa3030e

File tree

2 files changed

+88
-19
lines changed

2 files changed

+88
-19
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
from mistral_common.protocol.instruct.messages import UserMessage
5+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
6+
from mistral_common.protocol.instruct.tool_calls import Function, Tool
7+
8+
from vllm.transformers_utils.tokenizers.mistral import (
9+
make_mistral_chat_completion_request)
10+
11+
12+
# yapf: enable
13+
@pytest.mark.parametrize(
14+
"openai_request,expected_mistral_request",
15+
[(
16+
{
17+
"messages": [{
18+
"role": "user",
19+
"content": "What is the current local date and time?",
20+
}],
21+
"tools": [{
22+
"type": "function",
23+
"function": {
24+
"description": "Fetch the current local date and time.",
25+
"name": "get_current_time",
26+
},
27+
}],
28+
},
29+
ChatCompletionRequest(
30+
messages=[
31+
UserMessage(content="What is the current local date and time?")
32+
],
33+
tools=[
34+
Tool(
35+
type="function",
36+
function=Function(
37+
name="get_current_time",
38+
description="Fetch the current local date and time.",
39+
parameters={},
40+
),
41+
)
42+
],
43+
),
44+
)],
45+
)
46+
def test_make_mistral_chat_completion_request(openai_request,
47+
expected_mistral_request):
48+
assert (make_mistral_chat_completion_request(
49+
openai_request["messages"],
50+
openai_request["tools"]) == expected_mistral_request)

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
104104
return matched_files[0]
105105

106106

107+
def make_mistral_chat_completion_request(
108+
messages: List["ChatCompletionMessageParam"],
109+
tools: Optional[List[Dict[str,
110+
Any]]] = None) -> "ChatCompletionRequest":
111+
last_message = cast(Dict[str, Any], messages[-1])
112+
if last_message["role"] == "assistant":
113+
last_message["prefix"] = True
114+
115+
last_message = cast(Dict[str, Any], messages[-1])
116+
if last_message["role"] == "assistant":
117+
last_message["prefix"] = True
118+
119+
# mistral-common requires AssistantMessage content to be string [1].
120+
#
121+
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
122+
for message in messages:
123+
if message.get("role") == "assistant":
124+
content = message.get("content")
125+
if isinstance(content, list):
126+
content = "\n".join(chunk.get("text") for chunk in content)
127+
message["content"] = content
128+
129+
# The Mistral client, in comparison to the OpenAI client, requires the
130+
# "parameters" dict to be present, even if it's empty.
131+
if tools:
132+
for function in [
133+
tool["function"] for tool in tools
134+
if tool["type"] == "function"
135+
]:
136+
function.setdefault("parameters", {})
137+
138+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
139+
return ChatCompletionRequest(messages=messages,
140+
tools=tools) # type: ignore[type-var]
141+
142+
107143
class MistralTokenizer:
108144

109145
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
@@ -283,27 +319,10 @@ def encode(self, prompt: str) -> List[int]:
283319

284320
def apply_chat_template(self,
285321
messages: List["ChatCompletionMessageParam"],
286-
tools: Optional[Dict[str, Any]] = None,
322+
tools: Optional[List[Dict[str, Any]]] = None,
287323
**kwargs) -> List[int]:
288324

289-
last_message = cast(Dict[str, Any], messages[-1])
290-
if last_message["role"] == "assistant":
291-
last_message["prefix"] = True
292-
293-
from mistral_common.protocol.instruct.request import (
294-
ChatCompletionRequest)
295-
296-
# mistral-common requires AssistantMessage content to be string [1].
297-
#
298-
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
299-
for message in messages:
300-
if message.get("role") == "assistant":
301-
content = message.get("content")
302-
if isinstance(content, list):
303-
content = "\n".join(chunk.get("text") for chunk in content)
304-
message["content"] = content
305-
request = ChatCompletionRequest(messages=messages,
306-
tools=tools) # type: ignore[type-var]
325+
request = make_mistral_chat_completion_request(messages, tools)
307326
encoded = self.mistral.encode_chat_completion(request)
308327

309328
# encode-decode to get clean prompt

0 commit comments

Comments
 (0)