10
10
from huggingface_hub import (file_exists , hf_hub_download , list_repo_files ,
11
11
try_to_load_from_cache )
12
12
from huggingface_hub .utils import (EntryNotFoundError , HfHubHTTPError ,
13
- LocalEntryNotFoundError ,
13
+ HFValidationError , LocalEntryNotFoundError ,
14
14
RepositoryNotFoundError ,
15
15
RevisionNotFoundError )
16
16
from torch import nn
@@ -265,49 +265,66 @@ def get_config(
265
265
return config
266
266
267
267
268
+ def try_get_local_file (model : Union [str , Path ],
269
+ file_name : str ,
270
+ revision : Optional [str ] = 'main' ) -> Optional [Path ]:
271
+ file_path = Path (model ) / file_name
272
+ if file_path .is_file ():
273
+ return file_path
274
+ else :
275
+ try :
276
+ cached_filepath = try_to_load_from_cache (repo_id = model ,
277
+ filename = file_name ,
278
+ revision = revision )
279
+ if isinstance (cached_filepath , str ):
280
+ return Path (cached_filepath )
281
+ except HFValidationError :
282
+ ...
283
+ return None
284
+
285
+
268
286
def get_hf_file_to_dict (file_name : str ,
269
287
model : Union [str , Path ],
270
288
revision : Optional [str ] = 'main' ):
271
289
"""
272
- Downloads a file from the Hugging Face Hub and returns
290
+ Downloads a file from the Hugging Face Hub and returns
273
291
its contents as a dictionary.
274
292
275
293
Parameters:
276
294
- file_name (str): The name of the file to download.
277
295
- model (str): The name of the model on the Hugging Face Hub.
278
- - revision (str): The specific version of the model.
296
+ - revision (str): The specific version of the model.
279
297
280
298
Returns:
281
- - config_dict (dict): A dictionary containing
299
+ - config_dict (dict): A dictionary containing
282
300
the contents of the downloaded file.
283
301
"""
284
- file_path = Path (model ) / file_name
285
302
286
- if file_or_path_exists (model = model ,
287
- config_name = file_name ,
288
- revision = revision ):
303
+ file_path = try_get_local_file (model = model ,
304
+ file_name = file_name ,
305
+ revision = revision )
289
306
290
- if not file_path .is_file ():
291
- try :
292
- hf_hub_file = hf_hub_download (model ,
293
- file_name ,
294
- revision = revision )
295
- except (RepositoryNotFoundError , RevisionNotFoundError ,
296
- EntryNotFoundError , LocalEntryNotFoundError ) as e :
297
- logger .debug ("File or repository not found in hf_hub_download" ,
298
- e )
299
- return None
300
- except HfHubHTTPError as e :
301
- logger .warning (
302
- "Cannot connect to Hugging Face Hub. Skipping file "
303
- "download for '%s':" ,
304
- file_name ,
305
- exc_info = e )
306
- return None
307
- file_path = Path (hf_hub_file )
307
+ if file_path is None and file_or_path_exists (
308
+ model = model , config_name = file_name , revision = revision ):
309
+ try :
310
+ hf_hub_file = hf_hub_download (model , file_name , revision = revision )
311
+ except (RepositoryNotFoundError , RevisionNotFoundError ,
312
+ EntryNotFoundError , LocalEntryNotFoundError ) as e :
313
+ logger .debug ("File or repository not found in hf_hub_download" , e )
314
+ return None
315
+ except HfHubHTTPError as e :
316
+ logger .warning (
317
+ "Cannot connect to Hugging Face Hub. Skipping file "
318
+ "download for '%s':" ,
319
+ file_name ,
320
+ exc_info = e )
321
+ return None
322
+ file_path = Path (hf_hub_file )
308
323
324
+ if file_path is not None and file_path .is_file ():
309
325
with open (file_path ) as file :
310
326
return json .load (file )
327
+
311
328
return None
312
329
313
330
@@ -328,7 +345,12 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
328
345
"""
329
346
330
347
modules_file_name = "modules.json"
331
- modules_dict = get_hf_file_to_dict (modules_file_name , model , revision )
348
+
349
+ modules_dict = None
350
+ if file_or_path_exists (model = model ,
351
+ config_name = modules_file_name ,
352
+ revision = revision ):
353
+ modules_dict = get_hf_file_to_dict (modules_file_name , model , revision )
332
354
333
355
if modules_dict is None :
334
356
return None
@@ -382,17 +404,17 @@ def get_sentence_transformer_tokenizer_config(model: str,
382
404
revision : Optional [str ] = 'main'
383
405
):
384
406
"""
385
- Returns the tokenization configuration dictionary for a
407
+ Returns the tokenization configuration dictionary for a
386
408
given Sentence Transformer BERT model.
387
409
388
410
Parameters:
389
- - model (str): The name of the Sentence Transformer
411
+ - model (str): The name of the Sentence Transformer
390
412
BERT model.
391
413
- revision (str, optional): The revision of the m
392
414
odel to use. Defaults to 'main'.
393
415
394
416
Returns:
395
- - dict: A dictionary containing the configuration parameters
417
+ - dict: A dictionary containing the configuration parameters
396
418
for the Sentence Transformer BERT model.
397
419
"""
398
420
sentence_transformer_config_files = [
@@ -404,20 +426,33 @@ def get_sentence_transformer_tokenizer_config(model: str,
404
426
"sentence_xlm-roberta_config.json" ,
405
427
"sentence_xlnet_config.json" ,
406
428
]
407
- try :
408
- # If model is on HuggingfaceHub, get the repo files
409
- repo_files = list_repo_files (model , revision = revision , token = HF_TOKEN )
410
- except Exception as e :
411
- logger .debug ("Error getting repo files" , e )
412
- repo_files = []
413
-
414
429
encoder_dict = None
415
- for config_name in sentence_transformer_config_files :
416
- if config_name in repo_files or Path (model ).exists ():
417
- encoder_dict = get_hf_file_to_dict (config_name , model , revision )
430
+
431
+ for config_file in sentence_transformer_config_files :
432
+ if try_get_local_file (model = model ,
433
+ file_name = config_file ,
434
+ revision = revision ) is not None :
435
+ encoder_dict = get_hf_file_to_dict (config_file , model , revision )
418
436
if encoder_dict :
419
437
break
420
438
439
+ if not encoder_dict :
440
+ try :
441
+ # If model is on HuggingfaceHub, get the repo files
442
+ repo_files = list_repo_files (model ,
443
+ revision = revision ,
444
+ token = HF_TOKEN )
445
+ except Exception as e :
446
+ logger .debug ("Error getting repo files" , e )
447
+ repo_files = []
448
+
449
+ for config_name in sentence_transformer_config_files :
450
+ if config_name in repo_files :
451
+ encoder_dict = get_hf_file_to_dict (config_name , model ,
452
+ revision )
453
+ if encoder_dict :
454
+ break
455
+
421
456
if not encoder_dict :
422
457
return None
423
458
0 commit comments