Skip to content

[Safetensors] Make safetensors the default way of saving weights #4235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 17, 2023

Conversation

patrickvonplaten
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten patrickvonplaten marked this pull request as draft July 24, 2023 14:15
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 24, 2023

The documentation is not available anymore as the PR was closed or merged.

@DN6 DN6 self-assigned this Jul 25, 2023
@patrickvonplaten
Copy link
Contributor Author

@DN6 let me know if you need any help here

@patrickvonplaten
Copy link
Contributor Author

@DN6 we should try to get this one merged - it has been stale now for three weeks

@DN6
Copy link
Collaborator

DN6 commented Aug 14, 2023

I'll take a look this week @patrickvonplaten

@DN6 DN6 requested a review from yiyixuxu August 15, 2023 06:05
@DN6 DN6 marked this pull request as ready for review August 16, 2023 15:41
@DN6 DN6 requested a review from sayakpaul August 16, 2023 15:43
@@ -1374,7 +1374,7 @@ def compute_text_embeddings(prompt):
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does save_lora_weights() have safe_serialization set to True default? Otherwise, this will fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we don't need to explicitly specify the weight_name actually since the naming is in the diffusers format.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe_serialization defaults to True in the LoraLoaderMixin

safe_serialization: bool = True,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding, weight_name. Not setting it while saving is fine, however, loading looks like it needs a path to the weights

def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you only need pipeline.load_lora_weights(args.output_dir).

@@ -437,6 +437,7 @@ def test_custom_diffusion(self):
--lr_scheduler constant
--lr_warmup_steps 0
--modifier_token <new1>
--no_safe_serialization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Cool.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks for tackling this very important component!

I think we're still missing a couple examples no?

I would have expected to see safe_serialization to be set to True in all the officially supported training examples. But didn't see that. What am I missing out on?

@DN6
Copy link
Collaborator

DN6 commented Aug 17, 2023

@sayakpaul It's possible that a couple have slipped by me. Do you happen to have links for training examples where safe_serialization isn't enabled by default?

@sayakpaul
Copy link
Member

DreamBooth

ControlNet

Some examples.

@sayakpaul
Copy link
Member

Overlook on my part. This PR is ready to merge!

@sayakpaul sayakpaul merged commit 029fb41 into main Aug 17, 2023
@sayakpaul sayakpaul deleted the make_safe_serialization_default branch August 17, 2023 05:24
@@ -884,7 +884,7 @@ def test_custom_model_and_pipeline(self):
)

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe.save_pretrained(tmpdirname, safe_serialization=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also test that saving with safe serialization works fine? The default is now that dicts are saved with safe_serialization -> we should test this as well no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have these tests here for saving.

error = f"{traceback.format_exc()}"

def test_local_save_load_index(self):
prompt = "hello"
for variant in [None, "fp16"]:
for use_safe in [True, False]:
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes",
variant=variant,
use_safetensors=use_safe,
safety_checker=None,
)
pipe = pipe.to(torch_device)
generator = torch.manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(
tmpdirname, safe_serialization=use_safe, variant=variant
)
pipe_2 = pipe_2.to(torch_device)
generator = torch.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3

Are there some specific things we should also be testing here?

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…gingface#4235)

* make safetensors default

* set default save method as safetensors

* update tests

* update to support saving safetensors

* update test to account for safetensors default

* update example tests to use safetensors

* update example to support safetensors

* update unet tests for safetensors

* fix failing loader tests

* fix qc issues

* fix pipeline tests

* fix example test

---------

Co-authored-by: Dhruv Nair <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…gingface#4235)

* make safetensors default

* set default save method as safetensors

* update tests

* update to support saving safetensors

* update test to account for safetensors default

* update example tests to use safetensors

* update example to support safetensors

* update unet tests for safetensors

* fix failing loader tests

* fix qc issues

* fix pipeline tests

* fix example test

---------

Co-authored-by: Dhruv Nair <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants