Skip to content

Commit ba93f18

Browse files
patrickvonplatenDN6
authored and
Jimmy
committed
[Safetensors] Make safetensors the default way of saving weights (huggingface#4235)
* make safetensors default * set default save method as safetensors * update tests * update to support saving safetensors * update test to account for safetensors default * update example tests to use safetensors * update example to support safetensors * update unet tests for safetensors * fix failing loader tests * fix qc issues * fix pipeline tests * fix example test --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent d6d09b6 commit ba93f18

17 files changed

+126
-97
lines changed

examples/custom_diffusion/train_custom_diffusion.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pathlib import Path
2727

2828
import numpy as np
29+
import safetensors
2930
import torch
3031
import torch.nn.functional as F
3132
import torch.utils.checkpoint
@@ -296,14 +297,19 @@ def __getitem__(self, index):
296297
return example
297298

298299

299-
def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir):
300+
def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
300301
"""Saves the new token embeddings from the text encoder."""
301302
logger.info("Saving embeddings")
302303
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
303304
for x, y in zip(modifier_token_id, args.modifier_token):
304305
learned_embeds_dict = {}
305306
learned_embeds_dict[y] = learned_embeds[x]
306-
torch.save(learned_embeds_dict, f"{output_dir}/{y}.bin")
307+
filename = f"{output_dir}/{y}.bin"
308+
309+
if safe_serialization:
310+
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
311+
else:
312+
torch.save(learned_embeds_dict, filename)
307313

308314

309315
def parse_args(input_args=None):
@@ -605,6 +611,11 @@ def parse_args(input_args=None):
605611
action="store_true",
606612
help="Dont apply augmentation during data augmentation when this flag is enabled.",
607613
)
614+
parser.add_argument(
615+
"--no_safe_serialization",
616+
action="store_true",
617+
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
618+
)
608619

609620
if input_args is not None:
610621
args = parser.parse_args(input_args)
@@ -1244,8 +1255,15 @@ def main(args):
12441255
accelerator.wait_for_everyone()
12451256
if accelerator.is_main_process:
12461257
unet = unet.to(torch.float32)
1247-
unet.save_attn_procs(args.output_dir)
1248-
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.output_dir)
1258+
unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
1259+
save_new_embed(
1260+
text_encoder,
1261+
modifier_token_id,
1262+
accelerator,
1263+
args,
1264+
args.output_dir,
1265+
safe_serialization=not args.no_safe_serialization,
1266+
)
12491267

12501268
# Final inference
12511269
# Load previous pipeline
@@ -1256,9 +1274,15 @@ def main(args):
12561274
pipeline = pipeline.to(accelerator.device)
12571275

12581276
# load attention processors
1259-
pipeline.unet.load_attn_procs(args.output_dir, weight_name="pytorch_custom_diffusion_weights.bin")
1277+
weight_name = (
1278+
"pytorch_custom_diffusion_weights.safetensors"
1279+
if not args.no_safe_serialization
1280+
else "pytorch_custom_diffusion_weights.bin"
1281+
)
1282+
pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
12601283
for token in args.modifier_token:
1261-
pipeline.load_textual_inversion(args.output_dir, weight_name=f"{token}.bin")
1284+
token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin"
1285+
pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)
12621286

12631287
# run inference
12641288
if args.validation_prompt and args.num_validation_images > 0:

examples/dreambooth/train_dreambooth_lora.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ def compute_text_embeddings(prompt):
13741374
pipeline = pipeline.to(accelerator.device)
13751375

13761376
# load attention processors
1377-
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
1377+
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
13781378

13791379
# run inference
13801380
images = []

examples/test_examples.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import unittest
2424
from typing import List
2525

26-
import torch
26+
import safetensors
2727
from accelerate.utils import write_basic_config
2828

2929
from diffusers import DiffusionPipeline, UNet2DConditionModel
@@ -93,7 +93,7 @@ def test_train_unconditional(self):
9393

9494
run_command(self._launch_args + test_args, return_stdout=True)
9595
# save_pretrained smoke test
96-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
96+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
9797
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
9898

9999
def test_textual_inversion(self):
@@ -144,7 +144,7 @@ def test_dreambooth(self):
144144

145145
run_command(self._launch_args + test_args)
146146
# save_pretrained smoke test
147-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
147+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
148148
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
149149

