Skip to content

Commit 60a08d9

Browse files
youngkentSzymonOzog
authored andcommitted
[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (vllm-project#12518)
Signed-off-by: Keyun Tong <[email protected]> Signed-off-by: SzymonOzog <[email protected]>
1 parent f056049 commit 60a08d9

File tree

11 files changed

+343
-41
lines changed

11 files changed

+343
-41
lines changed

benchmarks/benchmark_serving.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1275,11 +1275,12 @@ def main(args: argparse.Namespace):
12751275
'--tokenizer-mode',
12761276
type=str,
12771277
default="auto",
1278-
choices=['auto', 'slow', 'mistral'],
1278+
choices=['auto', 'slow', 'mistral', 'custom'],
12791279
help='The tokenizer mode.\n\n* "auto" will use the '
12801280
'fast tokenizer if available.\n* "slow" will '
12811281
'always use the slow tokenizer. \n* '
1282-
'"mistral" will always use the `mistral_common` tokenizer.')
1282+
'"mistral" will always use the `mistral_common` tokenizer. \n*'
1283+
'"custom" will use --tokenizer to select the preregistered tokenizer.')
12831284

12841285
parser.add_argument("--served-model-name",
12851286
type=str,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
4+
5+
from vllm.transformers_utils.tokenizer import get_tokenizer
6+
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
7+
TokenizerRegistry)
8+
9+
if TYPE_CHECKING:
10+
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
11+
12+
13+
class TestTokenizer(TokenizerBase):
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
17+
return TestTokenizer()
18+
19+
@property
20+
def all_special_tokens_extended(self) -> List[str]:
21+
raise NotImplementedError()
22+
23+
@property
24+
def all_special_tokens(self) -> List[str]:
25+
raise NotImplementedError()
26+
27+
@property
28+
def all_special_ids(self) -> List[int]:
29+
raise NotImplementedError()
30+
31+
@property
32+
def bos_token_id(self) -> int:
33+
return 0
34+
35+
@property
36+
def eos_token_id(self) -> int:
37+
return 1
38+
39+
@property
40+
def sep_token(self) -> str:
41+
raise NotImplementedError()
42+
43+
@property
44+
def pad_token(self) -> str:
45+
raise NotImplementedError()
46+
47+
@property
48+
def is_fast(self) -> bool:
49+
raise NotImplementedError()
50+
51+
@property
52+
def vocab_size(self) -> int:
53+
raise NotImplementedError()
54+
55+
@property
56+
def max_token_id(self) -> int:
57+
raise NotImplementedError()
58+
59+
def __call__(
60+
self,
61+
text: Union[str, List[str], List[int]],
62+
text_pair: Optional[str] = None,
63+
add_special_tokens: bool = False,
64+
truncation: bool = False,
65+
max_length: Optional[int] = None,
66+
):
67+
raise NotImplementedError()
68+
69+
def get_vocab(self) -> Dict[str, int]:
70+
raise NotImplementedError()
71+
72+
def get_added_vocab(self) -> Dict[str, int]:
73+
raise NotImplementedError()
74+
75+
def encode_one(
76+
self,
77+
text: str,
78+
truncation: bool = False,
79+
max_length: Optional[int] = None,
80+
) -> List[int]:
81+
raise NotImplementedError()
82+
83+
def encode(self,
84+
text: str,
85+
add_special_tokens: Optional[bool] = None) -> List[int]:
86+
raise NotImplementedError()
87+
88+
def apply_chat_template(self,
89+
messages: List["ChatCompletionMessageParam"],
90+
tools: Optional[List[Dict[str, Any]]] = None,
91+
**kwargs) -> List[int]:
92+
raise NotImplementedError()
93+
94+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
95+
raise NotImplementedError()
96+
97+
def decode(self,
98+
ids: Union[List[int], int],
99+
skip_special_tokens: bool = True) -> str:
100+
raise NotImplementedError()
101+
102+
def convert_ids_to_tokens(
103+
self,
104+
ids: List[int],
105+
skip_special_tokens: bool = True,
106+
) -> List[str]:
107+
raise NotImplementedError()
108+
109+
110+
def test_customized_tokenizer():
111+
TokenizerRegistry.register("test_tokenizer",
112+
"tests.tokenization.test_tokenizer_registry",
113+
"TestTokenizer")
114+
115+
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
116+
assert isinstance(tokenizer, TestTokenizer)
117+
assert tokenizer.bos_token_id == 0
118+
assert tokenizer.eos_token_id == 1
119+
120+
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
121+
assert isinstance(tokenizer, TestTokenizer)
122+
assert tokenizer.bos_token_id == 0
123+
assert tokenizer.eos_token_id == 1

