Skip to content

fix: missing AutoencoderKL lora adapter #9807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
Expand All @@ -34,7 +35,7 @@
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder


class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

Expand Down
38 changes: 38 additions & 0 deletions tests/models/autoencoders/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
backend_empty_cache,
enable_full_determinism,
floats_tensor,
is_peft_available,
load_hf_numpy,
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
Expand All @@ -50,6 +52,10 @@
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin


if is_peft_available():
from peft import LoraConfig


enable_full_determinism()


Expand Down Expand Up @@ -263,6 +269,38 @@ def test_output_pretrained(self):

self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))

@require_peft_backend
def test_lora_adapter(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be decorated with:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@beniz seems like this was not resolved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul ah apologies, may have forgot to push to repo. Done. Thanks for your vigilance.

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
vae = self.model_class(**init_dict)

target_modules_vae = [
"conv1",
"conv2",
"conv_in",
"conv_shortcut",
"conv",
"conv_out",
"skip_conv_1",
"skip_conv_2",
"skip_conv_3",
"skip_conv_4",
"to_k",
"to_q",
"to_v",
"to_out.0",
]
vae_lora_config = LoraConfig(
r=16,
init_lora_weights="gaussian",
target_modules=target_modules_vae,
)

vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
active_lora = vae.active_adapters()
self.assertTrue(len(active_lora) == 1)
self.assertTrue(active_lora[0] == "vae_lora")


class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
Expand Down
Loading