diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 1c6989a5e659..fd2c07e59f3f 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -305,6 +305,10 @@ image = control_pipe( image.save("output.png") ``` +## Note about `unload_lora_weights()` when using Flux LoRAs + +When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397). + ## Running FP16 inference Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f55d9958e5c3..7b7693dcfbcf 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2277,8 +2277,24 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) - # We override this here account for `_transformer_norm_layers`. - def unload_lora_weights(self): + # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. + def unload_lora_weights(self, reset_to_overwritten_params=False): + """ + Unloads the LoRA parameters. + + Args: + reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules + to their original params. Refer to the [Flux + documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ super().unload_lora_weights() transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer @@ -2286,7 +2302,7 @@ def unload_lora_weights(self): transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) transformer._transformer_norm_layers = None - if getattr(transformer, "_overwritten_params", None) is not None: + if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None: overwritten_params = transformer._overwritten_params module_names = set() diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 9fa968c47107..ace0ad6b6044 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -706,7 +706,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - control_pipe.unload_lora_weights() + control_pipe.unload_lora_weights(reset_to_overwritten_params=True) self.assertTrue( control_pipe.transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", @@ -724,6 +724,65 @@ def test_lora_unload_with_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) self.assertTrue(pipe.transformer.config.in_channels == in_features) + def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + self.assertTrue( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. + components["transformer"] = transformer + pipe = FluxPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + control_image = inputs.pop("control_image") + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + control_pipe = self.pipeline_class(**components) + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + control_pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + inputs["control_image"] = control_image + lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + control_pipe.unload_lora_weights(reset_to_overwritten_params=False) + self.assertTrue( + control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + ) + no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) + self.assertTrue(pipe.transformer.config.in_channels == in_features * 2) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass