-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Safetensors] Make safetensors the default way of saving weights #4235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@DN6 let me know if you need any help here |
@DN6 we should try to get this one merged - it has been stale now for three weeks |
I'll take a look this week @patrickvonplaten |
…huggingface/diffusers into make_safe_serialization_default
@@ -1374,7 +1374,7 @@ def compute_text_embeddings(prompt): | |||
pipeline = pipeline.to(accelerator.device) | |||
|
|||
# load attention processors | |||
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin") | |||
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does save_lora_weights()
have safe_serialization
set to True default? Otherwise, this will fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we don't need to explicitly specify the weight_name
actually since the naming is in the diffusers
format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safe_serialization
defaults to True
in the LoraLoaderMixin
diffusers/src/diffusers/loaders.py
Line 1418 in 1dd3441
safe_serialization: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding, weight_name
. Not setting it while saving is fine, however, loading looks like it needs a path to the weights
diffusers/src/diffusers/loaders.py
Line 906 in 1dd3441
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you only need pipeline.load_lora_weights(args.output_dir)
.
@@ -437,6 +437,7 @@ def test_custom_diffusion(self): | |||
--lr_scheduler constant | |||
--lr_warmup_steps 0 | |||
--modifier_token <new1> | |||
--no_safe_serialization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. Cool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks for tackling this very important component!
I think we're still missing a couple examples no?
I would have expected to see safe_serialization to be set to True in all the officially supported training examples. But didn't see that. What am I missing out on?
@sayakpaul It's possible that a couple have slipped by me. Do you happen to have links for training examples where |
DreamBooth
ControlNet
Some examples. |
Overlook on my part. This PR is ready to merge! |
@@ -884,7 +884,7 @@ def test_custom_model_and_pipeline(self): | |||
) | |||
|
|||
with tempfile.TemporaryDirectory() as tmpdirname: | |||
pipe.save_pretrained(tmpdirname) | |||
pipe.save_pretrained(tmpdirname, safe_serialization=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also test that saving with safe serialization works fine? The default is now that dicts are saved with safe_serialization
-> we should test this as well no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have these tests here for saving.
diffusers/tests/pipelines/test_pipelines.py
Line 126 in 1dd3441
error = f"{traceback.format_exc()}" |
diffusers/tests/pipelines/test_pipelines.py
Lines 578 to 603 in 1dd3441
def test_local_save_load_index(self): | |
prompt = "hello" | |
for variant in [None, "fp16"]: | |
for use_safe in [True, False]: | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes", | |
variant=variant, | |
use_safetensors=use_safe, | |
safety_checker=None, | |
) | |
pipe = pipe.to(torch_device) | |
generator = torch.manual_seed(0) | |
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
pipe.save_pretrained(tmpdirname) | |
pipe_2 = StableDiffusionPipeline.from_pretrained( | |
tmpdirname, safe_serialization=use_safe, variant=variant | |
) | |
pipe_2 = pipe_2.to(torch_device) | |
generator = torch.manual_seed(0) | |
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images | |
assert np.max(np.abs(out - out_2)) < 1e-3 |
Are there some specific things we should also be testing here?
…gingface#4235) * make safetensors default * set default save method as safetensors * update tests * update to support saving safetensors * update test to account for safetensors default * update example tests to use safetensors * update example to support safetensors * update unet tests for safetensors * fix failing loader tests * fix qc issues * fix pipeline tests * fix example test --------- Co-authored-by: Dhruv Nair <[email protected]>
…gingface#4235) * make safetensors default * set default save method as safetensors * update tests * update to support saving safetensors * update test to account for safetensors default * update example tests to use safetensors * update example to support safetensors * update unet tests for safetensors * fix failing loader tests * fix qc issues * fix pipeline tests * fix example test --------- Co-authored-by: Dhruv Nair <[email protected]>
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.