vllm/config.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ class ModelConfig:
102102
it; otherwise, you must specify explicitly which task to use.
103103
tokenizer: Name or path of the huggingface tokenizer to use.
104104
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
105-
available, "slow" will always use the slow tokenizer, and
106-
"mistral" will always use the tokenizer from `mistral_common`.
105+
available, "slow" will always use the slow tokenizer,
106+
"mistral" will always use the tokenizer from `mistral_common`, and
107+
"custom" will use --tokenizer to select the preregistered tokenizer.
107108
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
108109
downloading the model and tokenizer.
109110
allowed_local_media_path: Allowing API requests to read local images or
@@ -472,10 +473,10 @@ def _init_has_inner_state(self) -> bool:
472473

473474
def _verify_tokenizer_mode(self) -> None:
474475
tokenizer_mode = self.tokenizer_mode.lower()
475-
if tokenizer_mode not in ["auto", "slow", "mistral"]:
476+
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
476477
raise ValueError(
477478
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
478-
"either 'auto', 'slow' or 'mistral'.")
479+
"either 'auto', 'slow', 'mistral' or 'custom'.")
479480
self.tokenizer_mode = tokenizer_mode
480481

481482
def _get_preferred_task(

vllm/engine/arg_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
285285
'--tokenizer-mode',
286286
type=str,
287287
default=EngineArgs.tokenizer_mode,
288-
choices=['auto', 'slow', 'mistral'],
288+
choices=['auto', 'slow', 'mistral', 'custom'],
289289
help='The tokenizer mode.\n\n* "auto" will use the '
290290
'fast tokenizer if available.\n* "slow" will '
291291
'always use the slow tokenizer. \n* '
292-
'"mistral" will always use the `mistral_common` tokenizer.')
292+
'"mistral" will always use the `mistral_common` tokenizer. \n* '
293+
'"custom" will use --tokenizer to select the '
294+
'preregistered tokenizer.')
293295
parser.add_argument('--trust-remote-code',
294296
action='store_true',
295297
help='Trust remote code from huggingface.')