150150
def test_dreambooth_if(self):
@@ -170,7 +170,7 @@ def test_dreambooth_if(self):
170170

171171
run_command(self._launch_args + test_args)
172172
# save_pretrained smoke test
173-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
173+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
174174
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
175175

176176
def test_dreambooth_checkpointing(self):
@@ -272,10 +272,10 @@ def test_dreambooth_lora(self):
272272

273273
run_command(self._launch_args + test_args)
274274
# save_pretrained smoke test
275-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
275+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
276276

277277
# make sure the state_dict has the correct naming in the parameters.
278-
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
278+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
279279
is_lora = all("lora" in k for k in lora_state_dict.keys())
280280
self.assertTrue(is_lora)
281281

@@ -305,10 +305,10 @@ def test_dreambooth_lora_with_text_encoder(self):
305305

306306
run_command(self._launch_args + test_args)
307307
# save_pretrained smoke test
308-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
308+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
309309

310310
# check `text_encoder` is present at all.
311-
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
311+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
312312
keys = lora_state_dict.keys()
313313
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
314314
self.assertTrue(is_text_encoder_present)
@@ -341,10 +341,10 @@ def test_dreambooth_lora_if_model(self):
341341

342342
run_command(self._launch_args + test_args)
343343
# save_pretrained smoke test
344-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
344+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
345345

346346
# make sure the state_dict has the correct naming in the parameters.
347-
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
347+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
348348
is_lora = all("lora" in k for k in lora_state_dict.keys())
349349
self.assertTrue(is_lora)
350350

@@ -373,10 +373,10 @@ def test_dreambooth_lora_sdxl(self):
373373

374374
run_command(self._launch_args + test_args)
375375
# save_pretrained smoke test
376-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
376+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
377377

378378
# make sure the state_dict has the correct naming in the parameters.
379-
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
379+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
380380
is_lora = all("lora" in k for k in lora_state_dict.keys())
381381
self.assertTrue(is_lora)
382382

@@ -406,10 +406,10 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):
406406

407407
run_command(self._launch_args + test_args)
408408
# save_pretrained smoke test
409-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
409+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
410410

411411
# make sure the state_dict has the correct naming in the parameters.
412-
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
412+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
413413
is_lora = all("lora" in k for k in lora_state_dict.keys())
414414
self.assertTrue(is_lora)
415415

@@ -437,6 +437,7 @@ def test_custom_diffusion(self):
437437
--lr_scheduler constant
438438
--lr_warmup_steps 0
439439
--modifier_token <new1>
440+
--no_safe_serialization
440441
--output_dir {tmpdir}
441442
""".split()
442443

@@ -466,7 +467,7 @@ def test_text_to_image(self):
466467

467468
run_command(self._launch_args + test_args)
468469
# save_pretrained smoke test
469-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
470+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
470471
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
471472

472473
def test_text_to_image_checkpointing(self):
@@ -778,7 +779,7 @@ def test_text_to_image_sdxl(self):
778779

779780
run_command(self._launch_args + test_args)
780781
# save_pretrained smoke test
781-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
782+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
782783
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
783784

784785
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
@@ -1373,7 +1374,7 @@ def test_controlnet_sdxl(self):
13731374

13741375
run_command(self._launch_args + test_args)
13751376

1376-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
1377+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
13771378

13781379
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
13791380
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1390,6 +1391,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
13901391
--max_train_steps=6
13911392
--checkpoints_total_limit=2
13921393
--checkpointing_steps=2
1394+
--no_safe_serialization
13931395
""".split()
13941396

13951397
run_command(self._launch_args + test_args)
@@ -1413,6 +1415,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
14131415
--dataloader_num_workers=0
14141416
--max_train_steps=9
14151417
--checkpointing_steps=2
1418+
--no_safe_serialization
14161419
""".split()
14171420

14181421
run_command(self._launch_args + test_args)
@@ -1436,6 +1439,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
14361439
--checkpointing_steps=2
14371440
--resume_from_checkpoint=checkpoint-8
14381441
--checkpoints_total_limit=3
1442+
--no_safe_serialization
14391443
""".split()
14401444

14411445
run_command(self._launch_args + resume_run_args)
@@ -1464,10 +1468,10 @@ def test_text_to_image_lora_sdxl(self):
14641468

