Skip to content

Commit 9880de7

Browse files
committed
fix: support missing tool parameters in mistral tokenizer
The Mistral schema is a bit more restrictive and does not accept functions without parameter dictionary. This is no problem with the OpenAI schema though, so the Mistral tokenizer should adapt the request before passing it on. Signed-off-by: Florian Greinacher <[email protected]>
1 parent 2431371 commit 9880de7

File tree

2 files changed

+88
-19
lines changed

2 files changed

+88
-19
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
from mistral_common.protocol.instruct.messages import UserMessage
5+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
6+
from mistral_common.protocol.instruct.tool_calls import Function, Tool
7+
8+
from vllm.transformers_utils.tokenizers.mistral import (
9+
make_mistral_chat_completion_request)
10+
11+
12+
# yapf: enable
13+
@pytest.mark.parametrize(
14+
"openai_request,expected_mistral_request",
15+
[(
16+
{
17+
"messages": [{
18+
"role": "user",
19+
"content": "What is the current local date and time?",
20+
}],
21+
"tools": [{
22+
"type": "function",
23+
"function": {
24+
"description": "Fetch the current local date and time.",
25+
"name": "get_current_time",
26+
},
27+
}],
28+
},
29+
ChatCompletionRequest(
30+
messages=[
31+
UserMessage(content="What is the current local date and time?")
32+
],
33+
tools=[
34+
Tool(
35+
type="function",
36+
function=Function(
37+
name="get_current_time",
38+
description="Fetch the current local date and time.",
39+
parameters={},
40+
),
41+
)
42+
],
43+
),
44+
)],
45+
)
46+
def test_make_mistral_chat_completion_request(openai_request,
47+
expected_mistral_request):
48+
assert (make_mistral_chat_completion_request(
49+
openai_request["messages"],
50+
openai_request["tools"]) == expected_mistral_request)

vllm/transformers_utils/tokenizers/mistral.py

+38-19
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
104104
return matched_files[0]
105105

106106

107+
def make_mistral_chat_completion_request(
108+
messages: List["ChatCompletionMessageParam"],
109+
tools: Optional[List[Dict[str,
110+
Any]]] = None) -> "ChatCompletionRequest":
111+
last_message = cast(Dict[str, Any], messages[-1])
112+
if last_message["role"] == "assistant":
113+
last_message["prefix"] = True
114+
115+
last_message = cast(Dict[str, Any], messages[-1])
116+
if last_message["role"] == "assistant":
117+
last_message["prefix"] = True
118+
119+
# mistral-common requires AssistantMessage content to be string [1].
120+
#
121+
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
122+
for message in messages:
123+
if message.get("role") == "assistant":
124+
content = message.get("content")
125+
if isinstance(content, list):
126+
content = "\n".join(chunk.get("text") for chunk in content)
127+
message["content"] = content
128+
129+
# The Mistral client, in comparison to the OpenAI client, requires the
130+
# "parameters" dict to be present, even if it's empty.
131+
if tools:
132+
for function in [
133+
tool["function"] for tool in tools
134+
if tool["type"] == "function"
135+
]:
136+
function.setdefault("parameters", {})
137+
138+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
139+
return ChatCompletionRequest(messages=messages,
140+
tools=tools) # type: ignore[type-var]
141+
142+
107143
class MistralTokenizer:
108144

109145
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
@@ -283,27 +319,10 @@ def encode(self, prompt: str) -> List[int]:
283319

284320
def apply_chat_template(self,
285321
messages: List["ChatCompletionMessageParam"],
286-
tools: Optional[Dict[str, Any]] = None,
322+
tools: Optional[List[Dict[str, Any]]] = None,
287323
**kwargs) -> List[int]:
288324

289-
last_message = cast(Dict[str, Any], messages[-1])
290-
if last_message["role"] == "assistant":
291-
last_message["prefix"] = True
292-
293-
from mistral_common.protocol.instruct.request import (
294-
ChatCompletionRequest)
295-
296-
# mistral-common requires AssistantMessage content to be string [1].
297-
#
298-
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
299-
for message in messages:
300-
if message.get("role") == "assistant":
301-
content = message.get("content")
302-
if isinstance(content, list):
303-
content = "\n".join(chunk.get("text") for chunk in content)
304-
message["content"] = content
305-
request = ChatCompletionRequest(messages=messages,
306-
tools=tools) # type: ignore[type-var]
325+
request = make_mistral_chat_completion_request(messages, tools)
307326
encoded = self.mistral.encode_chat_completion(request)
308327

309328
# encode-decode to get clean prompt

0 commit comments

Comments
 (0)