Skip to content

Commit d0b45c9

Browse files
committed
Make safetensor import option for now. Improve avg/clean checkpoints ext handling a bit (more consistent).
1 parent 7d9e321 commit d0b45c9

File tree

4 files changed

+127
-62
lines changed

4 files changed

+127
-62
lines changed

avg_checkpoints.py

+33-19
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
import glob
1818
import hashlib
1919
from timm.models import load_state_dict
20-
import safetensors.torch
20+
try:
21+
import safetensors.torch
22+
_has_safetensors = True
23+
except ImportError:
24+
_has_safetensors = False
2125

22-
DEFAULT_OUTPUT = "./average.pth"
23-
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
26+
DEFAULT_OUTPUT = "./averaged.pth"
27+
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
2428

2529
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
2630
parser.add_argument('--input', default='', type=str, metavar='PATH',
@@ -38,6 +42,7 @@
3842
parser.add_argument('--safetensors', action='store_true',
3943
help='Save weights using safetensors instead of the default torch way (pickle).')
4044

45+
4146
def checkpoint_metric(checkpoint_path):
4247
if not checkpoint_path or not os.path.isfile(checkpoint_path):
4348
return {}
@@ -63,14 +68,20 @@ def main():
6368
if args.safetensors and args.output == DEFAULT_OUTPUT:
6469
# Default path changes if using safetensors
6570
args.output = DEFAULT_SAFE_OUTPUT
66-
if args.safetensors and not args.output.endswith(".safetensors"):
71+
72+
output, output_ext = os.path.splitext(args.output)
73+
if not output_ext:
74+
output_ext = ('.safetensors' if args.safetensors else '.pth')
75+
output = output + output_ext
76+
77+
if args.safetensors and not output_ext == ".safetensors":
6778
print(
6879
"Warning: saving weights as safetensors but output file extension is not "
6980
f"set to '.safetensors': {args.output}"
7081
)
7182

72-
if os.path.exists(args.output):
73-
print("Error: Output filename ({}) already exists.".format(args.output))
83+
if os.path.exists(output):
84+
print("Error: Output filename ({}) already exists.".format(output))
7485
exit(1)
7586

7687
pattern = args.input
@@ -87,22 +98,27 @@ def main():
8798
checkpoint_metrics.append((metric, c))
8899
checkpoint_metrics = list(sorted(checkpoint_metrics))
89100
checkpoint_metrics = checkpoint_metrics[-args.n:]
90-
print("Selected checkpoints:")
91-
[print(m, c) for m, c in checkpoint_metrics]
101+
if checkpoint_metrics:
102+
print("Selected checkpoints:")
103+
[print(m, c) for m, c in checkpoint_metrics]
92104
avg_checkpoints = [c for m, c in checkpoint_metrics]
93105
else:
94106
avg_checkpoints = checkpoints
95-
print("Selected checkpoints:")
96-
[print(c) for c in checkpoints]
107+
if avg_checkpoints:
108+
print("Selected checkpoints:")
109+
[print(c) for c in checkpoints]
110+
111+
if not avg_checkpoints:
112+
print('Error: No checkpoints found to average.')
113+
exit(1)
97114

98115
avg_state_dict = {}
99116
avg_counts = {}
100117
for c in avg_checkpoints:
101118
new_state_dict = load_state_dict(c, args.use_ema)
102119
if not new_state_dict:
103-
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
120+
print(f"Error: Checkpoint ({c}) doesn't exist")
104121
continue
105-
106122
for k, v in new_state_dict.items():
107123
if k not in avg_state_dict:
108124
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@@ -122,16 +138,14 @@ def main():
122138
final_state_dict[k] = v.to(dtype=torch.float32)
123139

124140
if args.safetensors:
125-
safetensors.torch.save_file(final_state_dict, args.output)
141+
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
142+
safetensors.torch.save_file(final_state_dict, output)
126143
else:
127-
try:
128-
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
129-
except:
130-
torch.save(final_state_dict, args.output)
144+
torch.save(final_state_dict, output)
131145

132-
with open(args.output, 'rb') as f:
146+
with open(output, 'rb') as f:
133147
sha_hash = hashlib.sha256(f.read()).hexdigest()
134-
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
148+
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
135149

136150

137151
if __name__ == '__main__':

clean_checkpoint.py

+48-18
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
import argparse
1212
import os
1313
import hashlib
14-
import safetensors.torch
1514
import shutil
15+
import tempfile
1616
from timm.models import load_state_dict
17+
try:
18+
import safetensors.torch
19+
_has_safetensors = True
20+
except ImportError:
21+
_has_safetensors = False
1722

1823
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
1924
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@@ -22,13 +27,13 @@
2227
help='output path')
2328
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
2429
help='use ema version of weights if present')
30+
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
31+
help='no hash in output filename')
2532
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
2633
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
2734
parser.add_argument('--safetensors', action='store_true',
2835
help='Save weights using safetensors instead of the default torch way (pickle).')
2936

