18
18
Tekkenizer )
19
19
20
20
from vllm .logger import init_logger
21
+ from vllm .utils import is_list_of
21
22
22
23
if TYPE_CHECKING :
23
24
from vllm .entrypoints .chat_utils import ChatCompletionMessageParam
27
28
28
29
@dataclass
29
30
class Encoding :
30
- input_ids : List [int ]
31
+ input_ids : Union [ List [int ], List [ List [ int ]] ]
31
32
32
33
33
34
def maybe_serialize_tool_calls (request : ChatCompletionRequest ):
@@ -223,17 +224,25 @@ def __len__(self) -> int:
223
224
224
225
def __call__ (
225
226
self ,
226
- prompt : str ,
227
+ prompt : Union [ str , List [ str ], List [ int ]] ,
227
228
add_special_tokens : bool = False ,
228
229
truncation : bool = False ,
229
230
max_length : Optional [int ] = None ,
230
231
):
231
- # Mistral Tokenizers should not add special tokens
232
- input_ids = self .encode (prompt )
233
-
234
- if truncation :
235
- input_ids = input_ids [:max_length ]
236
-
232
+ input_ids : Union [List [int ], List [List [int ]]]
233
+ # For List[str], original prompt text
234
+ if is_list_of (prompt , str ):
235
+ input_ids_ : List [List [int ]] = []
236
+ for p in prompt :
237
+ each_input_ids = self .encode_one (p , truncation , max_length )
238
+ input_ids_ .append (each_input_ids )
239
+ input_ids = input_ids_
240
+ # For List[int], apply chat template output, already tokens.
241
+ elif is_list_of (prompt , int ):
242
+ input_ids = prompt
243
+ # For str, single prompt text
244
+ else :
245
+ input_ids = self .encode_one (prompt , truncation , max_length )
237
246
return Encoding (input_ids = input_ids )
238
247
239
248
def get_vocab (self ) -> Dict [str , int ]:
@@ -245,6 +254,19 @@ def get_added_vocab(self) -> Dict[str, int]:
245
254
# Mistral tokenizers have no added vocabulary
246
255
return {}
247
256
257
+ def encode_one (
258
+ self ,
259
+ prompt : str ,
260
+ truncation : bool = False ,
261
+ max_length : Optional [int ] = None ,
262
+ ) -> List [int ]:
263
+ # Mistral Tokenizers should not add special tokens
264
+ input_ids = self .encode (prompt )
265
+
266
+ if truncation :
267
+ input_ids = input_ids [:max_length ]
268
+ return input_ids
269
+
248
270
def encode (self , prompt : str ) -> List [int ]:
249
271
# `encode` should only be used for prompt completion
250
272
# it should never be used for chat_completion.
0 commit comments