@@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
104
104
return matched_files [0 ]
105
105
106
106
107
+ def make_mistral_chat_completion_request (
108
+ messages : List ["ChatCompletionMessageParam" ],
109
+ tools : Optional [List [Dict [str ,
110
+ Any ]]] = None ) -> "ChatCompletionRequest" :
111
+ last_message = cast (Dict [str , Any ], messages [- 1 ])
112
+ if last_message ["role" ] == "assistant" :
113
+ last_message ["prefix" ] = True
114
+
115
+ last_message = cast (Dict [str , Any ], messages [- 1 ])
116
+ if last_message ["role" ] == "assistant" :
117
+ last_message ["prefix" ] = True
118
+
119
+ # mistral-common requires AssistantMessage content to be string [1].
120
+ #
121
+ # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
122
+ for message in messages :
123
+ if message .get ("role" ) == "assistant" :
124
+ content = message .get ("content" )
125
+ if isinstance (content , list ):
126
+ content = "\n " .join (chunk .get ("text" ) for chunk in content )
127
+ message ["content" ] = content
128
+
129
+ # The Mistral client, in comparison to the OpenAI client, requires the
130
+ # "parameters" dict to be present, even if it's empty.
131
+ if tools :
132
+ for function in [
133
+ tool ["function" ] for tool in tools
134
+ if tool ["type" ] == "function"
135
+ ]:
136
+ function .setdefault ("parameters" , {})
137
+
138
+ from mistral_common .protocol .instruct .request import ChatCompletionRequest
139
+ return ChatCompletionRequest (messages = messages ,
140
+ tools = tools ) # type: ignore[type-var]
141
+
142
+
107
143
class MistralTokenizer :
108
144
109
145
def __init__ (self , tokenizer : "PublicMistralTokenizer" ) -> None :
@@ -283,27 +319,10 @@ def encode(self, prompt: str) -> List[int]:
283
319
284
320
def apply_chat_template (self ,
285
321
messages : List ["ChatCompletionMessageParam" ],
286
- tools : Optional [Dict [str , Any ]] = None ,
322
+ tools : Optional [List [ Dict [str , Any ] ]] = None ,
287
323
** kwargs ) -> List [int ]:
288
324
289
- last_message = cast (Dict [str , Any ], messages [- 1 ])
290
- if last_message ["role" ] == "assistant" :
291
- last_message ["prefix" ] = True
292
-
293
- from mistral_common .protocol .instruct .request import (
294
- ChatCompletionRequest )
295
-
296
- # mistral-common requires AssistantMessage content to be string [1].
297
- #
298
- # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
299
- for message in messages :
300
- if message .get ("role" ) == "assistant" :
301
- content = message .get ("content" )
302
- if isinstance (content , list ):
303
- content = "\n " .join (chunk .get ("text" ) for chunk in content )
304
- message ["content" ] = content
305
- request = ChatCompletionRequest (messages = messages ,
306
- tools = tools ) # type: ignore[type-var]
325
+ request = make_mistral_chat_completion_request (messages , tools )
307
326
encoded = self .mistral .encode_chat_completion (request )
308
327
309
328
# encode-decode to get clean prompt
0 commit comments