1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
- import asyncio
4
- import os
5
- import sys
6
- from typing import Optional
7
- from unittest .mock import patch
8
-
9
3
import pytest
10
4
from transformers import AutoTokenizer , PreTrainedTokenizerBase
11
5
12
- from vllm .transformers_utils .tokenizer_group import (TokenizerGroup ,
13
- get_tokenizer_group )
14
- from vllm .transformers_utils .tokenizer_group .ray_tokenizer_group import (
15
- RayTokenizerGroupPool )
16
-
17
- from ..conftest import get_tokenizer_pool_config
18
-
19
-
20
- class CustomTokenizerGroup (TokenizerGroup ):
21
-
22
- def __init__ (self , * args , ** kwargs ):
23
- super ().__init__ (* args , ** kwargs )
24
- self ._i = 0
25
-
26
- def encode (self , * args , ** kwargs ):
27
- self ._i += 1
28
- return super ().encode (* args , ** kwargs )
6
+ from vllm .transformers_utils .tokenizer_group import TokenizerGroup
29
7
30
8
31
9
@pytest .mark .asyncio
32
- @pytest .mark .parametrize ("tokenizer_group_type" ,
33
- [None , "ray" , CustomTokenizerGroup ])
34
- async def test_tokenizer_group (tokenizer_group_type ):
10
+ async def test_tokenizer_group ():
35
11
reference_tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
36
- tokenizer_group = get_tokenizer_group (
37
- get_tokenizer_pool_config (tokenizer_group_type ),
12
+ tokenizer_group = TokenizerGroup (
38
13
tokenizer_id = "gpt2" ,
39
14
enable_lora = False ,
40
15
max_num_seqs = 1 ,
@@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type):
49
24
PreTrainedTokenizerBase )
50
25
assert tokenizer_group .get_lora_tokenizer (
51
26
None ) == await tokenizer_group .get_lora_tokenizer_async (None )
52
- if tokenizer_group_type is CustomTokenizerGroup :
53
- assert tokenizer_group ._i > 0
54
-
55
-
56
- @pytest .mark .asyncio
57
- @pytest .mark .parametrize ("tokenizer_group_type" , ["ray" ])
58
- async def test_tokenizer_group_pool (tokenizer_group_type ):
59
- reference_tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
60
- tokenizer_group_pool = get_tokenizer_group (
61
- get_tokenizer_pool_config (tokenizer_group_type ),
62
- tokenizer_id = "gpt2" ,
63
- enable_lora = False ,
64
- max_num_seqs = 1 ,
65
- max_input_length = None ,
66
- )
67
- # Send multiple requests to the tokenizer group pool
68
- # (more than the pool size)
69
- # and check that all requests are processed correctly.
70
- num_requests = tokenizer_group_pool .pool_size * 5
71
- requests = [
72
- tokenizer_group_pool .encode_async (prompt = f"prompt { i } " ,
73
- lora_request = None )
74
- for i in range (num_requests )
75
- ]
76
- results = await asyncio .gather (* requests )
77
- expected_results = [
78
- reference_tokenizer .encode (f"prompt { i } " ) for i in range (num_requests )
79
- ]
80
- assert results == expected_results
81
-
82
-
83
- @pytest .mark .asyncio
84
- @pytest .mark .parametrize ("tokenizer_group_type" , ["ray" ])
85
- async def test_tokenizer_group_ray_pool_env_var_propagation (
86
- tokenizer_group_type ):
87
- """Test that env vars from caller process are propagated to
88
- tokenizer Ray actors."""
89
- env_var = "MY_ENV_VAR"
90
-
91
- class EnvVarCheckerTokenizerGroup (TokenizerGroup ):
92
-
93
- def ping (self ):
94
- assert os .environ .get (env_var ) == "1"
95
- return super ().ping ()
96
-
97
- class EnvVarCheckerRayTokenizerGroupPool (RayTokenizerGroupPool ):
98
- _worker_cls = EnvVarCheckerTokenizerGroup
99
-
100
- tokenizer_pool_config = get_tokenizer_pool_config (tokenizer_group_type )
101
- tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool .from_config (
102
- tokenizer_pool_config ,
103
- tokenizer_id = "gpt2" ,
104
- enable_lora = False ,
105
- max_num_seqs = 1 ,
106
- max_input_length = None )
107
- with pytest .raises (AssertionError ):
108
- tokenizer_pool .ping ()
109
-
110
- with patch .dict (os .environ , {env_var : "1" }):
111
- tokenizer_pool_config = get_tokenizer_pool_config (tokenizer_group_type )
112
- tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool .from_config (
113
- tokenizer_pool_config ,
114
- tokenizer_id = "gpt2" ,
115
- enable_lora = False ,
116
- max_num_seqs = 1 ,
117
- max_input_length = None )
118
- tokenizer_pool .ping ()
119
-
120
-
121
- @pytest .mark .asyncio
122
- @pytest .mark .parametrize ("tokenizer_group_type" , ["ray" ])
123
- async def test_tokenizer_group_ray_pool_fault_tolerance (tokenizer_group_type ):
124
- """Test that Ray tokenizer pool group can recover from failures and
125
- if that's not possible, mark itself as unhealthy."""
126
-
127
- class FailingTokenizerGroup (TokenizerGroup ):
128
-
129
- def __init__ (self ,
130
- * args ,
131
- fail_at : Optional [list [int ]] = None ,
132
- ** kwargs ):
133
- super ().__init__ (* args , ** kwargs )
134
- self .i = 0
135
- self .fail_at = fail_at or []
136
-
137
- def encode (self , * args , ** kwargs ):
138
- self .i += 1
139
- if self .i in self .fail_at :
140
- sys .exit (1 )
141
- return super ().encode (* args , ** kwargs )
142
-
143
- class FailingRayTokenizerGroupPool (RayTokenizerGroupPool ):
144
- _worker_cls = FailingTokenizerGroup
145
-
146
- # Fail at first iteration
147
- fail_at = [1 ]
148
- tokenizer_pool_config = get_tokenizer_pool_config (tokenizer_group_type )
149
- tokenizer_group_pool = FailingRayTokenizerGroupPool .from_config (
150
- tokenizer_pool_config ,
151
- tokenizer_id = "gpt2" ,
152
- enable_lora = False ,
153
- max_num_seqs = 1 ,
154
- max_input_length = None ,
155
- fail_at = fail_at )
156
- tokenizer_actors = tokenizer_group_pool .tokenizer_actors .copy ()
157
-
158
- # Modify fail at to not fail at all (will be re-read when actor is
159
- # re-initialized).
160
- fail_at [0 ] = 1000
161
-
162
- # We should recover successfully.
163
- await tokenizer_group_pool .encode_async (prompt = "prompt" , lora_request = None )
164
- await tokenizer_group_pool .encode_async (prompt = "prompt" , lora_request = None )
165
-
166
- # Check that we have a new actor
167
- assert len (tokenizer_group_pool .tokenizer_actors ) == len (tokenizer_actors )
168
- assert tokenizer_group_pool .tokenizer_actors != tokenizer_actors
169
-
170
- # Fail at first iteration
171
- fail_at = [1 ]
172
- tokenizer_group_pool = FailingRayTokenizerGroupPool .from_config (
173
- tokenizer_pool_config ,
174
- tokenizer_id = "gpt2" ,
175
- enable_lora = False ,
176
- max_num_seqs = 1 ,
177
- max_input_length = None ,
178
- fail_at = fail_at )
179
-
180
- # We should fail after re-initialization.
181
- with pytest .raises (RuntimeError ):
182
- await tokenizer_group_pool .encode_async (prompt = "prompt" ,
183
- lora_request = None )
184
-
185
- # check_health should raise the same thing
186
- with pytest .raises (RuntimeError ):
187
- tokenizer_group_pool .check_health ()
188
-
189
- # Ensure that non-ActorDiedErrors are still propagated correctly and do not
190
- # cause a re-initialization.
191
- fail_at = []
192
- tokenizer_group_pool = FailingRayTokenizerGroupPool .from_config (
193
- tokenizer_pool_config ,
194
- tokenizer_id = "gpt2" ,
195
- enable_lora = False ,
196
- max_num_seqs = 1 ,
197
- max_input_length = 2 ,
198
- fail_at = fail_at )
199
- tokenizer_actors = tokenizer_group_pool .tokenizer_actors .copy ()
200
-
201
- # Prompt too long error
202
- with pytest .raises (ValueError ):
203
- await tokenizer_group_pool .encode_async (prompt = "prompt" * 100 ,
204
- lora_request = None )
205
- await tokenizer_group_pool .encode_async (prompt = "prompt" , lora_request = None )
206
- # Actors should stay the same.
207
- assert tokenizer_group_pool .tokenizer_actors == tokenizer_actors
0 commit comments