From d5ed3df8b7c17b8e7091358bc717abb539a295b0 Mon Sep 17 00:00:00 2001 From: raul_ar Date: Tue, 12 Nov 2024 16:40:07 +0100 Subject: [PATCH 1/3] Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers to fix bug on updating keys and old_state_dict --- src/diffusers/loaders/lora_conversion_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index d0ca40213b14..74a1dd3035b6 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -636,10 +636,14 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) new_key = f"transformer.single_transformer_blocks.{block_num}" - if "proj_lora1" in old_key or "proj_lora2" in old_key: + # if "proj_lora1" in old_key or "proj_lora2" in old_key: + if "proj_lora" in old_key: new_key += ".proj_out" - elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: - new_key += ".norm.linear" + # elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: + elif "qkv_lora" in old_key and "up" not in old_key: + handle_qkv(old_state_dict, new_state_dict, old_key, [ + f"transformer.single_transformer_blocks.{block_num}.norm.linear" + ]) if "down" in old_key: new_key += ".lora_A.weight" @@ -657,4 +661,4 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if len(old_state_dict) > 0: raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") - return new_state_dict + return new_state_dict \ No newline at end of file From 5bdbd5b74e59d2b269f3dc260c6a28d170f691d6 Mon Sep 17 00:00:00 2001 From: raul_ar Date: Thu, 14 Nov 2024 10:31:58 +0100 Subject: [PATCH 2/3] Remove comments and add test --- .../loaders/lora_conversion_utils.py | 2 -- tests/lora/test_lora_layers_flux.py | 22 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 74a1dd3035b6..7660e6535d86 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -636,10 +636,8 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) new_key = f"transformer.single_transformer_blocks.{block_num}" - # if "proj_lora1" in old_key or "proj_lora2" in old_key: if "proj_lora" in old_key: new_key += ".proj_out" - # elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: elif "qkv_lora" in old_key and "up" not in old_key: handle_qkv(old_state_dict, new_state_dict, old_key, [ f"transformer.single_transformer_blocks.{block_num}.norm.linear" diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b58525cc7a6f..2fb79cb986ab 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -282,3 +282,25 @@ def test_flux_xlabs(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 1e-3 + + def test_flux_xlabs_load_lora_with_single_blocks(self): + self.pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", + weight_name="lora.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline.enable_model_cpu_offload() + + prompt = "a wizard mouse playing chess" + + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=3.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 \ No newline at end of file From 703ed60e3ad9c9d9e59dac41f4bb0150076f132c Mon Sep 17 00:00:00 2001 From: raulmosa Date: Tue, 19 Nov 2024 07:20:20 +0000 Subject: [PATCH 3/3] Run style and quality --- src/diffusers/loaders/lora_conversion_utils.py | 11 +++++++---- tests/lora/test_lora_layers_flux.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 7660e6535d86..51a406b2f6a3 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -639,9 +639,12 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if "proj_lora" in old_key: new_key += ".proj_out" elif "qkv_lora" in old_key and "up" not in old_key: - handle_qkv(old_state_dict, new_state_dict, old_key, [ - f"transformer.single_transformer_blocks.{block_num}.norm.linear" - ]) + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [f"transformer.single_transformer_blocks.{block_num}.norm.linear"], + ) if "down" in old_key: new_key += ".lora_A.weight" @@ -659,4 +662,4 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if len(old_state_dict) > 0: raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") - return new_state_dict \ No newline at end of file + return new_state_dict diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 2fb79cb986ab..e6e87c7ba939 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -284,8 +284,9 @@ def test_flux_xlabs(self): assert max_diff < 1e-3 def test_flux_xlabs_load_lora_with_single_blocks(self): - self.pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", - weight_name="lora.safetensors") + self.pipeline.load_lora_weights( + "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" + ) self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() self.pipeline.enable_model_cpu_offload() @@ -300,7 +301,9 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): generator=torch.manual_seed(self.seed), ).images out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]) + expected_slice = np.array( + [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625] + ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - assert max_diff < 1e-3 \ No newline at end of file + assert max_diff < 1e-3