1
1
from dataclasses import dataclass , field
2
- from typing import AbstractSet , Mapping , Optional
2
+ from typing import AbstractSet , Any , Literal , Mapping , Optional
3
+
4
+ import pytest
5
+ from packaging .version import Version
6
+ from transformers import __version__ as TRANSFORMERS_VERSION
3
7
4
8
5
9
@dataclass (frozen = True )
@@ -38,6 +42,50 @@ class _HfExamplesInfo:
38
42
trust_remote_code : bool = False
39
43
"""The ``trust_remote_code`` level required to load the model."""
40
44
45
+ hf_overrides : dict [str , Any ] = field (default_factory = dict )
46
+ """The ``hf_overrides`` required to load the model."""
47
+
48
+ def check_transformers_version (
49
+ self ,
50
+ * ,
51
+ on_fail : Literal ["error" , "skip" ],
52
+ ) -> None :
53
+ """
54
+ If the installed transformers version does not meet the requirements,
55
+ perform the given action.
56
+ """
57
+ if self .min_transformers_version is None :
58
+ return
59
+
60
+ current_version = TRANSFORMERS_VERSION
61
+ required_version = self .min_transformers_version
62
+ if Version (current_version ) < Version (required_version ):
63
+ msg = (
64
+ f"You have `transformers=={ current_version } ` installed, but "
65
+ f"`transformers>={ required_version } ` is required to run this "
66
+ "model" )
67
+
68
+ if on_fail == "error" :
69
+ raise RuntimeError (msg )
70
+ else :
71
+ pytest .skip (msg )
72
+
73
+ def check_available_online (
74
+ self ,
75
+ * ,
76
+ on_fail : Literal ["error" , "skip" ],
77
+ ) -> None :
78
+ """
79
+ If the model is not available online, perform the given action.
80
+ """
81
+ if not self .is_available_online :
82
+ msg = "Model is not available online"
83
+
84
+ if on_fail == "error" :
85
+ raise RuntimeError (msg )
86
+ else :
87
+ pytest .skip (msg )
88
+
41
89
42
90
# yapf: disable
43
91
_TEXT_GENERATION_EXAMPLE_MODELS = {
@@ -48,8 +96,6 @@ class _HfExamplesInfo:
48
96
trust_remote_code = True ),
49
97
"ArcticForCausalLM" : _HfExamplesInfo ("Snowflake/snowflake-arctic-instruct" ,
50
98
trust_remote_code = True ),
51
- "AriaForConditionalGeneration" : _HfExamplesInfo ("rhymes-ai/Aria" ,
52
- trust_remote_code = True ),
53
99
"BaiChuanForCausalLM" : _HfExamplesInfo ("baichuan-inc/Baichuan-7B" ,
54
100
trust_remote_code = True ),
55
101
"BaichuanForCausalLM" : _HfExamplesInfo ("baichuan-inc/Baichuan2-7B-chat" ,
@@ -176,14 +222,17 @@ class _HfExamplesInfo:
176
222
177
223
_MULTIMODAL_EXAMPLE_MODELS = {
178
224
# [Decoder-only]
225
+ "AriaForConditionalGeneration" : _HfExamplesInfo ("rhymes-ai/Aria" ,
226
+ min_transformers_version = "4.48" ),
179
227
"Blip2ForConditionalGeneration" : _HfExamplesInfo ("Salesforce/blip2-opt-2.7b" ), # noqa: E501
180
228
"ChameleonForConditionalGeneration" : _HfExamplesInfo ("facebook/chameleon-7b" ), # noqa: E501
181
229
"ChatGLMModel" : _HfExamplesInfo ("THUDM/glm-4v-9b" ,
182
230
extras = {"text_only" : "THUDM/chatglm3-6b" },
183
231
trust_remote_code = True ),
184
232
"ChatGLMForConditionalGeneration" : _HfExamplesInfo ("chatglm2-6b" ,
185
233
is_available_online = False ),
186
- "DeepseekVLV2ForCausalLM" : _HfExamplesInfo ("deepseek-ai/deepseek-vl2-tiny" ), # noqa: E501
234
+ "DeepseekVLV2ForCausalLM" : _HfExamplesInfo ("deepseek-ai/deepseek-vl2-tiny" , # noqa: E501
235
+ hf_overrides = {"architectures" : ["DeepseekVLV2ForCausalLM" ]}), # noqa: E501
187
236
"FuyuForCausalLM" : _HfExamplesInfo ("adept/fuyu-8b" ),
188
237
"H2OVLChatModel" : _HfExamplesInfo ("h2oai/h2ovl-mississippi-800m" ),
189
238
"InternVLChatModel" : _HfExamplesInfo ("OpenGVLab/InternVL2-1B" ,
@@ -194,7 +243,8 @@ class _HfExamplesInfo:
194
243
"LlavaNextForConditionalGeneration" : _HfExamplesInfo ("llava-hf/llava-v1.6-mistral-7b-hf" ), # noqa: E501
195
244
"LlavaNextVideoForConditionalGeneration" : _HfExamplesInfo ("llava-hf/LLaVA-NeXT-Video-7B-hf" ), # noqa: E501
196
245
"LlavaOnevisionForConditionalGeneration" : _HfExamplesInfo ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf" ), # noqa: E501
197
- "MantisForConditionalGeneration" : _HfExamplesInfo ("TIGER-Lab/Mantis-8B-siglip-llama3" ), # noqa: E501
246
+ "MantisForConditionalGeneration" : _HfExamplesInfo ("TIGER-Lab/Mantis-8B-siglip-llama3" , # noqa: E501
247
+ hf_overrides = {"architectures" : ["MantisForConditionalGeneration" ]}), # noqa: E501
198
248
"MiniCPMV" : _HfExamplesInfo ("openbmb/MiniCPM-Llama3-V-2_5" ,
199
249
trust_remote_code = True ),
200
250
"MolmoForCausalLM" : _HfExamplesInfo ("allenai/Molmo-7B-D-0924" ,
@@ -247,5 +297,12 @@ def get_supported_archs(self) -> AbstractSet[str]:
247
297
def get_hf_info (self , model_arch : str ) -> _HfExamplesInfo :
248
298
return self .hf_models [model_arch ]
249
299
300
+ def find_hf_info (self , model_id : str ) -> _HfExamplesInfo :
301
+ for info in self .hf_models .values ():
302
+ if info .default == model_id :
303
+ return info
304
+
305
+ raise ValueError (f"No example model defined for { model_id } " )
306
+
250
307
251
308
HF_EXAMPLE_MODELS = HfExampleModels (_EXAMPLE_MODELS )
0 commit comments