Skip to content

Commit b37d827

Browse files
[Model] Upgrade Aria to transformers 4.48 (#12203)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 3127e97 commit b37d827

File tree

10 files changed

+178
-379
lines changed

10 files changed

+178
-379
lines changed

examples/offline_inference/vision_language.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ def run_aria(question: str, modality: str):
2626

2727
# NOTE: Need L40 (or equivalent) to avoid OOM
2828
llm = LLM(model=model_name,
29-
tokenizer_mode="slow",
30-
dtype="bfloat16",
3129
max_model_len=4096,
3230
max_num_seqs=2,
33-
trust_remote_code=True,
3431
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
3532

3633
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import pytest
1111
from transformers import AutoModelForVision2Seq
1212
from transformers import __version__ as TRANSFORMERS_VERSION
13-
from transformers.utils import is_flash_attn_2_available
1413

1514
from vllm.platforms import current_platform
1615
from vllm.utils import identity
@@ -140,9 +139,7 @@
140139
#### Extended model tests
141140
"aria": VLMTestInfo(
142141
models=["rhymes-ai/Aria"],
143-
tokenizer_mode="slow",
144142
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
145-
dtype="bfloat16",
146143
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
147144
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
148145
max_model_len=4096,
@@ -158,8 +155,8 @@
158155
max_tokens=64,
159156
marks=[
160157
pytest.mark.skipif(
161-
not is_flash_attn_2_available(),
162-
reason="Model needs flash-attn for numeric convergence.",
158+
TRANSFORMERS_VERSION < "4.48.0",
159+
reason="HF model requires transformers>=4.48.0",
163160
),
164161
large_gpu_mark(min_gb=64),
165162
],

tests/models/multimodal/processing/test_common.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.multimodal.utils import cached_get_tokenizer
1212

1313
from ....multimodal.utils import random_audio, random_image, random_video
14+
from ...registry import HF_EXAMPLE_MODELS
1415

1516

1617
def _test_processing_correctness(
@@ -20,12 +21,9 @@ def _test_processing_correctness(
2021
num_batches: int,
2122
simplify_rate: float,
2223
):
23-
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
24-
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
25-
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
26-
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
27-
else:
28-
hf_overrides = {}
24+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
25+
model_info.check_available_online(on_fail="skip")
26+
model_info.check_transformers_version(on_fail="skip")
2927

3028
limit_mm_per_prompt = {
3129
modality: 3 if supports_multi else 1
@@ -41,7 +39,7 @@ def _test_processing_correctness(
4139
seed=0,
4240
dtype="float16",
4341
revision=None,
44-
hf_overrides=hf_overrides,
42+
hf_overrides=model_info.hf_overrides,
4543
limit_mm_per_prompt=limit_mm_per_prompt,
4644
)
4745

tests/models/registry.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
2-
from typing import AbstractSet, Mapping, Optional
2+
from typing import AbstractSet, Any, Literal, Mapping, Optional
3+
4+
import pytest
5+
from packaging.version import Version
6+
from transformers import __version__ as TRANSFORMERS_VERSION
37

48

59
@dataclass(frozen=True)
@@ -38,6 +42,50 @@ class _HfExamplesInfo:
3842
trust_remote_code: bool = False
3943
"""The ``trust_remote_code`` level required to load the model."""
4044

45+
hf_overrides: dict[str, Any] = field(default_factory=dict)
46+
"""The ``hf_overrides`` required to load the model."""
47+
48+
def check_transformers_version(
49+
self,
50+
*,
51+
on_fail: Literal["error", "skip"],
52+
) -> None:
53+
"""
54+
If the installed transformers version does not meet the requirements,
55+
perform the given action.
56+
"""
57+
if self.min_transformers_version is None:
58+
return
59+
60+
current_version = TRANSFORMERS_VERSION
61+
required_version = self.min_transformers_version
62+
if Version(current_version) < Version(required_version):
63+
msg = (
64+
f"You have `transformers=={current_version}` installed, but "
65+
f"`transformers>={required_version}` is required to run this "
66+
"model")
67+
68+
if on_fail == "error":
69+
raise RuntimeError(msg)
70+
else:
71+
pytest.skip(msg)
72+
73+
def check_available_online(
74+
self,
75+
*,
76+
on_fail: Literal["error", "skip"],
77+
) -> None:
78+
"""
79+
If the model is not available online, perform the given action.
80+
"""
81+
if not self.is_available_online:
82+
msg = "Model is not available online"
83+
84+
if on_fail == "error":
85+
raise RuntimeError(msg)
86+
else:
87+
pytest.skip(msg)
88+
4189

4290
# yapf: disable
4391
_TEXT_GENERATION_EXAMPLE_MODELS = {
@@ -48,8 +96,6 @@ class _HfExamplesInfo:
4896
trust_remote_code=True),
4997
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
5098
trust_remote_code=True),
51-
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
52-
trust_remote_code=True),
5399
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
54100
trust_remote_code=True),
55101
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
@@ -176,14 +222,17 @@ class _HfExamplesInfo:
176222

177223
_MULTIMODAL_EXAMPLE_MODELS = {
178224
# [Decoder-only]
225+
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
226+
min_transformers_version="4.48"),
179227
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
180228
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
181229
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
182230
extras={"text_only": "THUDM/chatglm3-6b"},
183231
trust_remote_code=True),
184232
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
185233
is_available_online=False),
186-
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny"), # noqa: E501
234+
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
235+
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
187236
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
188237
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
189238
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
@@ -194,7 +243,8 @@ class _HfExamplesInfo:
194243
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
195244
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
196245
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
197-
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
246+
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
247+
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
198248
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
199249
trust_remote_code=True),
200250
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
@@ -247,5 +297,12 @@ def get_supported_archs(self) -> AbstractSet[str]:
247297
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
248298
return self.hf_models[model_arch]
249299

300+
def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
301+
for info in self.hf_models.values():
302+
if info.default == model_id:
303+
return info
304+
305+
raise ValueError(f"No example model defined for {model_id}")
306+
250307

251308
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)

tests/models/test_initialization.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from unittest.mock import patch
22

33
import pytest
4-
from packaging.version import Version
54
from transformers import PretrainedConfig
6-
from transformers import __version__ as TRANSFORMERS_VERSION
75

86
from vllm import LLM
97

@@ -13,16 +11,8 @@
1311
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
1412
def test_can_initialize(model_arch):
1513
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
16-
if not model_info.is_available_online:
17-
pytest.skip("Model is not available online")
18-
if model_info.min_transformers_version is not None:
19-
current_version = TRANSFORMERS_VERSION
20-
required_version = model_info.min_transformers_version
21-
if Version(current_version) < Version(required_version):
22-
pytest.skip(
23-
f"You have `transformers=={current_version}` installed, but "
24-
f"`transformers>={required_version}` is required to run this "
25-
"model")
14+
model_info.check_available_online(on_fail="skip")
15+
model_info.check_transformers_version(on_fail="skip")
2616

2717
# Avoid OOM
2818
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:

tests/models/test_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
2323
def test_registry_imports(model_arch):
24+
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
25+
model_info.check_transformers_version(on_fail="skip")
26+
2427
# Ensure all model classes can be imported successfully
2528
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
2629

0 commit comments

Comments
 (0)