vllm/entrypoints/llm.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -1051,9 +1051,9 @@ def _embedding_score(
10511051

10521052
def _cross_encoding_score(
10531053
self,
1054-
tokenizer: Union[AnyTokenizer],
1055-
text_1: List[Union[str, TextPrompt, TokensPrompt]],
1056-
text_2: List[Union[str, TextPrompt, TokensPrompt]],
1054+
tokenizer: AnyTokenizer,
1055+
text_1: List[str],
1056+
text_2: List[str],
10571057
truncate_prompt_tokens: Optional[int] = None,
10581058
use_tqdm: bool = True,
10591059
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
@@ -1176,29 +1176,36 @@ def ensure_str(prompt: SingletonPrompt):
11761176
if isinstance(text_1, (str, dict)):
11771177
# Convert a single prompt to a list.
11781178
text_1 = [text_1]
1179-
text_1 = [ensure_str(t) for t in text_1]
1179+
input_text_1: List[str] = [ensure_str(t) for t in text_1]
11801180

11811181
if isinstance(text_2, (str, dict)):
11821182
# Convert a single prompt to a list.
11831183
text_2 = [text_2]
1184-
text_2 = [ensure_str(t) for t in text_2]
1184+
input_text_2: List[str] = [ensure_str(t) for t in text_2]
11851185

1186-
if len(text_1) > 1 and len(text_1) != len(text_2):
1186+
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
11871187
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
1188-
if len(text_1) == 0:
1188+
if len(input_text_1) == 0:
11891189
raise ValueError("At least one text element must be given")
1190-
if len(text_2) == 0:
1190+
if len(input_text_2) == 0:
11911191
raise ValueError("At least one text_pair element must be given")
11921192

11931193
if self.llm_engine.model_config.is_cross_encoder:
1194-
return self._cross_encoding_score(tokenizer, text_1, text_2,
1194+
return self._cross_encoding_score(tokenizer, input_text_1,
1195+
input_text_2,
11951196
truncate_prompt_tokens, use_tqdm,
11961197
lora_request,
11971198
prompt_adapter_request)
11981199
else:
1199-
return self._embedding_score(tokenizer, text_1, text_2,
1200-
truncate_prompt_tokens, use_tqdm,
1201-
lora_request, prompt_adapter_request)
1200+
1201+
return self._embedding_score(
1202+
tokenizer,
1203+
input_text_1, # type: ignore[arg-type]
1204+
input_text_2, # type: ignore[arg-type]
1205+
truncate_prompt_tokens,
1206+
use_tqdm,
1207+
lora_request,
1208+
prompt_adapter_request)
12021209

12031210
def start_profile(self) -> None:
12041211
self.llm_engine.start_profile()

vllm/entrypoints/openai/serving_engine.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,7 @@ async def _preprocess_chat(
400400
_chat_template_kwargs.update(chat_template_kwargs or {})
401401

402402
request_prompt: Union[str, List[int]]
403-
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
404-
if is_mistral_tokenizer:
403+
if isinstance(tokenizer, MistralTokenizer):
405404
request_prompt = apply_mistral_chat_template(
406405
tokenizer,
407406
messages=messages,

vllm/entrypoints/openai/serving_score.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async def create_score(
121121

122122
tokenize_async = make_async(tokenizer.__call__,
123123
executor=self._tokenizer_executor)
124-
prompt_inputs = await tokenize_async(text=q,
124+
prompt_inputs = await tokenize_async(q,
125125
text_pair=t,
126126
**tokenization_kwargs)
127127

vllm/logits_process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
3131

3232
if isinstance(tokenizer, MistralTokenizer):
3333
# Mistral tokenizers should not add special tokens
34-
prompt_token_ids = tokenizer.encode(prompt=prompt)
34+
prompt_token_ids = tokenizer.encode(text=prompt)
3535
else:
3636
prompt_token_ids = tokenizer.encode(text=prompt,
3737
add_special_tokens=False)

vllm/transformers_utils/tokenizer.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
from vllm.envs import VLLM_USE_MODELSCOPE
1515
from vllm.logger import init_logger
1616
from vllm.lora.request import LoRARequest
17+
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
18+
TokenizerRegistry)
1719
from vllm.transformers_utils.tokenizers import MistralTokenizer
1820
from vllm.transformers_utils.utils import check_gguf_file
1921
from vllm.utils import make_async
2022

2123
logger = init_logger(__name__)
2224

2325
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
24-
MistralTokenizer]
26+
TokenizerBase]
2527

2628

2729
def decode_tokens(
@@ -47,11 +49,7 @@ def encode_tokens(
4749
Backend-agnostic equivalent of HF's
4850
:code:`tokenizer.encode(text, add_special_tokens=...)`.
4951
"""
50-
if isinstance(tokenizer, MistralTokenizer):
51-
return tokenizer.tokenizer.encode(text,
52-
bos=add_special_tokens,
53-
eos=add_special_tokens)
54-
elif add_special_tokens is not None:
52+
if add_special_tokens is not None:
5553
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
5654
return tokenizer.encode(text)
5755

@@ -183,9 +181,17 @@ def get_tokenizer(
183181
'encoding and decoding.',
184182
FutureWarning,
185183
stacklevel=2)
184+
185+
tokenizer: AnyTokenizer
186186
if tokenizer_mode == "mistral":
187187
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
188188
revision=revision)
189+
elif tokenizer_mode == "custom":
190+
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
191+
*args,
192+
revision=revision,
193+
download_dir=download_dir,
194+
**kwargs)
189195
else:
190196
try:
191197
tokenizer = AutoTokenizer.from_pretrained(

0 commit comments

Comments
 (0)