Skip to content

Commit 5873877

Browse files
authored
[Bugfix] Mistral tool calling when content is list (#18729)
Signed-off-by: mgoin <[email protected]>
1 parent 696259c commit 5873877

File tree

2 files changed

+115
-7
lines changed

2 files changed

+115
-7
lines changed

tests/tokenization/test_mistral_tokenizer.py

Lines changed: 110 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import pytest
4-
from mistral_common.protocol.instruct.messages import UserMessage
4+
from mistral_common.protocol.instruct.messages import (AssistantMessage,
5+
ToolMessage,
6+
UserMessage)
57
from mistral_common.protocol.instruct.request import ChatCompletionRequest
6-
from mistral_common.protocol.instruct.tool_calls import Function, Tool
8+
from mistral_common.protocol.instruct.tool_calls import (Function,
9+
FunctionCall, Tool,
10+
ToolCall)
711

812
from vllm.transformers_utils.tokenizers.mistral import (
913
make_mistral_chat_completion_request)
1014

1115

12-
# yapf: enable
1316
@pytest.mark.parametrize(
1417
"openai_request,expected_mistral_request",
1518
[(
@@ -78,6 +81,107 @@
7881
)
7982
def test_make_mistral_chat_completion_request(openai_request,
8083
expected_mistral_request):
81-
assert (make_mistral_chat_completion_request(
82-
openai_request["messages"],
83-
openai_request["tools"]) == expected_mistral_request)
84+
actual_request = make_mistral_chat_completion_request(
85+
openai_request["messages"], openai_request["tools"])
86+
assert actual_request == expected_mistral_request
87+
88+
89+
# Tool use with list content and reasoning_content
90+
@pytest.mark.parametrize("openai_request,expected_mistral_request", [(
91+
{
92+
"messages": [
93+
{
94+
"role": "user",
95+
"content": "What's the weather in Paris?",
96+
},
97+
{
98+
"role":
99+
"assistant",
100+
"reasoning_content":
101+
None,
102+
"content":
103+
None,
104+
"tool_calls": [{
105+
"id": "call123",
106+
"type": "function",
107+
"function": {
108+
"name": "get_weather",
109+
"arguments": '{"city": "Paris"}',
110+
},
111+
}],
112+
},
113+
{
114+
"role": "tool",
115+
"content": [{
116+
"type": "text",
117+
"text": "Rainy"
118+
}],
119+
"name": "get_weather",
120+
"tool_call_id": "call123",
121+
},
122+
],
123+
"tools": [{
124+
"type": "function",
125+
"function": {
126+
"name": "get_weather",
127+
"description": "Gets the current weather in a city.",
128+
"parameters": {
129+
"type": "object",
130+
"properties": {
131+
"city": {
132+
"type": "string",
133+
"description": "The city name"
134+
}
135+
},
136+
"required": ["city"],
137+
},
138+
},
139+
}],
140+
},
141+
ChatCompletionRequest(
142+
messages=[
143+
UserMessage(content="What's the weather in Paris?"),
144+
AssistantMessage(
145+
content=None,
146+
tool_calls=[
147+
ToolCall(
148+
id="call123",
149+
function=FunctionCall(
150+
name="get_weather",
151+
arguments='{"city": "Paris"}',
152+
),
153+
)
154+
],
155+
),
156+
ToolMessage(
157+
content="Rainy",
158+
tool_call_id="call123",
159+
name="get_weather",
160+
),
161+
],
162+
tools=[
163+
Tool(
164+
type="function",
165+
function=Function(
166+
name="get_weather",
167+
description="Gets the current weather in a city.",
168+
parameters={
169+
"type": "object",
170+
"properties": {
171+
"city": {
172+
"type": "string",
173+
"description": "The city name"
174+
}
175+
},
176+
"required": ["city"],
177+
},
178+
),
179+
)
180+
],
181+
),
182+
)])
183+
def test_make_mistral_chat_completion_request_list_content(
184+
openai_request, expected_mistral_request):
185+
actual_request = make_mistral_chat_completion_request(
186+
openai_request["messages"], openai_request["tools"])
187+
assert actual_request == expected_mistral_request

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,11 @@ def make_mistral_chat_completion_request(
156156
#
157157
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
158158
for message in messages:
159-
if message.get("role") == "assistant":
159+
# Remove reasoning_content as unsupported by Mistral
160+
_ = message.pop("reasoning_content", None) # type: ignore
161+
162+
# Convert list text content to string
163+
if message.get("role") in ("assistant", "tool"):
160164
content = message.get("content")
161165
if isinstance(content, list):
162166
content = "\n".join(chunk.get("text") for chunk in content)

0 commit comments

Comments
 (0)