30-
_TEMP_NAME = './_checkpoint.pth'
31-
3237

3338
def main():
3439
args = parser.parse_args()
@@ -37,10 +42,24 @@ def main():
3742
print("Error: Output filename ({}) already exists.".format(args.output))
3843
exit(1)
3944

40-
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
45+
clean_checkpoint(
46+
args.checkpoint,
47+
args.output,
48+
not args.no_use_ema,
49+
args.no_hash,
50+
args.clean_aux_bn,
51+
safe_serialization=args.safetensors,
52+
)
4153

4254

43-
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
55+
def clean_checkpoint(
56+
checkpoint,
57+
output,
58+
use_ema=True,
59+
no_hash=False,
60+
clean_aux_bn=False,
61+
safe_serialization: bool=False,
62+
):
4463
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
4564
if checkpoint and os.path.isfile(checkpoint):
4665
print("=> Loading checkpoint '{}'".format(checkpoint))
@@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
5574
new_state_dict[name] = v
5675
print("=> Loaded state_dict from '{}'".format(checkpoint))
5776

77+
ext = ''
78+
if output:
79+
checkpoint_root, checkpoint_base = os.path.split(output)
80+
checkpoint_base, ext = os.path.splitext(checkpoint_base)
81+
else:
82+
checkpoint_root = ''
83+
checkpoint_base = os.path.split(checkpoint)[1]
84+
checkpoint_base = os.path.splitext(checkpoint_base)[0]
85+
86+
temp_filename = '__' + checkpoint_base
5887
if safe_serialization:
59-
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
88+
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
89+
safetensors.torch.save_file(new_state_dict, temp_filename)
6090
else:
61-
try:
62-
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
63-
except:
64-
torch.save(new_state_dict, _TEMP_NAME)
91+
torch.save(new_state_dict, temp_filename)
6592

66-
with open(_TEMP_NAME, 'rb') as f:
93+
with open(temp_filename, 'rb') as f:
6794
sha_hash = hashlib.sha256(f.read()).hexdigest()
6895

69-
if output:
70-
checkpoint_root, checkpoint_base = os.path.split(output)
71-
checkpoint_base = os.path.splitext(checkpoint_base)[0]
96+
if ext:
97+
final_ext = ext
7298
else:
73-
checkpoint_root = ''
74-
checkpoint_base = os.path.splitext(checkpoint)[0]
75-
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
76-
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
99+
final_ext = ('.safetensors' if safe_serialization else '.pth')
100+
101+
if no_hash:
102+
final_filename = checkpoint_base + final_ext
103+
else:
104+
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
105+
106+
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
77107
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
78108
return final_filename
79109
else:

timm/models/_helpers.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from collections import OrderedDict
88

99
import torch
10-
import safetensors.torch
10+
try:
11+
import safetensors.torch
12+
_has_safetensors = True
13+
except ImportError:
14+
_has_safetensors = False
1115

1216
import timm.models._builder
1317

@@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
2933
if checkpoint_path and os.path.isfile(checkpoint_path):
3034
# Check if safetensors or not and load weights accordingly
3135
if str(checkpoint_path).endswith(".safetensors"):
36+
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
3237
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
3338
else:
3439
checkpoint = torch.load(checkpoint_path, map_location='cpu')

timm/models/_hub.py

+40-24
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77
from pathlib import Path
88
from tempfile import TemporaryDirectory
99
from typing import Iterable, Optional, Union
10+
1011
import torch
1112
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
12-
import safetensors.torch
1313

1414
try:
1515
from torch.hub import get_dir
1616
except ImportError:
1717
from torch.hub import _get_torch_home as get_dir
1818

19+
try:
20+
import safetensors.torch
21+
_has_safetensors = True
22+
except ImportError:
23+
_has_safetensors = False
24+
1925
if sys.version_info >= (3, 8):
2026
from typing import Literal
2127
else:
@@ -45,6 +51,7 @@
4551
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
4652
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
4753

54+
4855
def get_cache_dir(child_dir=''):
4956
"""
5057
Returns the location of the directory where models are cached (and creates it if necessary).
@@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
164171
hf_model_id, hf_revision = hf_split(model_id)
165172

166173
# Look for .safetensors alternatives and load from it if it exists
167-
for safe_filename in _get_safe_alternatives(filename):
168-
try:
169-
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
170-
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
171-
return safetensors.torch.load_file(cached_safe_file, device="cpu")
172-
except EntryNotFoundError:
173-
pass
174+
if _has_safetensors:
175+
for safe_filename in _get_safe_alternatives(filename):
176+
try:
177+
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
178+
_logger.info(
179+
f"[{model_id}] Safe alternative available for '{filename}' "
180+
f"(as '{safe_filename}'). Loading weights using safetensors.")
181+
return safetensors.torch.load_file(cached_safe_file, device="cpu")
182+
except EntryNotFoundError:
183+
pass
174184

175185
# Otherwise, load using pytorch.load
176186
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
177-
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
187+
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
178188
return torch.load(cached_file, map_location='cpu')
179189

180190

181-
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
191+
def save_config_for_hf(
192+
model,
193+
config_path: str,
194+
model_config: Optional[dict] = None
195+
):
182196
model_config = model_config or {}
183197
hf_config = {}
184198
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@@ -220,15 +234,16 @@ def save_for_hf(
220234
model,
221235
save_directory: str,
222236
model_config: Optional[dict] = None,
223-
safe_serialization: Union[bool, Literal["both"]] = False
224-
):
237+
safe_serialization: Union[bool, Literal["both"]] = False,
238+
):
225239
assert has_hf_hub(True)
226240
save_directory = Path(save_directory)
227241
save_directory.mkdir(exist_ok=True, parents=True)
228242

229243
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
230244
tensors = model.state_dict()
231245
if safe_serialization is True or safe_serialization == "both":
246+
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
232247
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
233248
if safe_serialization is False or safe_serialization == "both":
234249
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
@@ -238,16 +253,16 @@ def save_for_hf(
238253

239254

240255
def push_to_hf_hub(
241-
model,
242-
repo_id: str,
243-
commit_message: str = 'Add model',
244-
token: Optional[str] = None,
245-
revision: Optional[str] = None,
246-
private: bool = False,
247-
create_pr: bool = False,
248-
model_config: Optional[dict] = None,
249-
model_card: Optional[dict] = None,
250-
safe_serialization: Union[bool, Literal["both"]] = False
256+
model,
257+
repo_id: str,
258+
commit_message: str = 'Add model',
259+
token: Optional[str] = None,
260+
revision: Optional[str] = None,
261+
private: bool = False,
262+
create_pr: bool = False,
263+
model_config: Optional[dict] = None,
264+
model_card: Optional[dict] = None,
265+
safe_serialization: Union[bool, Literal["both"]] = False,
251266
):
252267
"""
253268
Arguments:
@@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
341356
readme_text += f"```bibtex\n{c}\n```\n"
342357
return readme_text
343358

359+
344360
def _get_safe_alternatives(filename: str) -> Iterable[str]:
345361
"""Returns potential safetensors alternatives for a given filename.
346362
@@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
350366
"""
351367
if filename == HF_WEIGHTS_NAME:
352368
yield HF_SAFE_WEIGHTS_NAME
353-
if filename.endswith(".bin"):
354-
yield filename[:-4] + ".safetensors"
369+
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
370+
return filename[:-4] + ".safetensors"

0 commit comments

Comments
 (0)