Skip to content

Commit d9d94e1

Browse files
sayakpaula-r-r-o-w
andauthored
[LoRA] fix: lora unloading when using expanded Flux LoRAs. (#10397)
* fix: lora unloading when using expanded Flux LoRAs. * fix argument name. Co-authored-by: a-r-r-o-w <[email protected]> * docs. --------- Co-authored-by: a-r-r-o-w <[email protected]>
1 parent 2f25156 commit d9d94e1

File tree

3 files changed

+83
-4
lines changed

3 files changed

+83
-4
lines changed

Diff for: docs/source/en/api/pipelines/flux.md

+4
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ image = control_pipe(
305305
image.save("output.png")
306306
```
307307

308+
## Note about `unload_lora_weights()` when using Flux LoRAs
309+
310+
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
311+
308312
## Running FP16 inference
309313

310314
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.

Diff for: src/diffusers/loaders/lora_pipeline.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -2277,16 +2277,32 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
22772277

22782278
super().unfuse_lora(components=components)
22792279

2280-
# We override this here account for `_transformer_norm_layers`.
2281-
def unload_lora_weights(self):
2280+
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
2281+
def unload_lora_weights(self, reset_to_overwritten_params=False):
2282+
"""
2283+
Unloads the LoRA parameters.
2284+
2285+
Args:
2286+
reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
2287+
to their original params. Refer to the [Flux
2288+
documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
2289+
2290+
Examples:
2291+
2292+
```python
2293+
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
2294+
>>> pipeline.unload_lora_weights()
2295+
>>> ...
2296+
```
2297+
"""
22822298
super().unload_lora_weights()
22832299

22842300
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
22852301
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
22862302
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
22872303
transformer._transformer_norm_layers = None
22882304

2289-
if getattr(transformer, "_overwritten_params", None) is not None:
2305+
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
22902306
overwritten_params = transformer._overwritten_params
22912307
module_names = set()
22922308

Diff for: tests/lora/test_lora_layers_flux.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
706706
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
707707
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
708708

709-
control_pipe.unload_lora_weights()
709+
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
710710
self.assertTrue(
711711
control_pipe.transformer.config.in_channels == num_channels_without_control,
712712
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
@@ -724,6 +724,65 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
724724
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
725725
self.assertTrue(pipe.transformer.config.in_channels == in_features)
726726

727+
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
728+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
729+
730+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
731+
logger.setLevel(logging.DEBUG)
732+
733+
# Change the transformer config to mimic a real use case.
734+
num_channels_without_control = 4
735+
transformer = FluxTransformer2DModel.from_config(
736+
components["transformer"].config, in_channels=num_channels_without_control
737+
).to(torch_device)
738+
self.assertTrue(
739+
transformer.config.in_channels == num_channels_without_control,
740+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
741+
)
742+
743+
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
744+
components["transformer"] = transformer
745+
pipe = FluxPipeline(**components)
746+
pipe = pipe.to(torch_device)
747+
pipe.set_progress_bar_config(disable=None)
748+
749+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
750+
control_image = inputs.pop("control_image")
751+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
752+
753+
control_pipe = self.pipeline_class(**components)
754+
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
755+
rank = 4
756+
757+
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
758+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
759+
lora_state_dict = {
760+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
761+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
762+
}
763+
with CaptureLogger(logger) as cap_logger:
764+
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
765+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
766+
767+
inputs["control_image"] = control_image
768+
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
769+
770+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
771+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
772+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
773+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
774+
775+
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
776+
self.assertTrue(
777+
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
778+
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
779+
)
780+
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
781+
782+
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
783+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
784+
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
785+
727786
@unittest.skip("Not supported in Flux.")
728787
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
729788
pass

0 commit comments

Comments
 (0)