Skip to content

Commit 44640c8

Browse files
maxs-kanhlkysayakpaul
authored
Fix Flux multiple Lora loading bug (#10388)
* check for base_layer key in transformer state dict * test_lora_expansion_works_for_absent_keys * check * Update tests/lora/test_lora_layers_flux.py Co-authored-by: Sayak Paul <[email protected]> * check * test_lora_expansion_works_for_absent_keys/test_lora_expansion_works_for_extra_keys * absent->extra --------- Co-authored-by: hlky <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 4b9f1c7 commit 44640c8

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

src/diffusers/loaders/lora_pipeline.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2466,7 +2466,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24662466
continue
24672467

24682468
base_param_name = (
2469-
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
2469+
f"{k.replace(prefix, '')}.base_layer.weight"
2470+
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
2471+
else f"{k.replace(prefix, '')}.weight"
24702472
)
24712473
base_weight_param = transformer_state_dict[base_param_name]
24722474
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

tests/lora/test_lora_layers_flux.py

+100
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
import gc
1617
import os
1718
import sys
@@ -162,6 +163,105 @@ def test_with_alpha_in_state_dict(self):
162163
)
163164
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
164165

166+
def test_lora_expansion_works_for_absent_keys(self):
167+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
168+
pipe = self.pipeline_class(**components)
169+
pipe = pipe.to(torch_device)
170+
pipe.set_progress_bar_config(disable=None)
171+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
172+
173+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
174+
self.assertTrue(output_no_lora.shape == self.output_shape)
175+
176+
# Modify the config to have a layer which won't be present in the second LoRA we will load.
177+
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
178+
modified_denoiser_lora_config.target_modules.add("x_embedder")
179+
180+
pipe.transformer.add_adapter(modified_denoiser_lora_config)
181+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
182+
183+
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
184+
self.assertFalse(
185+
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
186+
"LoRA should lead to different results.",
187+
)
188+
189+
with tempfile.TemporaryDirectory() as tmpdirname:
190+
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
191+
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
192+
193+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
194+
pipe.unload_lora_weights()
195+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
196+
197+
# Modify the state dict to exclude "x_embedder" related LoRA params.
198+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
199+
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
200+
201+
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
202+
pipe.set_adapters(["one", "two"])
203+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
204+
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
205+
206+
self.assertFalse(
207+
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
208+
"Different LoRAs should lead to different results.",
209+
)
210+
self.assertFalse(
211+
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
212+
"LoRA should lead to different results.",
213+
)
214+
215+
def test_lora_expansion_works_for_extra_keys(self):
216+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
217+
pipe = self.pipeline_class(**components)
218+
pipe = pipe.to(torch_device)
219+
pipe.set_progress_bar_config(disable=None)
220+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
221+
222+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
223+
self.assertTrue(output_no_lora.shape == self.output_shape)
224+
225+
# Modify the config to have a layer which won't be present in the first LoRA we will load.
226+
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
227+
modified_denoiser_lora_config.target_modules.add("x_embedder")
228+
229+
pipe.transformer.add_adapter(modified_denoiser_lora_config)
230+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
231+
232+
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
233+
self.assertFalse(
234+
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
235+
"LoRA should lead to different results.",
236+
)
237+
238+
with tempfile.TemporaryDirectory() as tmpdirname:
239+
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
240+
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
241+
242+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
243+
pipe.unload_lora_weights()
244+
# Modify the state dict to exclude "x_embedder" related LoRA params.
245+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
246+
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
247+
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
248+
249+
# Load state dict with `x_embedder`.
250+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
251+
252+
pipe.set_adapters(["one", "two"])
253+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
254+
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
255+
256+
self.assertFalse(
257+
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
258+
"Different LoRAs should lead to different results.",
259+
)
260+
self.assertFalse(
261+
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
262+
"LoRA should lead to different results.",
263+
)
264+
165265
@unittest.skip("Not supported in Flux.")
166266
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
167267
pass

0 commit comments

Comments
 (0)