Skip to content

Commit 54cacf0

Browse files
authored
[Bugfix] Mistral tokenizer encode accept list of str (#12149)
Signed-off-by: Kunshang Ji <[email protected]>
1 parent 58fd57f commit 54cacf0

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Tekkenizer)
1919

2020
from vllm.logger import init_logger
21+
from vllm.utils import is_list_of
2122

2223
if TYPE_CHECKING:
2324
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@@ -27,7 +28,7 @@
2728

2829
@dataclass
2930
class Encoding:
30-
input_ids: List[int]
31+
input_ids: Union[List[int], List[List[int]]]
3132

3233

3334
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
@@ -223,17 +224,25 @@ def __len__(self) -> int:
223224

224225
def __call__(
225226
self,
226-
prompt: str,
227+
prompt: Union[str, List[str], List[int]],
227228
add_special_tokens: bool = False,
228229
truncation: bool = False,
229230
max_length: Optional[int] = None,
230231
):
231-
# Mistral Tokenizers should not add special tokens
232-
input_ids = self.encode(prompt)
233-
234-
if truncation:
235-
input_ids = input_ids[:max_length]
236-
232+
input_ids: Union[List[int], List[List[int]]]
233+
# For List[str], original prompt text
234+
if is_list_of(prompt, str):
235+
input_ids_: List[List[int]] = []
236+
for p in prompt:
237+
each_input_ids = self.encode_one(p, truncation, max_length)
238+
input_ids_.append(each_input_ids)
239+
input_ids = input_ids_
240+
# For List[int], apply chat template output, already tokens.
241+
elif is_list_of(prompt, int):
242+
input_ids = prompt
243+
# For str, single prompt text
244+
else:
245+
input_ids = self.encode_one(prompt, truncation, max_length)
237246
return Encoding(input_ids=input_ids)
238247

239248
def get_vocab(self) -> Dict[str, int]:
@@ -245,6 +254,19 @@ def get_added_vocab(self) -> Dict[str, int]:
245254
# Mistral tokenizers have no added vocabulary
246255
return {}
247256

257+
def encode_one(
258+
self,
259+
prompt: str,
260+
truncation: bool = False,
261+
max_length: Optional[int] = None,
262+
) -> List[int]:
263+
# Mistral Tokenizers should not add special tokens
264+
input_ids = self.encode(prompt)
265+
266+
if truncation:
267+
input_ids = input_ids[:max_length]
268+
return input_ids
269+
248270
def encode(self, prompt: str) -> List[int]:
249271
# `encode` should only be used for prompt completion
250272
# it should never be used for chat_completion.

0 commit comments

Comments
 (0)