@@ -417,14 +417,35 @@ def get_model(pretrained_model_name_or_path: str) -> str:
417
417
418
418
419
419
def get_tokenizer (
420
- pretrained_model_name_or_path : str , trust_remote_code : bool
420
+ pretrained_model_name_or_path : str ,
421
+ tokenizer_mode : str = "auto" ,
422
+ trust_remote_code : bool = False ,
423
+ ** kwargs ,
421
424
) -> Union [PreTrainedTokenizer , PreTrainedTokenizerFast ]:
422
425
if pretrained_model_name_or_path is not None and not os .path .exists (
423
426
pretrained_model_name_or_path ):
424
427
pretrained_model_name_or_path = get_model (
425
428
pretrained_model_name_or_path )
426
- return AutoTokenizer .from_pretrained (pretrained_model_name_or_path ,
427
- trust_remote_code = trust_remote_code )
429
+ if tokenizer_mode == "slow" :
430
+ if kwargs .get ("use_fast" , False ):
431
+ raise ValueError (
432
+ "Cannot use the fast tokenizer in slow tokenizer mode." )
433
+ kwargs ["use_fast" ] = False
434
+ if tokenizer_mode == "mistral" :
435
+ try :
436
+ from vllm .transformers_utils .tokenizer import MistralTokenizer
437
+ except ImportError as e :
438
+ raise ImportError ("MistralTokenizer requires vllm package.\n "
439
+ "Please install it with `pip install vllm` "
440
+ "to use mistral tokenizer mode." ) from e
441
+ return MistralTokenizer .from_pretrained (
442
+ str (pretrained_model_name_or_path ))
443
+ else :
444
+ return AutoTokenizer .from_pretrained (
445
+ pretrained_model_name_or_path ,
446
+ trust_remote_code = trust_remote_code ,
447
+ ** kwargs ,
448
+ )
428
449
429
450
430
451
ASYNC_REQUEST_FUNCS = {
0 commit comments