Skip to content

Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers #9915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 20, 2024
12 changes: 8 additions & 4 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,14 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.single_transformer_blocks.{block_num}"

if "proj_lora1" in old_key or "proj_lora2" in old_key:
# if "proj_lora1" in old_key or "proj_lora2" in old_key:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's with the comments?

Copy link
Contributor Author

@raulmosa raulmosa Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is the previous code and should be removed.
The reason for the change in this part of the code is that single blocks in Xlabs Flux LoRA do not contain "pros_lora1" or "proj_lora2", the string is "proj_lora".
See the example below of old_state_dict of a LoRA model where single blocks 1 to 4 are trained (only keys for double block 9 and single blocks are shown):

  • Double blocks keys example:
'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight',
  • Single blocks key example:
 'single_blocks.1.processor.proj_lora.down.weight', 'single_blocks.1.processor.proj_lora.up.weight', 'single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.proj_lora.down.weight', 'single_blocks.2.processor.proj_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.proj_lora.down.weight', 'single_blocks.3.processor.proj_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'

Then if we use the previous line code, single_blocks will never be updated in new_state_dict and removed from old_state_dict.

if "proj_lora" in old_key:
new_key += ".proj_out"
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
new_key += ".norm.linear"
# elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
elif "qkv_lora" in old_key and "up" not in old_key:
handle_qkv(old_state_dict, new_state_dict, old_key, [
f"transformer.single_transformer_blocks.{block_num}.norm.linear"
])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain me this change?

Copy link
Contributor Author

@raulmosa raulmosa Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, here again, I forgot to remove the commented line, it is from the previous code snippet.
Using the same example shown in the previous comment, the original old_state_dict of a LoRA model where single blocks 1 to 4 are trained (only keys for double block 9 and single blocks are shown):

  • Double block keys example:
'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight',
  • Single blocks keys example:
 'single_blocks.1.processor.proj_lora.down.weight', 'single_blocks.1.processor.proj_lora.up.weight', 'single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.proj_lora.down.weight', 'single_blocks.2.processor.proj_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.proj_lora.down.weight', 'single_blocks.3.processor.proj_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'

qkv_lora1 and qkv_lora2 are not presented in single blocks, the key is qkv_lora, then I've used the same logic and function used to handle double blocks, i.e, function handle_qkv used to update the new_state_dict and remove the keys from old_state_dict. Then, in the last part of the code:

# Since we already handle qkv above.
        if "qkv" not in old_key:
            new_state_dict[new_key] = old_state_dict.pop(old_key)

    if len(old_state_dict) > 0:
        raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

All "qkv" for double and single blocks are handled and ValueError is not raised.


if "down" in old_key:
new_key += ".lora_A.weight"
Expand All @@ -657,4 +661,4 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if len(old_state_dict) > 0:
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

return new_state_dict
return new_state_dict