Skip to content

Load Kohya-ss style LoRAs with auxilary states #4147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

isidentical
Copy link
Contributor

@isidentical isidentical commented Jul 18, 2023

What does this PR do?

This PR is a revival of the original #3756 (as discussed in #3756 (comment)). Majority of the credits goes to @takuma104 on making the original PR as well as explaining the impacts/changes in detail. On top of that work, I've moved the text encoder patching parts to a modular notion (as done in #3778), added support for unload_loras() with the auxiliary parts, as well as general cleanups / test fixes.

Fixes #3725

Examples

On a LoRA model for a random pet, the current revision of diffusers produces a related although not super precise image:
test

And with this PR, it now looks like the original "pet" that the LoRA trained with:
test

@isidentical
Copy link
Contributor Author

CC: @sayakpaul @takuma104 (sent an invite to my diffusers fork for both of you, in case you might want to push changes)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 18, 2023

The documentation is not available anymore as the PR was closed or merged.

@isidentical isidentical force-pushed the kohya-lora-aux-features branch from 873d04a to b90a44e Compare July 18, 2023 20:09
@isidentical isidentical marked this pull request as ready for review July 18, 2023 20:15
@isidentical isidentical force-pushed the kohya-lora-aux-features branch from b90a44e to 4048fb1 Compare July 18, 2023 20:22
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much, @isidentical I left some initial comments.

Let me run some qualitative tests and get back to you.

I think we can add some tests to ensure things are robust.

@sayakpaul
Copy link
Member

sayakpaul commented Jul 19, 2023

Quick update:

I ran the tests with RUN_SLOW=1 enabled:

RUN_SLOW=1 pytest tests/models/test_lora_layers.py

But it fails test_unload() and test_a1111() currently with assertion errors. Looking into it. Note that this failure doesn't happen in main.

@sayakpaul
Copy link
Member

Left some discussions here: isidentical#1.

Overall, I think this PR is in a good state already, except for this issue: isidentical#1 (comment).

So, to summarize, the following TODOs are remaining:

  • Reflect this enhanced support from the docs by enlisting a couple of non-diffusers LoRA from the docs.
  • Address the above issue

Then in a future PR we can tackle SDXL LoRA support, what say?

@sayakpaul
Copy link
Member

Quick update:

I ran the tests with RUN_SLOW=1 enabled:

RUN_SLOW=1 pytest tests/models/test_lora_layers.py

But it fails test_unload() and test_a1111() currently with assertion errors. Looking into it. Note that this failure doesn't happen in main.

isidentical#1 should fix all of these :)

@isidentical
Copy link
Contributor Author

isidentical commented Jul 19, 2023

Thanks a lot for the feedback @sayakpaul, I'll address them today will let you know the for second round of PR reviews!

@isidentical isidentical force-pushed the kohya-lora-aux-features branch from f778d91 to c72fb16 Compare July 19, 2023 17:33
@isidentical
Copy link
Contributor Author

@sayakpaul I just noticed that loading+unloading+loading LoRAs doesn't work with the new unet unpatching system, since setting .forward to old-forward causes it to never recognize loras ever again (so loading a lora and then unloading works, but after that we can never load a lora again). Instead of patching forward, i've changed that logic in 84348fb to just nullify lora layers which causes the same exact effect without the said bug WDYT?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool! Thanks for iterating here @isidentical

@williamberman could you take a final look? :-)

