2
2
import gc
3
3
import logging
4
4
import os
5
- from typing import List , Optional , Tuple
5
+ from typing import Any , Dict , List , Optional , Tuple
6
6
7
7
import pytest
8
8
import torch
9
9
from PIL import Image
10
- from transformers import (AutoModelForCausalLM , AutoProcessor ,
11
- LlavaForConditionalGeneration )
10
+ from transformers import (AutoModelForCausalLM , AutoProcessor , AutoTokenizer ,
11
+ LlavaConfig , LlavaForConditionalGeneration )
12
12
13
13
from tests .utils .logging import make_logger
14
14
from vllm import LLM , SamplingParams
15
15
from vllm .config import TokenizerPoolConfig , VisionLanguageConfig
16
16
from vllm .distributed import destroy_model_parallel
17
+ from vllm .logger import init_logger
17
18
from vllm .sequence import MultiModalData
18
- from vllm .transformers_utils .tokenizer import get_tokenizer
19
+
20
+ logger = init_logger (__name__ )
19
21
20
22
_TEST_DIR = os .path .dirname (__file__ )
21
23
_TEST_PROMPTS = [os .path .join (_TEST_DIR , "prompts" , "example.txt" )]
@@ -131,9 +133,7 @@ def example_long_prompts() -> List[str]:
131
133
"float" : torch .float ,
132
134
}
133
135
134
- _VISION_LANGUAGE_MODELS = {
135
- "llava-hf/llava-1.5-7b-hf" : LlavaForConditionalGeneration ,
136
- }
136
+ AutoModelForCausalLM .register (LlavaConfig , LlavaForConditionalGeneration )
137
137
138
138
_EMBEDDING_MODELS = [
139
139
"intfloat/e5-mistral-7b-instruct" ,
@@ -145,23 +145,14 @@ class HfRunner:
145
145
def __init__ (
146
146
self ,
147
147
model_name : str ,
148
- tokenizer_name : Optional [str ] = None ,
149
148
dtype : str = "half" ,
150
149
) -> None :
151
150
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
152
151
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE [dtype ]
152
+
153
153
self .model_name = model_name
154
- if model_name in _VISION_LANGUAGE_MODELS :
155
- self .model = _VISION_LANGUAGE_MODELS [model_name ].from_pretrained (
156
- model_name ,
157
- torch_dtype = torch_dtype ,
158
- trust_remote_code = True ,
159
- ).cuda ()
160
- self .processor = AutoProcessor .from_pretrained (
161
- model_name ,
162
- torch_dtype = torch_dtype ,
163
- )
164
- elif model_name in _EMBEDDING_MODELS :
154
+
155
+ if model_name in _EMBEDDING_MODELS :
165
156
# Lazy init required for AMD CI
166
157
from sentence_transformers import SentenceTransformer
167
158
self .model = SentenceTransformer (
@@ -174,10 +165,24 @@ def __init__(
174
165
torch_dtype = torch_dtype ,
175
166
trust_remote_code = True ,
176
167
).cuda ()
177
- self .processor = None
178
- if tokenizer_name is None :
179
- tokenizer_name = model_name
180
- self .tokenizer = get_tokenizer (tokenizer_name , trust_remote_code = True )
168
+
169
+ self .tokenizer = AutoTokenizer .from_pretrained (
170
+ model_name ,
171
+ torch_dtype = torch_dtype ,
172
+ trust_remote_code = True ,
173
+ )
174
+
175
+ try :
176
+ self .processor = AutoProcessor .from_pretrained (
177
+ model_name ,
178
+ torch_dtype = torch_dtype ,
179
+ trust_remote_code = True ,
180
+ )
181
+ except Exception :
182
+ logger .warning (
183
+ "Unable to auto-load processor from HuggingFace for "
184
+ "model %s. Using tokenizer instead." , model_name )
185
+ self .processor = self .tokenizer
181
186
182
187
def generate (
183
188
self ,
@@ -189,19 +194,19 @@ def generate(
189
194
if images :
190
195
assert len (prompts ) == len (images )
191
196
for i , prompt in enumerate (prompts ):
192
- if self . model_name not in _VISION_LANGUAGE_MODELS :
193
- input_ids = self . tokenizer ( prompt ,
194
- return_tensors = " pt"). input_ids
195
- inputs = { "input_ids" : input_ids . cuda () }
196
- else :
197
- image = images [i ] if images else None
198
- inputs = self . processor ( text = prompt ,
199
- images = image ,
200
- return_tensors = "pt" )
201
- inputs = {
202
- key : value . cuda () if value is not None else None
203
- for key , value in inputs . items ()
204
- }
197
+ processor_kwargs : Dict [ str , Any ] = {
198
+ "text" : prompt ,
199
+ " return_tensors" : " pt",
200
+ }
201
+ if images is not None and images [ i ] is not None :
202
+ processor_kwargs [ "images" ] = images [i ]
203
+
204
+ inputs = self . processor ( ** processor_kwargs )
205
+ inputs = {
206
+ key : value . cuda () if value is not None else None
207
+ for key , value in inputs . items ()
208
+ }
209
+
205
210
output_ids = self .model .generate (
206
211
** inputs ,
207
212
use_cache = True ,
0 commit comments