Skip to content

Commit c6db213

Browse files
authored
bugfix: Fix signature mismatch in benchmark's get_tokenizer function (#11982)
Signed-off-by: elijah <[email protected]>
1 parent a7d5968 commit c6db213

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

benchmarks/backend_request_func.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,14 +417,35 @@ def get_model(pretrained_model_name_or_path: str) -> str:
417417

418418

419419
def get_tokenizer(
420-
pretrained_model_name_or_path: str, trust_remote_code: bool
420+
pretrained_model_name_or_path: str,
421+
tokenizer_mode: str = "auto",
422+
trust_remote_code: bool = False,
423+
**kwargs,
421424
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
422425
if pretrained_model_name_or_path is not None and not os.path.exists(
423426
pretrained_model_name_or_path):
424427
pretrained_model_name_or_path = get_model(
425428
pretrained_model_name_or_path)
426-
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
427-
trust_remote_code=trust_remote_code)
429+
if tokenizer_mode == "slow":
430+
if kwargs.get("use_fast", False):
431+
raise ValueError(
432+
"Cannot use the fast tokenizer in slow tokenizer mode.")
433+
kwargs["use_fast"] = False
434+
if tokenizer_mode == "mistral":
435+
try:
436+
from vllm.transformers_utils.tokenizer import MistralTokenizer
437+
except ImportError as e:
438+
raise ImportError("MistralTokenizer requires vllm package.\n"
439+
"Please install it with `pip install vllm` "
440+
"to use mistral tokenizer mode.") from e
441+
return MistralTokenizer.from_pretrained(
442+
str(pretrained_model_name_or_path))
443+
else:
444+
return AutoTokenizer.from_pretrained(
445+
pretrained_model_name_or_path,
446+
trust_remote_code=trust_remote_code,
447+
**kwargs,
448+
)
428449

429450

430451
ASYNC_REQUEST_FUNCS = {

0 commit comments

Comments
 (0)