-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers #9915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d5ed3df
d595416
5bdbd5b
435e55a
be3e4a8
703ed60
801a0dc
2a8a0b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -636,10 +636,14 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): | |
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) | ||
new_key = f"transformer.single_transformer_blocks.{block_num}" | ||
|
||
if "proj_lora1" in old_key or "proj_lora2" in old_key: | ||
# if "proj_lora1" in old_key or "proj_lora2" in old_key: | ||
if "proj_lora" in old_key: | ||
new_key += ".proj_out" | ||
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: | ||
new_key += ".norm.linear" | ||
# elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: | ||
elif "qkv_lora" in old_key and "up" not in old_key: | ||
handle_qkv(old_state_dict, new_state_dict, old_key, [ | ||
f"transformer.single_transformer_blocks.{block_num}.norm.linear" | ||
]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain me this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, here again, I forgot to remove the commented line, it is from the previous code snippet.
qkv_lora1 and qkv_lora2 are not presented in single blocks, the key is qkv_lora, then I've used the same logic and function used to handle double blocks, i.e, function handle_qkv used to update the new_state_dict and remove the keys from old_state_dict. Then, in the last part of the code:
All "qkv" for double and single blocks are handled and ValueError is not raised. |
||
|
||
if "down" in old_key: | ||
new_key += ".lora_A.weight" | ||
|
@@ -657,4 +661,4 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): | |
if len(old_state_dict) > 0: | ||
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") | ||
|
||
return new_state_dict | ||
return new_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's with the comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is the previous code and should be removed.
The reason for the change in this part of the code is that single blocks in Xlabs Flux LoRA do not contain "pros_lora1" or "proj_lora2", the string is "proj_lora".
See the example below of old_state_dict of a LoRA model where single blocks 1 to 4 are trained (only keys for double block 9 and single blocks are shown):
Then if we use the previous line code, single_blocks will never be updated in new_state_dict and removed from old_state_dict.