-
Notifications
You must be signed in to change notification settings - Fork 6k
[training] add ds support to lora sd3. #10378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: leisuzz <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@sayakpaul Please take a look at the comment, especially the third comment has to be changed! |
No idea what this means. |
@sayakpaul In the line 1852, I added the modification on |
Can you comment on the changes directly instead? That will be helpful and easier. |
@@ -1292,11 +1292,13 @@ def save_model_hook(models, weights, output_dir): | |||
text_encoder_two_lora_layers_to_save = None | |||
|
|||
for model in models: | |||
if isinstance(model, type(unwrap_model(transformer))): | |||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))): | |||
model = unwrap_model(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can modify this to:
transformer_model = unwrap_model(model)
if args.upcast_before_saving:
transformer_model = transformer_model.to(torch.float32)
else:
transformer_model = transformer_model.to(weight_dtype)
transformer_lora_layers_to_save = get_peft_model_state_dict(transformer_model)
As
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else
should not be needed as the model would already be type-casted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the change of isinstance(model,...
to isinstance(unwrap_model(model),...
is needed in the if statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a deepspeed-specific change as with deepspeed, the model gets wrapped into a Module
.
transformer_lora_layers_to_save = get_peft_model_state_dict(model) | ||
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two | ||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add and args.train_text_encoder
, because it train_text_encoder is false, it should be a None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it already be reflected in the models sent to save_model_hook
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm addressed
@@ -1829,7 +1846,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |||
progress_bar.update(1) | |||
global_step += 1 | |||
|
|||
if accelerator.is_main_process: | |||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: | |||
if global_step % args.checkpointing_steps == 0: | |||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | |||
if args.checkpoints_total_limit is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be if args.checkpoints_total_limit is not None and accelerator.is_main_process:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a must! It has to be changed!!! And I think the correct format should be if accelerator.is_main_process and args.checkpoints_total_limit is not None:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calm down sir :)
The line is already under:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
So, this should already take care of what you're suggesting.
@sayakpaul Sorry about that, didn't notice the comment status is pending, just submitted. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot @leisuzz! LGTM, just a couple of small comments
transformer_lora_layers_to_save = get_peft_model_state_dict(model) | ||
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two | ||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it already be reflected in the models sent to save_model_hook
?
@@ -1292,11 +1292,13 @@ def save_model_hook(models, weights, output_dir): | |||
text_encoder_two_lora_layers_to_save = None | |||
|
|||
for model in models: | |||
if isinstance(model, type(unwrap_model(transformer))): | |||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))): | |||
model = unwrap_model(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the change of isinstance(model,...
to isinstance(unwrap_model(model),...
is needed in the if statement?
* add ds support to lora sd3. Co-authored-by: leisuzz <[email protected]> * style. --------- Co-authored-by: leisuzz <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]>
* Update pix2pix.md fix hyperlink error * fix md link typos * fix md typo - remove ".md" at the end of links * [Fix] Broken links in hunyuan docs (#10402) * fix-hunyuan-broken-links * [Fix] docs broken links hunyuan * [training] add ds support to lora sd3. (#10378) * add ds support to lora sd3. Co-authored-by: leisuzz <[email protected]> * style. --------- Co-authored-by: leisuzz <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]> * fix md typo - remove ".md" at the end of links * fix md link typos * fix md typo - remove ".md" at the end of links --------- Co-authored-by: SahilCarterr <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: leisuzz <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]>
What does this PR do?
Fixes #10252.
Cc: @leisuzz could you give this a try?