14651469
run_command(self._launch_args + test_args)
14661470
# save_pretrained smoke test
1467-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
1471+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
14681472

14691473
# make sure the state_dict has the correct naming in the parameters.
1470-
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
1474+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
14711475
is_lora = all("lora" in k for k in lora_state_dict.keys())
14721476
self.assertTrue(is_lora)
14731477

@@ -1491,10 +1495,10 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self):
14911495

14921496
run_command(self._launch_args + test_args)
14931497
# save_pretrained smoke test
1494-
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
1498+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
14951499

14961500
# make sure the state_dict has the correct naming in the parameters.
1497-
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
1501+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
14981502
is_lora = all("lora" in k for k in lora_state_dict.keys())
14991503
self.assertTrue(is_lora)
15001504

examples/textual_inversion/textual_inversion.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import numpy as np
2626
import PIL
27+
import safetensors
2728
import torch
2829
import torch.nn.functional as F
2930
import torch.utils.checkpoint
@@ -157,15 +158,19 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
157158
return images
158159

159160

160-
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
161+
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
161162
logger.info("Saving embeddings")
162163
learned_embeds = (
163164
accelerator.unwrap_model(text_encoder)
164165
.get_input_embeddings()
165166
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
166167
)
167168
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
168-
torch.save(learned_embeds_dict, save_path)
169+
170+
if safe_serialization:
171+
safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
172+
else:
173+
torch.save(learned_embeds_dict, save_path)
169174

170175

171176
def parse_args():
@@ -409,6 +414,11 @@ def parse_args():
409414
parser.add_argument(
410415
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
411416
)
417+
parser.add_argument(
418+
"--no_safe_serialization",
419+
action="store_true",
420+
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
421+
)
412422

413423
args = parser.parse_args()
414424
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -878,7 +888,14 @@ def main():
878888
global_step += 1
879889
if global_step % args.save_steps == 0:
880890
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
881-
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
891+
save_progress(
892+
text_encoder,
893+
placeholder_token_ids,
894+
accelerator,
895+
args,
896+
save_path,
897+
safe_serialization=not args.no_safe_serialization,
898+
)
882899

883900
if accelerator.is_main_process:
884901
if global_step % args.checkpointing_steps == 0:
@@ -936,7 +953,14 @@ def main():
936953
pipeline.save_pretrained(args.output_dir)
937954
# Save the newly trained embeddings
938955
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
939-
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
956+
save_progress(
957+
text_encoder,
958+
placeholder_token_ids,
959+
accelerator,
960+
args,
961+
save_path,
962+
safe_serialization=not args.no_safe_serialization,
963+
)
940964

941965
if args.push_to_hub:
942966
save_model_card(

src/diffusers/loaders.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,8 @@ def save_attn_procs(
497497
is_main_process: bool = True,
498498
weight_name: str = None,
499499
save_function: Callable = None,
500-
safe_serialization: bool = False,
500+
safe_serialization: bool = True,
501+
**kwargs,
501502
):
502503
r"""
503504
Save an attention processor to a directory so that it can be reloaded using the
@@ -514,7 +515,8 @@ def save_attn_procs(
514515
The function to use to save the state dictionary. Useful during distributed training when you need to
515516
replace `torch.save` with another method. Can be configured with the environment variable
516517
`DIFFUSERS_SAVE_MODE`.
517-
518+
safe_serialization (`bool`, *optional*, defaults to `True`):
519+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
518520
"""
519521
from .models.attention_processor import (
520522
CustomDiffusionAttnProcessor,
@@ -1414,7 +1416,7 @@ def save_lora_weights(
14141416
is_main_process: bool = True,
14151417
weight_name: str = None,
14161418
save_function: Callable = None,
1417-
safe_serialization: bool = False,
1419+
safe_serialization: bool = True,
14181420
):
14191421
r"""
14201422
Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -1435,6 +1437,8 @@ def save_lora_weights(
14351437
The function to use to save the state dictionary. Useful during distributed training when you need to
14361438
replace `torch.save` with another method. Can be configured with the environment variable
14371439
`DIFFUSERS_SAVE_MODE`.
1440+
safe_serialization (`bool`, *optional*, defaults to `True`):
1441+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
14381442
"""
14391443
# Create a flat dictionary.
14401444
state_dict = {}

0 commit comments

Comments
 (0)