Skip to content

Commit b0adb6a

Browse files
maxdebayserlulmer
authored andcommitted
Further reduce the HTTP calls to huggingface.co (vllm-project#13107)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent 3c4784f commit b0adb6a

File tree

1 file changed

+79
-56
lines changed

1 file changed

+79
-56
lines changed

vllm/transformers_utils/config.py

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import json
55
import os
66
import time
7+
from functools import cache
78
from pathlib import Path
8-
from typing import Any, Dict, Literal, Optional, Type, Union
9+
from typing import Any, Callable, Dict, Literal, Optional, Type, Union
910

1011
import huggingface_hub
11-
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
12-
try_to_load_from_cache)
12+
from huggingface_hub import hf_hub_download
13+
from huggingface_hub import list_repo_files as hf_list_repo_files
14+
from huggingface_hub import try_to_load_from_cache
1315
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
1416
HFValidationError, LocalEntryNotFoundError,
1517
RepositoryNotFoundError,
@@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
8688
MISTRAL = "mistral"
8789

8890

91+
def with_retry(func: Callable[[], Any],
92+
log_msg: str,
93+
max_retries: int = 2,
94+
retry_delay: int = 2):
95+
for attempt in range(max_retries):
96+
try:
97+
return func()
98+
except Exception as e:
99+
if attempt == max_retries - 1:
100+
logger.error("%s: %s", log_msg, e)
101+
raise
102+
logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1,
103+
max_retries)
104+
time.sleep(retry_delay)
105+
retry_delay *= 2
106+
107+
108+
# @cache doesn't cache exceptions
109+
@cache
110+
def list_repo_files(
111+
repo_id: str,
112+
*,
113+
revision: Optional[str] = None,
114+
repo_type: Optional[str] = None,
115+
token: Union[str, bool, None] = None,
116+
) -> list[str]:
117+
118+
def lookup_files():
119+
try:
120+
return hf_list_repo_files(repo_id,
121+
revision=revision,
122+
repo_type=repo_type,
123+
token=token)
124+
except huggingface_hub.errors.OfflineModeIsEnabled:
125+
# Don't raise in offline mode,
126+
# all we know is that we don't have this
127+
# file cached.
128+
return []
129+
130+
return with_retry(lookup_files, "Error retrieving file list")
131+
132+
133+
def file_exists(
134+
repo_id: str,
135+
file_name: str,
136+
*,
137+
repo_type: Optional[str] = None,
138+
revision: Optional[str] = None,
139+
token: Union[str, bool, None] = None,
140+
) -> bool:
141+
142+
file_list = list_repo_files(repo_id,
143+
repo_type=repo_type,
144+
revision=revision,
145+
token=token)
146+
return file_name in file_list
147+
148+
149+
# In offline mode the result can be a false negative
89150
def file_or_path_exists(model: Union[str, Path], config_name: str,
90151
revision: Optional[str]) -> bool:
91152
if Path(model).exists():
@@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
103164
# hf_hub. This will fail in offline mode.
104165

105166
# Call HF to check if the file exists
106-
# 2 retries and exponential backoff
107-
max_retries = 2
108-
retry_delay = 2
109-
for attempt in range(max_retries):
110-
try:
111-
return file_exists(model,
112-
config_name,
113-
revision=revision,
114-
token=HF_TOKEN)
115-
except huggingface_hub.errors.OfflineModeIsEnabled:
116-
# Don't raise in offline mode,
117-
# all we know is that we don't have this
118-
# file cached.
119-
return False
120-
except Exception as e:
121-
logger.error(
122-
"Error checking file existence: %s, retrying %d of %d", e,
123-
attempt + 1, max_retries)
124-
if attempt == max_retries - 1:
125-
logger.error("Error checking file existence: %s", e)
126-
raise
127-
time.sleep(retry_delay)
128-
retry_delay *= 2
129-
continue
130-
return False
167+
return file_exists(str(model),
168+
config_name,
169+
revision=revision,
170+
token=HF_TOKEN)
131171

132172

133173
def patch_rope_scaling(config: PretrainedConfig) -> None:
@@ -208,32 +248,7 @@ def get_config(
208248
revision=revision):
209249
config_format = ConfigFormat.MISTRAL
210250
else:
211-
# If we're in offline mode and found no valid config format, then
212-
# raise an offline mode error to indicate to the user that they
213-
# don't have files cached and may need to go online.
214-
# This is conveniently triggered by calling file_exists().
215-
216-
# Call HF to check if the file exists
217-
# 2 retries and exponential backoff
218-
max_retries = 2
219-
retry_delay = 2
220-
for attempt in range(max_retries):
221-
try:
222-
file_exists(model,
223-
HF_CONFIG_NAME,
224-
revision=revision,
225-
token=HF_TOKEN)
226-
except Exception as e:
227-
logger.error(
228-
"Error checking file existence: %s, retrying %d of %d",
229-
e, attempt + 1, max_retries)
230-
if attempt == max_retries:
231-
logger.error("Error checking file existence: %s", e)
232-
raise e
233-
time.sleep(retry_delay)
234-
retry_delay *= 2
235-
236-
raise ValueError(f"No supported config format found in {model}")
251+
raise ValueError(f"No supported config format found in {model}.")
237252

238253
if config_format == ConfigFormat.HF:
239254
config_dict, _ = PretrainedConfig.get_config_dict(
@@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
339354
file_name=file_name,
340355
revision=revision)
341356

342-
if file_path is None and file_or_path_exists(
343-
model=model, config_name=file_name, revision=revision):
357+
if file_path is None:
344358
try:
345359
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
360+
except huggingface_hub.errors.OfflineModeIsEnabled:
361+
return None
346362
except (RepositoryNotFoundError, RevisionNotFoundError,
347363
EntryNotFoundError, LocalEntryNotFoundError) as e:
348364
logger.debug("File or repository not found in hf_hub_download", e)
@@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
363379
return None
364380

365381

382+
@cache
366383
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
367384
"""
368385
This function gets the pooling and normalize
@@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
390407
if modules_dict is None:
391408
return None
392409

410+
logger.info("Found sentence-transformers modules configuration.")
411+
393412
pooling = next((item for item in modules_dict
394413
if item["type"] == "sentence_transformers.models.Pooling"),
395414
None)
@@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
408427
if pooling_type_name is not None:
409428
pooling_type_name = get_pooling_config_name(pooling_type_name)
410429

430+
logger.info("Found pooling configuration.")
411431
return {"pooling_type": pooling_type_name, "normalize": normalize}
412432

413433
return None
@@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
435455
return None
436456

437457

458+
@cache
438459
def get_sentence_transformer_tokenizer_config(model: str,
439460
revision: Optional[str] = 'main'
440461
):
@@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
491512
if not encoder_dict:
492513
return None
493514

515+
logger.info("Found sentence-transformers tokenize configuration.")
516+
494517
if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
495518
return encoder_dict
496519
return None

0 commit comments

Comments
 (0)