Skip to content

[training] add ds support to lora sd3. #10378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 30, 2024
Merged

[training] add ds support to lora sd3. #10378

merged 4 commits into from
Dec 30, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Fixes #10252.

Cc: @leisuzz could you give this a try?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@leisuzz
Copy link
Contributor

leisuzz commented Dec 30, 2024

@sayakpaul Please take a look at the comment, especially the third comment has to be changed!

@sayakpaul
Copy link
Member Author

Please take a look at the comment, especially the third comment has to be changed!

No idea what this means.

@leisuzz
Copy link
Contributor

leisuzz commented Dec 30, 2024

@sayakpaul In the line 1852, I added the modification on if args.checkpoints_total_limit is not None:
It should be if accelerator.is_main_process and args.checkpoints_total_limit is not None:. Otherwise errors will occur

@sayakpaul
Copy link
Member Author

Can you comment on the changes directly instead? That will be helpful and easier.

@@ -1292,11 +1292,13 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None

for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can modify this to:

transformer_model = unwrap_model(model)
    if args.upcast_before_saving:
        transformer_model = transformer_model.to(torch.float32)
    else:
        transformer_model = transformer_model.to(weight_dtype)
    transformer_lora_layers_to_save = get_peft_model_state_dict(transformer_model)

As

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else should not be needed as the model would already be type-casted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the change of isinstance(model,... to isinstance(unwrap_model(model),... is needed in the if statement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a deepspeed-specific change as with deepspeed, the model gets wrapped into a Module.

transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add and args.train_text_encoder , because it train_text_encoder is false, it should be a None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it already be reflected in the models sent to save_model_hook?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm addressed

@@ -1829,7 +1846,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be if args.checkpoints_total_limit is not None and accelerator.is_main_process:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a must! It has to be changed!!! And I think the correct format should be if accelerator.is_main_process and args.checkpoints_total_limit is not None:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calm down sir :)

The line is already under:

if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:

So, this should already take care of what you're suggesting.

@leisuzz
Copy link
Contributor

leisuzz commented Dec 30, 2024

@sayakpaul Sorry about that, didn't notice the comment status is pending, just submitted.

Copy link
Collaborator

@linoytsaban linoytsaban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot @leisuzz! LGTM, just a couple of small comments

transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it already be reflected in the models sent to save_model_hook?

@@ -1292,11 +1292,13 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None

for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the change of isinstance(model,... to isinstance(unwrap_model(model),... is needed in the if statement?

@sayakpaul sayakpaul merged commit 5f72473 into main Dec 30, 2024
12 checks passed
@sayakpaul sayakpaul deleted the ds-support-sd3-lora branch December 30, 2024 14:01
hulefei pushed a commit to hulefei/diffusers that referenced this pull request Dec 31, 2024
* add ds support to lora sd3.

Co-authored-by: leisuzz <[email protected]>

* style.

---------

Co-authored-by: leisuzz <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
stevhliu pushed a commit that referenced this pull request Dec 31, 2024
* Update pix2pix.md

fix hyperlink error

* fix md link typos

* fix md typo - remove ".md" at the end of links

* [Fix] Broken links in hunyuan docs (#10402)

* fix-hunyuan-broken-links

* [Fix] docs broken links hunyuan

* [training] add ds support to lora sd3. (#10378)

* add ds support to lora sd3.

Co-authored-by: leisuzz <[email protected]>

* style.

---------

Co-authored-by: leisuzz <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>

* fix md typo - remove ".md" at the end of links

* fix md link typos

* fix md typo - remove ".md" at the end of links

---------

Co-authored-by: SahilCarterr <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: leisuzz <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants