-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Load Kohya-ss style LoRAs with auxilary states #4147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load Kohya-ss style LoRAs with auxilary states #4147
Conversation
CC: @sayakpaul @takuma104 (sent an invite to my diffusers fork for both of you, in case you might want to push changes) |
The documentation is not available anymore as the PR was closed or merged. |
873d04a
to
b90a44e
Compare
b90a44e
to
4048fb1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much, @isidentical I left some initial comments.
Let me run some qualitative tests and get back to you.
I think we can add some tests to ensure things are robust.
Quick update: I ran the tests with RUN_SLOW=1 pytest tests/models/test_lora_layers.py But it fails |
Left some discussions here: isidentical#1. Overall, I think this PR is in a good state already, except for this issue: isidentical#1 (comment). So, to summarize, the following TODOs are remaining:
Then in a future PR we can tackle SDXL LoRA support, what say? |
isidentical#1 should fix all of these :) |
Thanks a lot for the feedback @sayakpaul, I'll address them today will let you know the for second round of PR reviews! |
f778d91
to
c72fb16
Compare
@sayakpaul I just noticed that loading+unloading+loading LoRAs doesn't work with the new unet unpatching system, since setting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool! Thanks for iterating here @isidentical
@williamberman could you take a final look? :-)
src/diffusers/loaders.py
Outdated
for name, _ in text_encoder_mlp_modules(text_encoder): | ||
for direction in ["up", "down"]: | ||
for layer in ["fc1", "fc2"]: | ||
original_key = f"{name}.{layer}.lora.{direction}.weight" | ||
replacement_key = f"{name}.{layer}.lora_linear_layer.{direction}.weight" | ||
if original_key in text_encoder_lora_state_dict: | ||
text_encoder_lora_state_dict[replacement_key] = text_encoder_lora_state_dict.pop( | ||
original_key | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these layers are newly supported in this PR why would they be added to the legacy naming convention conversion code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with the legacy convention because the rest of the conversion logic used the legacy convention (which made it match the locality of _convert_kohya_ss_state_dict
) but I also understand this is a new addition to the legacy part of the codebase so I'll try to directly write it in the new convention with a brief comment about why it doesn't match the actual function (it should be refactored separately I think, maybe even after the proposed #4247).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, does _convert_kohya_ss_state_dict currently return keys in the legacy naming convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, does _convert_kohya_ss_state_dict currently return keys in the legacy naming convention?
It does, actually, IMO. This is why integration tests also pass. Otherwise, there would have been assertion problems I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I think we should follow up then and update the rest of _convert_kohya_ss_state_dict to use the new naming convention
Few brief questions but looks really good! |
7a441d4
to
424a344
Compare
Taking care of the final nits:
@isidentical will drop a PR to your branch and then let's go :) |
@isidentical I added a comment to elaborate the load + unload+ load test and also changed its name to reflect it's using a Kohya-style checkpoint. I think you (as the owner of the repo) can only rebase the Also, just for sanity, the SLOW tests are passing. @williamberman could you give this one final look? |
Co-Authored-By: Takuma Mori <[email protected]> Co-Authored-By: Sayak Paul <[email protected]>
d05a8d2
to
66a775a
Compare
@sayakpaul done! Let's get this landed 🚀 |
@@ -1092,8 +1139,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr | |||
rank = text_encoder_lora_state_dict[ | |||
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" | |||
].shape[1] | |||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok with this for now, but it would be really nice to avoid this if possible in the future. I think the meta point to think about here is once we have checks like this at any point inside the code, we have to now consider what are implications for any state dict checking or changing code any time we touch a model definition.
This part specifically is a model definition from a separate library which is even more hairy to be checking. We're lucky that we know the specific way transformers is written is that they very rarely change model definitions once they're written, but in general that's not something that we should rely on.
I think a good analogue is consider applications on your computer that serialize their state as locally stored files. That's all state dicts are, an application serialization format. Almost all applications will say, you should not make any assumptions about the format or make modifications to our files we store. If they do say files are user editable, they're usually very explicitly documented where as our state dict formats are implicitly documented through a combination of code in different libraries and how diffusers elects to monkey patch updated model definitions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Amazing job here @isidentical - merging! |
* Support to load Kohya-ss style LoRA file format (without restrictions) Co-Authored-By: Takuma Mori <[email protected]> Co-Authored-By: Sayak Paul <[email protected]> * tmp: add sdxl to mlp_modules --------- Co-authored-by: Takuma Mori <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
* Support to load Kohya-ss style LoRA file format (without restrictions) Co-Authored-By: Takuma Mori <[email protected]> Co-Authored-By: Sayak Paul <[email protected]> * tmp: add sdxl to mlp_modules --------- Co-authored-by: Takuma Mori <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
* Support to load Kohya-ss style LoRA file format (without restrictions) Co-Authored-By: Takuma Mori <[email protected]> Co-Authored-By: Sayak Paul <[email protected]> * tmp: add sdxl to mlp_modules --------- Co-authored-by: Takuma Mori <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
* Support to load Kohya-ss style LoRA file format (without restrictions) Co-Authored-By: Takuma Mori <[email protected]> Co-Authored-By: Sayak Paul <[email protected]> * tmp: add sdxl to mlp_modules --------- Co-authored-by: Takuma Mori <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
* Support to load Kohya-ss style LoRA file format (without restrictions) Co-Authored-By: Takuma Mori <[email protected]> Co-Authored-By: Sayak Paul <[email protected]> * tmp: add sdxl to mlp_modules --------- Co-authored-by: Takuma Mori <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
This PR is a revival of the original #3756 (as discussed in #3756 (comment)). Majority of the credits goes to @takuma104 on making the original PR as well as explaining the impacts/changes in detail. On top of that work, I've moved the text encoder patching parts to a modular notion (as done in #3778), added support for
unload_loras()
with the auxiliary parts, as well as general cleanups / test fixes.Fixes #3725
Examples
On a LoRA model for a random pet, the current revision of diffusers produces a related although not super precise image:

And with this PR, it now looks like the original "pet" that the LoRA trained with:
