4
4
import json
5
5
import os
6
6
import time
7
+ from functools import cache
7
8
from pathlib import Path
8
- from typing import Any , Dict , Literal , Optional , Type , Union
9
+ from typing import Any , Callable , Dict , Literal , Optional , Type , Union
9
10
10
11
import huggingface_hub
11
- from huggingface_hub import (file_exists , hf_hub_download , list_repo_files ,
12
- try_to_load_from_cache )
12
+ from huggingface_hub import hf_hub_download
13
+ from huggingface_hub import list_repo_files as hf_list_repo_files
14
+ from huggingface_hub import try_to_load_from_cache
13
15
from huggingface_hub .utils import (EntryNotFoundError , HfHubHTTPError ,
14
16
HFValidationError , LocalEntryNotFoundError ,
15
17
RepositoryNotFoundError ,
@@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
86
88
MISTRAL = "mistral"
87
89
88
90
91
+ def with_retry (func : Callable [[], Any ],
92
+ log_msg : str ,
93
+ max_retries : int = 2 ,
94
+ retry_delay : int = 2 ):
95
+ for attempt in range (max_retries ):
96
+ try :
97
+ return func ()
98
+ except Exception as e :
99
+ if attempt == max_retries - 1 :
100
+ logger .error ("%s: %s" , log_msg , e )
101
+ raise
102
+ logger .error ("%s: %s, retrying %d of %d" , log_msg , e , attempt + 1 ,
103
+ max_retries )
104
+ time .sleep (retry_delay )
105
+ retry_delay *= 2
106
+
107
+
108
+ # @cache doesn't cache exceptions
109
+ @cache
110
+ def list_repo_files (
111
+ repo_id : str ,
112
+ * ,
113
+ revision : Optional [str ] = None ,
114
+ repo_type : Optional [str ] = None ,
115
+ token : Union [str , bool , None ] = None ,
116
+ ) -> list [str ]:
117
+
118
+ def lookup_files ():
119
+ try :
120
+ return hf_list_repo_files (repo_id ,
121
+ revision = revision ,
122
+ repo_type = repo_type ,
123
+ token = token )
124
+ except huggingface_hub .errors .OfflineModeIsEnabled :
125
+ # Don't raise in offline mode,
126
+ # all we know is that we don't have this
127
+ # file cached.
128
+ return []
129
+
130
+ return with_retry (lookup_files , "Error retrieving file list" )
131
+
132
+
133
+ def file_exists (
134
+ repo_id : str ,
135
+ file_name : str ,
136
+ * ,
137
+ repo_type : Optional [str ] = None ,
138
+ revision : Optional [str ] = None ,
139
+ token : Union [str , bool , None ] = None ,
140
+ ) -> bool :
141
+
142
+ file_list = list_repo_files (repo_id ,
143
+ repo_type = repo_type ,
144
+ revision = revision ,
145
+ token = token )
146
+ return file_name in file_list
147
+
148
+
149
+ # In offline mode the result can be a false negative
89
150
def file_or_path_exists (model : Union [str , Path ], config_name : str ,
90
151
revision : Optional [str ]) -> bool :
91
152
if Path (model ).exists ():
@@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
103
164
# hf_hub. This will fail in offline mode.
104
165
105
166
# Call HF to check if the file exists
106
- # 2 retries and exponential backoff
107
- max_retries = 2
108
- retry_delay = 2
109
- for attempt in range (max_retries ):
110
- try :
111
- return file_exists (model ,
112
- config_name ,
113
- revision = revision ,
114
- token = HF_TOKEN )
115
- except huggingface_hub .errors .OfflineModeIsEnabled :
116
- # Don't raise in offline mode,
117
- # all we know is that we don't have this
118
- # file cached.
119
- return False
120
- except Exception as e :
121
- logger .error (
122
- "Error checking file existence: %s, retrying %d of %d" , e ,
123
- attempt + 1 , max_retries )
124
- if attempt == max_retries - 1 :
125
- logger .error ("Error checking file existence: %s" , e )
126
- raise
127
- time .sleep (retry_delay )
128
- retry_delay *= 2
129
- continue
130
- return False
167
+ return file_exists (str (model ),
168
+ config_name ,
169
+ revision = revision ,
170
+ token = HF_TOKEN )
131
171
132
172
133
173
def patch_rope_scaling (config : PretrainedConfig ) -> None :
@@ -208,32 +248,7 @@ def get_config(
208
248
revision = revision ):
209
249
config_format = ConfigFormat .MISTRAL
210
250
else :
211
- # If we're in offline mode and found no valid config format, then
212
- # raise an offline mode error to indicate to the user that they
213
- # don't have files cached and may need to go online.
214
- # This is conveniently triggered by calling file_exists().
215
-
216
- # Call HF to check if the file exists
217
- # 2 retries and exponential backoff
218
- max_retries = 2
219
- retry_delay = 2
220
- for attempt in range (max_retries ):
221
- try :
222
- file_exists (model ,
223
- HF_CONFIG_NAME ,
224
- revision = revision ,
225
- token = HF_TOKEN )
226
- except Exception as e :
227
- logger .error (
228
- "Error checking file existence: %s, retrying %d of %d" ,
229
- e , attempt + 1 , max_retries )
230
- if attempt == max_retries :
231
- logger .error ("Error checking file existence: %s" , e )
232
- raise e
233
- time .sleep (retry_delay )
234
- retry_delay *= 2
235
-
236
- raise ValueError (f"No supported config format found in { model } " )
251
+ raise ValueError (f"No supported config format found in { model } ." )
237
252
238
253
if config_format == ConfigFormat .HF :
239
254
config_dict , _ = PretrainedConfig .get_config_dict (
@@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
339
354
file_name = file_name ,
340
355
revision = revision )
341
356
342
- if file_path is None and file_or_path_exists (
343
- model = model , config_name = file_name , revision = revision ):
357
+ if file_path is None :
344
358
try :
345
359
hf_hub_file = hf_hub_download (model , file_name , revision = revision )
360
+ except huggingface_hub .errors .OfflineModeIsEnabled :
361
+ return None
346
362
except (RepositoryNotFoundError , RevisionNotFoundError ,
347
363
EntryNotFoundError , LocalEntryNotFoundError ) as e :
348
364
logger .debug ("File or repository not found in hf_hub_download" , e )
@@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
363
379
return None
364
380
365
381
382
+ @cache
366
383
def get_pooling_config (model : str , revision : Optional [str ] = 'main' ):
367
384
"""
368
385
This function gets the pooling and normalize
@@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
390
407
if modules_dict is None :
391
408
return None
392
409
410
+ logger .info ("Found sentence-transformers modules configuration." )
411
+
393
412
pooling = next ((item for item in modules_dict
394
413
if item ["type" ] == "sentence_transformers.models.Pooling" ),
395
414
None )
@@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
408
427
if pooling_type_name is not None :
409
428
pooling_type_name = get_pooling_config_name (pooling_type_name )
410
429
430
+ logger .info ("Found pooling configuration." )
411
431
return {"pooling_type" : pooling_type_name , "normalize" : normalize }
412
432
413
433
return None
@@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
435
455
return None
436
456
437
457
458
+ @cache
438
459
def get_sentence_transformer_tokenizer_config (model : str ,
439
460
revision : Optional [str ] = 'main'
440
461
):
@@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
491
512
if not encoder_dict :
492
513
return None
493
514
515
+ logger .info ("Found sentence-transformers tokenize configuration." )
516
+
494
517
if all (k in encoder_dict for k in ("max_seq_length" , "do_lower_case" )):
495
518
return encoder_dict
496
519
return None
0 commit comments