Skip to content

[Safetensors] Make safetensors the default way of saving weights #4235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathlib import Path

import numpy as np
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand Down Expand Up @@ -296,14 +297,19 @@ def __getitem__(self, index):
return example


def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir):
def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
"""Saves the new token embeddings from the text encoder."""
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
for x, y in zip(modifier_token_id, args.modifier_token):
learned_embeds_dict = {}
learned_embeds_dict[y] = learned_embeds[x]
torch.save(learned_embeds_dict, f"{output_dir}/{y}.bin")
filename = f"{output_dir}/{y}.bin"

if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, filename)


def parse_args(input_args=None):
Expand Down Expand Up @@ -605,6 +611,11 @@ def parse_args(input_args=None):
action="store_true",
help="Dont apply augmentation during data augmentation when this flag is enabled.",
)
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1244,8 +1255,15 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.output_dir)
unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
save_new_embed(
text_encoder,
modifier_token_id,
accelerator,
args,
args.output_dir,
safe_serialization=not args.no_safe_serialization,
)

# Final inference
# Load previous pipeline
Expand All @@ -1256,9 +1274,15 @@ def main(args):
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.unet.load_attn_procs(args.output_dir, weight_name="pytorch_custom_diffusion_weights.bin")
weight_name = (
"pytorch_custom_diffusion_weights.safetensors"
if not args.no_safe_serialization
else "pytorch_custom_diffusion_weights.bin"
)
pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
for token in args.modifier_token:
pipeline.load_textual_inversion(args.output_dir, weight_name=f"{token}.bin")
token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin"
pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)

# run inference
if args.validation_prompt and args.num_validation_images > 0:
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ def compute_text_embeddings(prompt):
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does save_lora_weights() have safe_serialization set to True default? Otherwise, this will fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we don't need to explicitly specify the weight_name actually since the naming is in the diffusers format.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe_serialization defaults to True in the LoraLoaderMixin

safe_serialization: bool = True,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding, weight_name. Not setting it while saving is fine, however, loading looks like it needs a path to the weights

def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you only need pipeline.load_lora_weights(args.output_dir).


# run inference
images = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1370,7 +1370,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = latents # compute loss against the denoised latents
target = latents # compute loss against the denoised latents
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
Expand Down
46 changes: 25 additions & 21 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import unittest
from typing import List

import torch
import safetensors
from accelerate.utils import write_basic_config

from diffusers import DiffusionPipeline, UNet2DConditionModel
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_train_unconditional(self):

run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_textual_inversion(self):
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_dreambooth(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_dreambooth_if(self):
Expand All @@ -170,7 +170,7 @@ def test_dreambooth_if(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_dreambooth_checkpointing(self):
Expand Down Expand Up @@ -272,10 +272,10 @@ def test_dreambooth_lora(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -305,10 +305,10 @@ def test_dreambooth_lora_with_text_encoder(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)
Expand Down Expand Up @@ -341,10 +341,10 @@ def test_dreambooth_lora_if_model(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -373,10 +373,10 @@ def test_dreambooth_lora_sdxl(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -406,10 +406,10 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down Expand Up @@ -437,6 +437,7 @@ def test_custom_diffusion(self):
--lr_scheduler constant
--lr_warmup_steps 0
--modifier_token <new1>
--no_safe_serialization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Cool.

--output_dir {tmpdir}
""".split()

Expand Down Expand Up @@ -466,7 +467,7 @@ def test_text_to_image(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_text_to_image_checkpointing(self):
Expand Down Expand Up @@ -778,7 +779,7 @@ def test_text_to_image_sdxl(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
Expand Down Expand Up @@ -1373,7 +1374,7 @@ def test_controlnet_sdxl(self):

run_command(self._launch_args + test_args)

self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))

def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -1390,6 +1391,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--no_safe_serialization
""".split()

run_command(self._launch_args + test_args)
Expand All @@ -1413,6 +1415,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--dataloader_num_workers=0
--max_train_steps=9
--checkpointing_steps=2
--no_safe_serialization
""".split()

run_command(self._launch_args + test_args)
Expand All @@ -1436,6 +1439,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--no_safe_serialization
""".split()

run_command(self._launch_args + resume_run_args)
Expand Down Expand Up @@ -1464,10 +1468,10 @@ def test_text_to_image_lora_sdxl(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand All @@ -1491,10 +1495,10 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self):

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

Expand Down
32 changes: 28 additions & 4 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import numpy as np
import PIL
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand Down Expand Up @@ -157,15 +158,19 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
return images


def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
logger.info("Saving embeddings")
learned_embeds = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
)
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)

if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, save_path)


def parse_args():
Expand Down Expand Up @@ -409,6 +414,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -878,7 +888,14 @@ def main():
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)

if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
Expand Down Expand Up @@ -936,7 +953,14 @@ def main():
pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)

if args.push_to_hub:
save_model_card(
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ def save_attn_procs(
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
**kwargs,
):
r"""
Save an attention processor to a directory so that it can be reloaded using the
Expand All @@ -514,7 +515,8 @@ def save_attn_procs(
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.

safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
from .models.attention_processor import (
CustomDiffusionAttnProcessor,
Expand Down Expand Up @@ -1413,7 +1415,7 @@ def save_lora_weights(
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Expand All @@ -1434,6 +1436,8 @@ def save_lora_weights(
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
# Create a flat dictionary.
state_dict = {}
Expand Down
Loading