Skip to content

Commit 0a05ed5

Browse files
authored
Simplify TokenizerGroup (#16790)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 14288d1 commit 0a05ed5

24 files changed

+80
-752
lines changed

tests/conftest.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from vllm import LLM, SamplingParams
2424
from vllm.assets.image import ImageAsset
2525
from vllm.assets.video import VideoAsset
26-
from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype
26+
from vllm.config import TaskOption, _get_and_verify_dtype
2727
from vllm.connections import global_http_connection
2828
from vllm.distributed import (cleanup_dist_env_and_memory,
2929
init_distributed_environment,
@@ -1010,20 +1010,6 @@ def vllm_runner():
10101010
return VllmRunner
10111011

10121012

1013-
def get_tokenizer_pool_config(tokenizer_group_type):
1014-
if tokenizer_group_type is None:
1015-
return None
1016-
if tokenizer_group_type == "ray":
1017-
return TokenizerPoolConfig(pool_size=1,
1018-
pool_type="ray",
1019-
extra_config={})
1020-
if isinstance(tokenizer_group_type, type):
1021-
return TokenizerPoolConfig(pool_size=1,
1022-
pool_type=tokenizer_group_type,
1023-
extra_config={})
1024-
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
1025-
1026-
10271013
@pytest.fixture()
10281014
def temporary_enable_log_propagate():
10291015
import logging

tests/lora/test_tokenizer_group.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@
55

66
from vllm.lora.request import LoRARequest
77
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
8-
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
9-
10-
from ..conftest import get_tokenizer_pool_config
8+
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
119

1210

1311
@pytest.mark.asyncio
1412
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
1513
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
1614
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
17-
tokenizer_group = get_tokenizer_group(
18-
get_tokenizer_pool_config(tokenizer_group_type),
15+
tokenizer_group = TokenizerGroup(
1916
tokenizer_id="gpt2",
2017
enable_lora=True,
2118
max_num_seqs=1,
@@ -60,8 +57,7 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
6057
@pytest.mark.parametrize("max_num_seqs", [1, 2])
6158
@pytest.mark.parametrize("max_loras", [1, 2])
6259
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
63-
tokenizer_group = get_tokenizer_group(
64-
get_tokenizer_pool_config(None),
60+
tokenizer_group = TokenizerGroup(
6561
tokenizer_id="gpt2",
6662
enable_lora=enable_lora,
6763
max_num_seqs=max_num_seqs,

tests/tokenization/test_detokenize.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.inputs import token_inputs
1111
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
1212
from vllm.transformers_utils.detokenizer import Detokenizer
13-
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
13+
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
1414
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
1515
from vllm.v1.engine import EngineCoreRequest
1616
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
@@ -212,7 +212,7 @@ def test_oov_decode(tokenizer, fast):
212212

213213
@pytest.fixture
214214
def detokenizer(tokenizer_name: str) -> Detokenizer:
215-
init_kwargs = dict(
215+
tokenizer_group = TokenizerGroup(
216216
tokenizer_id=tokenizer_name,
217217
enable_lora=False,
218218
max_num_seqs=100,
@@ -222,11 +222,6 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
222222
revision=None,
223223
)
224224

225-
tokenizer_group = get_tokenizer_group(
226-
None,
227-
**init_kwargs,
228-
)
229-
230225
return Detokenizer(tokenizer_group)
231226

232227

Lines changed: 3 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
import asyncio
4-
import os
5-
import sys
6-
from typing import Optional
7-
from unittest.mock import patch
8-
93
import pytest
104
from transformers import AutoTokenizer, PreTrainedTokenizerBase
115

12-
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
13-
get_tokenizer_group)
14-
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
15-
RayTokenizerGroupPool)
16-
17-
from ..conftest import get_tokenizer_pool_config
18-
19-
20-
class CustomTokenizerGroup(TokenizerGroup):
21-
22-
def __init__(self, *args, **kwargs):
23-
super().__init__(*args, **kwargs)
24-
self._i = 0
25-
26-
def encode(self, *args, **kwargs):
27-
self._i += 1
28-
return super().encode(*args, **kwargs)
6+
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
297

308

319
@pytest.mark.asyncio
32-
@pytest.mark.parametrize("tokenizer_group_type",
33-
[None, "ray", CustomTokenizerGroup])
34-
async def test_tokenizer_group(tokenizer_group_type):
10+
async def test_tokenizer_group():
3511
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
36-
tokenizer_group = get_tokenizer_group(
37-
get_tokenizer_pool_config(tokenizer_group_type),
12+
tokenizer_group = TokenizerGroup(
3813
tokenizer_id="gpt2",
3914
enable_lora=False,
4015
max_num_seqs=1,
@@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type):
4924
PreTrainedTokenizerBase)
5025
assert tokenizer_group.get_lora_tokenizer(
5126
None) == await tokenizer_group.get_lora_tokenizer_async(None)
52-
if tokenizer_group_type is CustomTokenizerGroup:
53-
assert tokenizer_group._i > 0
54-
55-
56-
@pytest.mark.asyncio
57-
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
58-
async def test_tokenizer_group_pool(tokenizer_group_type):
59-
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
60-
tokenizer_group_pool = get_tokenizer_group(
61-
get_tokenizer_pool_config(tokenizer_group_type),
62-
tokenizer_id="gpt2",
63-
enable_lora=False,
64-
max_num_seqs=1,
65-
max_input_length=None,
66-
)
67-
# Send multiple requests to the tokenizer group pool
68-
# (more than the pool size)
69-
# and check that all requests are processed correctly.
70-
num_requests = tokenizer_group_pool.pool_size * 5
71-
requests = [
72-
tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
73-
lora_request=None)
74-
for i in range(num_requests)
75-
]
76-
results = await asyncio.gather(*requests)
77-
expected_results = [
78-
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
79-
]
80-
assert results == expected_results
81-
82-
83-
@pytest.mark.asyncio
84-
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
85-
async def test_tokenizer_group_ray_pool_env_var_propagation(
86-
tokenizer_group_type):
87-
"""Test that env vars from caller process are propagated to
88-
tokenizer Ray actors."""
89-
env_var = "MY_ENV_VAR"
90-
91-
class EnvVarCheckerTokenizerGroup(TokenizerGroup):
92-
93-
def ping(self):
94-
assert os.environ.get(env_var) == "1"
95-
return super().ping()
96-
97-
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
98-
_worker_cls = EnvVarCheckerTokenizerGroup
99-
100-
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
101-
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
102-
tokenizer_pool_config,
103-
tokenizer_id="gpt2",
104-
enable_lora=False,
105-
max_num_seqs=1,
106-
max_input_length=None)
107-
with pytest.raises(AssertionError):
108-
tokenizer_pool.ping()
109-
110-
with patch.dict(os.environ, {env_var: "1"}):
111-
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
112-
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
113-
tokenizer_pool_config,
114-
tokenizer_id="gpt2",
115-
enable_lora=False,
116-
max_num_seqs=1,
117-
max_input_length=None)
118-
tokenizer_pool.ping()
119-
120-
121-
@pytest.mark.asyncio
122-
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
123-
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
124-
"""Test that Ray tokenizer pool group can recover from failures and
125-
if that's not possible, mark itself as unhealthy."""
126-
127-
class FailingTokenizerGroup(TokenizerGroup):
128-
129-
def __init__(self,
130-
*args,
131-
fail_at: Optional[list[int]] = None,
132-
**kwargs):
133-
super().__init__(*args, **kwargs)
134-
self.i = 0
135-
self.fail_at = fail_at or []
136-
137-
def encode(self, *args, **kwargs):
138-
self.i += 1
139-
if self.i in self.fail_at:
140-
sys.exit(1)
141-
return super().encode(*args, **kwargs)
142-
143-
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
144-
_worker_cls = FailingTokenizerGroup
145-
146-
# Fail at first iteration
147-
fail_at = [1]
148-
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
149-
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
150-
tokenizer_pool_config,
151-
tokenizer_id="gpt2",
152-
enable_lora=False,
153-
max_num_seqs=1,
154-
max_input_length=None,
155-
fail_at=fail_at)
156-
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
157-
158-
# Modify fail at to not fail at all (will be re-read when actor is
159-
# re-initialized).
160-
fail_at[0] = 1000
161-
162-
# We should recover successfully.
163-
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
164-
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
165-
166-
# Check that we have a new actor
167-
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
168-
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
169-
170-
# Fail at first iteration
171-
fail_at = [1]
172-
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
173-
tokenizer_pool_config,
174-
tokenizer_id="gpt2",
175-
enable_lora=False,
176-
max_num_seqs=1,
177-
max_input_length=None,
178-
fail_at=fail_at)
179-
180-
# We should fail after re-initialization.
181-
with pytest.raises(RuntimeError):
182-
await tokenizer_group_pool.encode_async(prompt="prompt",
183-
lora_request=None)
184-
185-
# check_health should raise the same thing
186-
with pytest.raises(RuntimeError):
187-
tokenizer_group_pool.check_health()
188-
189-
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
190-
# cause a re-initialization.
191-
fail_at = []
192-
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
193-
tokenizer_pool_config,
194-
tokenizer_id="gpt2",
195-
enable_lora=False,
196-
max_num_seqs=1,
197-
max_input_length=2,
198-
fail_at=fail_at)
199-
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
200-
201-
# Prompt too long error
202-
with pytest.raises(ValueError):
203-
await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
204-
lora_request=None)
205-
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
206-
# Actors should stay the same.
207-
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors

tests/v1/engine/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
4747
tokenizer=tokenizer,
4848
tokenizer_group=init_tokenizer_from_configs(
4949
vllm_config.model_config, vllm_config.scheduler_config,
50-
vllm_config.parallel_config, vllm_config.lora_config),
50+
vllm_config.lora_config),
5151
vllm_config=vllm_config,
5252
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
5353
prompt_tokens=prompt_tokens,

tests/v1/engine/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
99

1010
from vllm.engine.arg_utils import EngineArgs
11-
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
12-
BaseTokenizerGroup)
11+
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
1312
from vllm.v1.engine import EngineCoreOutput, FinishReason
1413
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
1514

@@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors(
296295
class DummyOutputProcessorTestVectors:
297296
"""Dummy test vectors for output processor tests"""
298297
tokenizer: GeneralTokenizerType
299-
tokenizer_group: BaseTokenizerGroup
298+
tokenizer_group: TokenizerGroup
300299
vllm_config: EngineArgs
301300
full_tokens: list[list[int]] # Prompt + generated tokens
302301
prompt_tokens: list[list[int]]

0 commit comments

Comments
 (0)