7
7
from pathlib import Path
8
8
from tempfile import TemporaryDirectory
9
9
from typing import Iterable , Optional , Union
10
+
10
11
import torch
11
12
from torch .hub import HASH_REGEX , download_url_to_file , urlparse
12
- import safetensors .torch
13
13
14
14
try :
15
15
from torch .hub import get_dir
16
16
except ImportError :
17
17
from torch .hub import _get_torch_home as get_dir
18
18
19
+ try :
20
+ import safetensors .torch
21
+ _has_safetensors = True
22
+ except ImportError :
23
+ _has_safetensors = False
24
+
19
25
if sys .version_info >= (3 , 8 ):
20
26
from typing import Literal
21
27
else :
45
51
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
46
52
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
47
53
54
+
48
55
def get_cache_dir (child_dir = '' ):
49
56
"""
50
57
Returns the location of the directory where models are cached (and creates it if necessary).
@@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
164
171
hf_model_id , hf_revision = hf_split (model_id )
165
172
166
173
# Look for .safetensors alternatives and load from it if it exists
167
- for safe_filename in _get_safe_alternatives (filename ):
168
- try :
169
- cached_safe_file = hf_hub_download (repo_id = hf_model_id , filename = safe_filename , revision = hf_revision )
170
- _logger .info (f"[{ model_id } ] Safe alternative available for '{ filename } ' (as '{ safe_filename } '). Loading weights using safetensors." )
171
- return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
172
- except EntryNotFoundError :
173
- pass
174
+ if _has_safetensors :
175
+ for safe_filename in _get_safe_alternatives (filename ):
176
+ try :
177
+ cached_safe_file = hf_hub_download (repo_id = hf_model_id , filename = safe_filename , revision = hf_revision )
178
+ _logger .info (
179
+ f"[{ model_id } ] Safe alternative available for '{ filename } ' "
180
+ f"(as '{ safe_filename } '). Loading weights using safetensors." )
181
+ return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
182
+ except EntryNotFoundError :
183
+ pass
174
184
175
185
# Otherwise, load using pytorch.load
176
186
cached_file = hf_hub_download (hf_model_id , filename = filename , revision = hf_revision )
177
- _logger .info (f"[{ model_id } ] Safe alternative not found for '{ filename } '. Loading weights using default pytorch." )
187
+ _logger .debug (f"[{ model_id } ] Safe alternative not found for '{ filename } '. Loading weights using default pytorch." )
178
188
return torch .load (cached_file , map_location = 'cpu' )
179
189
180
190
181
- def save_config_for_hf (model , config_path : str , model_config : Optional [dict ] = None ):
191
+ def save_config_for_hf (
192
+ model ,
193
+ config_path : str ,
194
+ model_config : Optional [dict ] = None
195
+ ):
182
196
model_config = model_config or {}
183
197
hf_config = {}
184
198
pretrained_cfg = filter_pretrained_cfg (model .pretrained_cfg , remove_source = True , remove_null = True )
@@ -220,15 +234,16 @@ def save_for_hf(
220
234
model ,
221
235
save_directory : str ,
222
236
model_config : Optional [dict ] = None ,
223
- safe_serialization : Union [bool , Literal ["both" ]] = False
224
- ):
237
+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
238
+ ):
225
239
assert has_hf_hub (True )
226
240
save_directory = Path (save_directory )
227
241
save_directory .mkdir (exist_ok = True , parents = True )
228
242
229
243
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
230
244
tensors = model .state_dict ()
231
245
if safe_serialization is True or safe_serialization == "both" :
246
+ assert _has_safetensors , "`pip install safetensors` to use .safetensors"
232
247
safetensors .torch .save_file (tensors , save_directory / HF_SAFE_WEIGHTS_NAME )
233
248
if safe_serialization is False or safe_serialization == "both" :
234
249
torch .save (tensors , save_directory / HF_WEIGHTS_NAME )
@@ -238,16 +253,16 @@ def save_for_hf(
238
253
239
254
240
255
def push_to_hf_hub (
241
- model ,
242
- repo_id : str ,
243
- commit_message : str = 'Add model' ,
244
- token : Optional [str ] = None ,
245
- revision : Optional [str ] = None ,
246
- private : bool = False ,
247
- create_pr : bool = False ,
248
- model_config : Optional [dict ] = None ,
249
- model_card : Optional [dict ] = None ,
250
- safe_serialization : Union [bool , Literal ["both" ]] = False
256
+ model ,
257
+ repo_id : str ,
258
+ commit_message : str = 'Add model' ,
259
+ token : Optional [str ] = None ,
260
+ revision : Optional [str ] = None ,
261
+ private : bool = False ,
262
+ create_pr : bool = False ,
263
+ model_config : Optional [dict ] = None ,
264
+ model_card : Optional [dict ] = None ,
265
+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
251
266
):
252
267
"""
253
268
Arguments:
@@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
341
356
readme_text += f"```bibtex\n { c } \n ```\n "
342
357
return readme_text
343
358
359
+
344
360
def _get_safe_alternatives (filename : str ) -> Iterable [str ]:
345
361
"""Returns potential safetensors alternatives for a given filename.
346
362
@@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
350
366
"""
351
367
if filename == HF_WEIGHTS_NAME :
352
368
yield HF_SAFE_WEIGHTS_NAME
353
- if filename .endswith (".bin" ):
354
- yield filename [:- 4 ] + ".safetensors"
369
+ if filename != HF_WEIGHTS_NAME and filename .endswith (".bin" ):
370
+ return filename [:- 4 ] + ".safetensors"
0 commit comments