Skip to content

Safe Serialization does not work for Custom Diffusion #4634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
DN6 opened this issue Aug 16, 2023 · 5 comments
Closed

Safe Serialization does not work for Custom Diffusion #4634

DN6 opened this issue Aug 16, 2023 · 5 comments
Assignees
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@DN6
Copy link
Collaborator

DN6 commented Aug 16, 2023

Describe the bug

Ran into this issue when working on: #4235

Diffusers now defaults to saving models via safetensors.

In certain cases:
https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/train_custom_diffusion.py

saving Custom Unet Attention Processors via safetensors is not possible. This is because of the following lines in loaders.py

state_dict = model_to_save.state_dict()
for name, attn in self.attn_processors.items():
if len(attn.state_dict()) == 0:
state_dict[name] = {}

If there are layers of the UNet with empty weights, creating a state_dict via model_to_save.state_dict() will remove them, and you will get an error when trying to load them via load_attn_procs since there are missing keys.

The current approach is to manually add the keys of the empty weight layers to the state_dict with a empty dictionary value. This is fine when using torch.save but fails when trying to save the file using safetensors with the following error

Traceback (most recent call last):
  File "examples/custom_diffusion/train_custom_diffusion.py", line 1330, in <module>
    main(args)
  File "examples/custom_diffusion/train_custom_diffusion.py", line 1258, in main
    unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
  File "/diffusers/src/diffusers/loaders.py", line 568, in save_attn_procs
    save_function(state_dict, os.path.join(save_directory, weight_name))
  File "/diffusers/src/diffusers/loaders.py", line 534, in save_function
    return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
  File "/opt/venv/lib/python3.8/site-packages/safetensors/torch.py", line 232, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "/opt/venv/lib/python3.8/site-packages/safetensors/torch.py", line 383, in _flatten
    raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}")
ValueError: Key `down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor` is invalid, expected torch.Tensor but received <class 'dict'>

safetensors does not support saving dict type values. You can verify this is an issueby running the following snippet

import torch
from safetensors.torch import save_file

state_dict = {
    "attn1": torch.randn((1, 3)),
    "attn2": {}
}
save_file(state_dict, "my_attn.safetensor", metadata={"format": "pt"})

You will get the following error.

ValueError: Key `attn2` is invalid, expected torch.Tensor but received <class 'dict'>

Trying to replace the dict with an empty tensor in loader.py results in

for name, attn in self.attn_processors.items():
     if len(attn.state_dict()) == 0:
         state_dict[name] = torch.tensor([]) if safe_serialization else {}

Leads to the following error

Traceback (most recent call last):
  File "examples/custom_diffusion/train_custom_diffusion.py", line 1330, in <module>
    main(args)
  File "examples/custom_diffusion/train_custom_diffusion.py", line 1258, in main
    unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
  File "/diffusers/src/diffusers/loaders.py", line 568, in save_attn_procs
    save_function(state_dict, os.path.join(save_directory, weight_name))
  File "/diffusers/src/diffusers/loaders.py", line 534, in save_function
    return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
  File "/opt/venv/lib/python3.8/site-packages/safetensors/torch.py", line 232, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "/opt/venv/lib/python3.8/site-packages/safetensors/torch.py", line 394, in _flatten
    raise RuntimeError(
RuntimeError:
            Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.0.attentions.0.transformer_blocks.0.attn1.processor'}].
            A potential way to correctly save your model is to use `save_model`.
            More information at https://huggingface.co/docs/safetensors/torch_shared_tensors

IMO, I feel we shouldn't introduce keys with empty values when saving the state_dict. It would be better to omit the empty keys while saving, and deal with missing keys when loading the weights (perhaps via load_state_dict(state_dict, strict=True)?)

Reproduction

Run the train_custom_diffusion.py script.

python examples/custom_diffusion/train_custom_diffusion.py --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe --instance_data_dir docs/source/en/imgs --instance_prompt "<new1>" --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1 --max_train_steps 2 --learning_rate 1.0e-05 --scale_lr --modifier_token "<new1>" --output_dir ./testing_results/

Logs

No response

System Info

N/A

Who can help?

No response

@DN6 DN6 added the bug Something isn't working label Aug 16, 2023
@sayakpaul
Copy link
Member

I think the best approach here is not use safetensors in this to avoid complexity. This is my opinion.

@patrickvonplaten
Copy link
Contributor

BTW don't think this PR is a prio at the moment - do we have any stats about how much custom diffusion is used?

@patrickvonplaten
Copy link
Contributor

Also cc @Narsil this is an example where PyTorch saving != safe_serialization saving

@Narsil
Copy link
Contributor

Narsil commented Aug 28, 2023

Thanks for the ping. Will follow if it becomes a prio. (I'm not sure how saving empty tensors/dicts in a file is interesting, aside from making torch happy).

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 18, 2023
@github-actions github-actions bot closed this as completed Nov 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

4 participants