|
28 | 28 | # yapf: disable
|
29 | 29 | if TYPE_CHECKING:
|
30 | 30 | import huggingface_hub as hfhub
|
31 |
| - import huggingface_hub.utils as hfhub_utils |
32 |
| - from transformers import GenerationConfig, PretrainedConfig |
| 31 | + import huggingface_hub.errors as hfhub_errors |
| 32 | + from transformers.configuration_utils import PretrainedConfig |
| 33 | + from transformers.generation.configuration_utils import GenerationConfig |
33 | 34 | else:
|
34 | 35 | hfhub = LazyLoader("hfhub", globals(), "huggingface_hub")
|
35 |
| - hfhub_utils = LazyLoader("hfhub_utils", globals(), "huggingface_hub.utils") |
| 36 | + hfhub_errors = LazyLoader("hfhub_errors", globals(), "huggingface_hub.errors") |
36 | 37 |
|
37 | 38 | _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, str] = {
|
38 | 39 | "mllama": "MllamaConfig"
|
39 | 40 | }
|
40 | 41 |
|
41 |
| -_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { |
42 |
| - "chatglm": ChatGLMConfig, |
43 |
| - "cohere2": Cohere2Config, |
44 |
| - "dbrx": DbrxConfig, |
45 |
| - "deepseek_vl_v2": DeepseekVLV2Config, |
46 |
| - "kimi_vl": KimiVLConfig, |
47 |
| - "mpt": MPTConfig, |
48 |
| - "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) |
49 |
| - "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) |
50 |
| - "jais": JAISConfig, |
51 |
| - "mlp_speculator": MLPSpeculatorConfig, |
52 |
| - "medusa": MedusaConfig, |
53 |
| - "eagle": EAGLEConfig, |
54 |
| - "exaone": ExaoneConfig, |
55 |
| - "h2ovl_chat": H2OVLChatConfig, |
56 |
| - "internvl_chat": InternVLChatConfig, |
57 |
| - "nemotron": NemotronConfig, |
58 |
| - "NVLM_D": NVLM_D_Config, |
59 |
| - "olmo2": Olmo2Config, |
60 |
| - "solar": SolarConfig, |
61 |
| - "skywork_chat": SkyworkR1VChatConfig, |
62 |
| - "telechat": Telechat2Config, |
63 |
| - "ultravox": UltravoxConfig, |
| 42 | +_CONFIG_REGISTRY: Dict[str, str] = { |
| 43 | + "chatglm": "ChatGLMConfig", |
| 44 | + "cohere2": "Cohere2Config", |
| 45 | + "dbrx": "DbrxConfig", |
| 46 | + "deepseek_vl_v2": "DeepseekVLV2Config", |
| 47 | + "kimi_vl": "KimiVLConfig", |
| 48 | + "mpt": "MPTConfig", |
| 49 | + "RefinedWeb": "RWConfig", # For tiiuae/falcon-40b(-instruct) |
| 50 | + "RefinedWebModel": "RWConfig", # For tiiuae/falcon-7b(-instruct) |
| 51 | + "jais": "JAISConfig", |
| 52 | + "mlp_speculator": "MLPSpeculatorConfig", |
| 53 | + "medusa": "MedusaConfig", |
| 54 | + "eagle": "EAGLEConfig", |
| 55 | + "exaone": "ExaoneConfig", |
| 56 | + "h2ovl_chat": "H2OVLChatConfig", |
| 57 | + "internvl_chat": "InternVLChatConfig", |
| 58 | + "nemotron": "NemotronConfig", |
| 59 | + "NVLM_D": "NVLM_D_Config", |
| 60 | + "olmo2": "Olmo2Config", |
| 61 | + "solar": "SolarConfig", |
| 62 | + "skywork_chat": "SkyworkR1VChatConfig", |
| 63 | + "telechat": "Telechat2Config", |
| 64 | + "ultravox": "UltravoxConfig", |
64 | 65 | **_CONFIG_REGISTRY_OVERRIDE_HF
|
65 | 66 | }
|
66 | 67 |
|
@@ -371,7 +372,7 @@ def try_get_local_file(model: Union[str, Path],
|
371 | 372 | revision=revision)
|
372 | 373 | if isinstance(cached_filepath, str):
|
373 | 374 | return Path(cached_filepath)
|
374 |
| - except hfhub_utils.HFValidationError: |
| 375 | + except hfhub_errors.HFValidationError: |
375 | 376 | ...
|
376 | 377 | return None
|
377 | 378 |
|
|
0 commit comments