Comment on lines 1131 to 1140
for name, _ in text_encoder_mlp_modules(text_encoder):
for direction in ["up", "down"]:
for layer in ["fc1", "fc2"]:
original_key = f"{name}.{layer}.lora.{direction}.weight"
replacement_key = f"{name}.{layer}.lora_linear_layer.{direction}.weight"
if original_key in text_encoder_lora_state_dict:
text_encoder_lora_state_dict[replacement_key] = text_encoder_lora_state_dict.pop(
original_key
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these layers are newly supported in this PR why would they be added to the legacy naming convention conversion code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with the legacy convention because the rest of the conversion logic used the legacy convention (which made it match the locality of _convert_kohya_ss_state_dict) but I also understand this is a new addition to the legacy part of the codebase so I'll try to directly write it in the new convention with a brief comment about why it doesn't match the actual function (it should be refactored separately I think, maybe even after the proposed #4247).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, does _convert_kohya_ss_state_dict currently return keys in the legacy naming convention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, does _convert_kohya_ss_state_dict currently return keys in the legacy naming convention?

It does, actually, IMO. This is why integration tests also pass. Otherwise, there would have been assertion problems I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I think we should follow up then and update the rest of _convert_kohya_ss_state_dict to use the new naming convention

@williamberman
Copy link
Contributor

Few brief questions but looks really good!

@isidentical isidentical force-pushed the kohya-lora-aux-features branch 2 times, most recently from 7a441d4 to 424a344 Compare July 25, 2023 00:55
@sayakpaul
Copy link
Member

Taking care of the final nits:

  • Resolving conflicts
  • Adding comments

@isidentical will drop a PR to your branch and then let's go :)

@sayakpaul
Copy link
Member

@isidentical I added a comment to elaborate the load + unload+ load test and also changed its name to reflect it's using a Kohya-style checkpoint.

I think you (as the owner of the repo) can only rebase the main to resolve the conflicts. Let me know if you face any difficulties.

Also, just for sanity, the SLOW tests are passing.

@williamberman could you give this one final look?

@sayakpaul sayakpaul requested a review from williamberman July 25, 2023 13:31
@isidentical isidentical force-pushed the kohya-lora-aux-features branch from d05a8d2 to 66a775a Compare July 25, 2023 13:43
@isidentical
Copy link
Contributor Author

@sayakpaul done! Let's get this landed 🚀

@@ -1092,8 +1139,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
rank = text_encoder_lora_state_dict[
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok with this for now, but it would be really nice to avoid this if possible in the future. I think the meta point to think about here is once we have checks like this at any point inside the code, we have to now consider what are implications for any state dict checking or changing code any time we touch a model definition.

This part specifically is a model definition from a separate library which is even more hairy to be checking. We're lucky that we know the specific way transformers is written is that they very rarely change model definitions once they're written, but in general that's not something that we should rely on.

I think a good analogue is consider applications on your computer that serialize their state as locally stored files. That's all state dicts are, an application serialization format. Almost all applications will say, you should not make any assumptions about the format or make modifications to our files we store. If they do say files are user editable, they're usually very explicitly documented where as our state dict formats are implicitly documented through a combination of code in different libraries and how diffusers elects to monkey patch updated model definitions

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@patrickvonplaten
Copy link
Contributor

Amazing job here @isidentical - merging!

@patrickvonplaten patrickvonplaten merged commit ff8f580 into huggingface:main Jul 25, 2023
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* Support to load Kohya-ss style LoRA file format (without restrictions)

Co-Authored-By: Takuma Mori <[email protected]>
Co-Authored-By: Sayak Paul <[email protected]>

* tmp: add sdxl to mlp_modules

---------

Co-authored-by: Takuma Mori <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* Support to load Kohya-ss style LoRA file format (without restrictions)

Co-Authored-By: Takuma Mori <[email protected]>
Co-Authored-By: Sayak Paul <[email protected]>

* tmp: add sdxl to mlp_modules

---------

Co-authored-by: Takuma Mori <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* Support to load Kohya-ss style LoRA file format (without restrictions)

Co-Authored-By: Takuma Mori <[email protected]>
Co-Authored-By: Sayak Paul <[email protected]>

* tmp: add sdxl to mlp_modules

---------

Co-authored-by: Takuma Mori <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Support to load Kohya-ss style LoRA file format (without restrictions)

Co-Authored-By: Takuma Mori <[email protected]>
Co-Authored-By: Sayak Paul <[email protected]>

* tmp: add sdxl to mlp_modules

---------

Co-authored-by: Takuma Mori <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Support to load Kohya-ss style LoRA file format (without restrictions)

Co-Authored-By: Takuma Mori <[email protected]>
Co-Authored-By: Sayak Paul <[email protected]>

* tmp: add sdxl to mlp_modules

---------

Co-authored-by: Takuma Mori <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

failed to use the feature of supporting for A1111 LoRA
8 participants