Skip to content

Commit 580a6d5

Browse files
committed
feat: lora support for SANA.
1 parent 3bf5400 commit 580a6d5

File tree

7 files changed

+517
-5
lines changed

7 files changed

+517
-5
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def text_encoder_attn_modules(text_encoder):
6969
"FluxLoraLoaderMixin",
7070
"CogVideoXLoraLoaderMixin",
7171
"Mochi1LoraLoaderMixin",
72+
"SanaLoraLoaderMixin",
7273
]
7374
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7475
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -90,6 +91,7 @@ def text_encoder_attn_modules(text_encoder):
9091
FluxLoraLoaderMixin,
9192
LoraLoaderMixin,
9293
Mochi1LoraLoaderMixin,
94+
SanaLoraLoaderMixin,
9395
SD3LoraLoaderMixin,
9496
StableDiffusionLoraLoaderMixin,
9597
StableDiffusionXLLoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3254,6 +3254,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
32543254
super().unfuse_lora(components=components)
32553255

32563256

3257+
class SanaLoraLoaderMixin(LoraBaseMixin):
3258+
r"""
3259+
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
3260+
"""
3261+
3262+
_lora_loadable_modules = ["transformer"]
3263+
transformer_name = TRANSFORMER_NAME
3264+
3265+
@classmethod
3266+
@validate_hf_hub_args
3267+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3268+
def lora_state_dict(
3269+
cls,
3270+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3271+
**kwargs,
3272+
):
3273+
r"""
3274+
Return state dict for lora weights and the network alphas.
3275+
3276+
<Tip warning={true}>
3277+
3278+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3279+
3280+
This function is experimental and might change in the future.
3281+
3282+
</Tip>
3283+
3284+
Parameters:
3285+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3286+
Can be either:
3287+
3288+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3289+
the Hub.
3290+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3291+
with [`ModelMixin.save_pretrained`].
3292+
- A [torch state
3293+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3294+
3295+
cache_dir (`Union[str, os.PathLike]`, *optional*):
3296+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3297+
is not used.
3298+
force_download (`bool`, *optional*, defaults to `False`):
3299+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3300+
cached versions if they exist.
3301+
3302+
proxies (`Dict[str, str]`, *optional*):
3303+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3304+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3305+
local_files_only (`bool`, *optional*, defaults to `False`):
3306+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
3307+
won't be downloaded from the Hub.
3308+
token (`str` or *bool*, *optional*):
3309+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3310+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
3311+
revision (`str`, *optional*, defaults to `"main"`):
3312+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3313+
allowed by Git.
3314+
subfolder (`str`, *optional*, defaults to `""`):
3315+
The subfolder location of a model file within a larger model repository on the Hub or locally.
3316+
3317+
"""
3318+
# Load the main state dict first which has the LoRA layers for either of
3319+
# transformer and text encoder or both.
3320+
cache_dir = kwargs.pop("cache_dir", None)
3321+
force_download = kwargs.pop("force_download", False)
3322+
proxies = kwargs.pop("proxies", None)
3323+
local_files_only = kwargs.pop("local_files_only", None)
3324+
token = kwargs.pop("token", None)
3325+
revision = kwargs.pop("revision", None)
3326+
subfolder = kwargs.pop("subfolder", None)
3327+
weight_name = kwargs.pop("weight_name", None)
3328+
use_safetensors = kwargs.pop("use_safetensors", None)
3329+
3330+
allow_pickle = False
3331+
if use_safetensors is None:
3332+
use_safetensors = True
3333+
allow_pickle = True
3334+
3335+
user_agent = {
3336+
"file_type": "attn_procs_weights",
3337+
"framework": "pytorch",
3338+
}
3339+
3340+
state_dict = _fetch_state_dict(
3341+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3342+
weight_name=weight_name,
3343+
use_safetensors=use_safetensors,
3344+
local_files_only=local_files_only,
3345+
cache_dir=cache_dir,
3346+
force_download=force_download,
3347+
proxies=proxies,
3348+
token=token,
3349+
revision=revision,
3350+
subfolder=subfolder,
3351+
user_agent=user_agent,
3352+
allow_pickle=allow_pickle,
3353+
)
3354+
3355+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3356+
if is_dora_scale_present:
3357+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3358+
logger.warning(warn_msg)
3359+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3360+
3361+
return state_dict
3362+
3363+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3364+
def load_lora_weights(
3365+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3366+
):
3367+
"""
3368+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3369+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3370+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3371+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3372+
dict is loaded into `self.transformer`.
3373+
3374+
Parameters:
3375+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3376+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3377+
adapter_name (`str`, *optional*):
3378+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3379+
`default_{i}` where i is the total number of adapters being loaded.
3380+
low_cpu_mem_usage (`bool`, *optional*):
3381+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3382+
weights.
3383+
kwargs (`dict`, *optional*):
3384+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3385+
"""
3386+
if not USE_PEFT_BACKEND:
3387+
raise ValueError("PEFT backend is required for this method.")
3388+
3389+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3390+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3391+
raise ValueError(
3392+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3393+
)
3394+
3395+
# if a dict is passed, copy it instead of modifying it inplace
3396+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
3397+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3398+
3399+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3400+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3401+
3402+
is_correct_format = all("lora" in key for key in state_dict.keys())
3403+
if not is_correct_format:
3404+
raise ValueError("Invalid LoRA checkpoint.")
3405+
3406+
self.load_lora_into_transformer(
3407+
state_dict,
3408+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3409+
adapter_name=adapter_name,
3410+
_pipeline=self,
3411+
low_cpu_mem_usage=low_cpu_mem_usage,
3412+
)
3413+
3414+
@classmethod
3415+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
3416+
def load_lora_into_transformer(
3417+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3418+
):
3419+
"""
3420+
This will load the LoRA layers specified in `state_dict` into `transformer`.
3421+
3422+
Parameters:
3423+
state_dict (`dict`):
3424+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3425+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3426+
encoder lora layers.
3427+
transformer (`CogVideoXTransformer3DModel`):
3428+
The Transformer model to load the LoRA layers into.
3429+
adapter_name (`str`, *optional*):
3430+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3431+
`default_{i}` where i is the total number of adapters being loaded.
3432+
low_cpu_mem_usage (`bool`, *optional*):
3433+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3434+
weights.
3435+
"""
3436+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3437+
raise ValueError(
3438+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3439+
)
3440+
3441+
# Load the layers corresponding to transformer.
3442+
logger.info(f"Loading {cls.transformer_name}.")
3443+
transformer.load_lora_adapter(
3444+
state_dict,
3445+
network_alphas=None,
3446+
adapter_name=adapter_name,
3447+
_pipeline=_pipeline,
3448+
low_cpu_mem_usage=low_cpu_mem_usage,
3449+
)
3450+
3451+
@classmethod
3452+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3453+
def save_lora_weights(
3454+
cls,
3455+
save_directory: Union[str, os.PathLike],
3456+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3457+
is_main_process: bool = True,
3458+
weight_name: str = None,
3459+
save_function: Callable = None,
3460+
safe_serialization: bool = True,
3461+
):
3462+
r"""
3463+
Save the LoRA parameters corresponding to the UNet and text encoder.
3464+
3465+
Arguments:
3466+
save_directory (`str` or `os.PathLike`):
3467+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
3468+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3469+
State dict of the LoRA layers corresponding to the `transformer`.
3470+
is_main_process (`bool`, *optional*, defaults to `True`):
3471+
Whether the process calling this is the main process or not. Useful during distributed training and you
3472+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3473+
process to avoid race conditions.
3474+
save_function (`Callable`):
3475+
The function to use to save the state dictionary. Useful during distributed training when you need to
3476+
replace `torch.save` with another method. Can be configured with the environment variable
3477+
`DIFFUSERS_SAVE_MODE`.
3478+
safe_serialization (`bool`, *optional*, defaults to `True`):
3479+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3480+
"""
3481+
state_dict = {}
3482+
3483+
if not transformer_lora_layers:
3484+
raise ValueError("You must pass `transformer_lora_layers`.")
3485+
3486+
if transformer_lora_layers:
3487+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3488+
3489+
# Save the model
3490+
cls.write_lora_layers(
3491+
state_dict=state_dict,
3492+
save_directory=save_directory,
3493+
is_main_process=is_main_process,
3494+
weight_name=weight_name,
3495+
save_function=save_function,
3496+
safe_serialization=safe_serialization,
3497+
)
3498+
3499+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3500+
def fuse_lora(
3501+
self,
3502+
components: List[str] = ["transformer", "text_encoder"],
3503+
lora_scale: float = 1.0,
3504+
safe_fusing: bool = False,
3505+
adapter_names: Optional[List[str]] = None,
3506+
**kwargs,
3507+
):
3508+
r"""
3509+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3510+
3511+
<Tip warning={true}>
3512+
3513+
This is an experimental API.
3514+
3515+
</Tip>
3516+
3517+
Args:
3518+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3519+
lora_scale (`float`, defaults to 1.0):
3520+
Controls how much to influence the outputs with the LoRA parameters.
3521+
safe_fusing (`bool`, defaults to `False`):
3522+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3523+
adapter_names (`List[str]`, *optional*):
3524+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3525+
3526+
Example:
3527+
3528+
```py
3529+
from diffusers import DiffusionPipeline
3530+
import torch
3531+
3532+
pipeline = DiffusionPipeline.from_pretrained(
3533+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3534+
).to("cuda")
3535+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3536+
pipeline.fuse_lora(lora_scale=0.7)
3537+
```
3538+
"""
3539+
super().fuse_lora(
3540+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3541+
)
3542+
3543+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3544+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3545+
r"""
3546+
Reverses the effect of
3547+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3548+
3549+
<Tip warning={true}>
3550+
3551+
This is an experimental API.
3552+
3553+
</Tip>
3554+
3555+
Args:
3556+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3557+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3558+
unfuse_text_encoder (`bool`, defaults to `True`):
3559+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3560+
LoRA parameters then it won't have any effect.
3561+
"""
3562+
super().unfuse_lora(components=components)
3563+
3564+
32573565
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
32583566
def __init__(self, *args, **kwargs):
32593567
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"FluxTransformer2DModel": lambda model_cls, weights: weights,
5454
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
5555
"MochiTransformer3DModel": lambda model_cls, weights: weights,
56+
"SanaTransformer2DModel": lambda model_cls, weights: weights,
5657
}
5758

5859

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from torch import nn
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
21-
from ...utils import is_torch_version, logging
21+
from ...loaders import PeftAdapterMixin
22+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2223
from ..attention_processor import (
2324
Attention,
2425
AttentionProcessor,
@@ -180,7 +181,7 @@ def forward(
180181
return hidden_states
181182

182183

183-
class SanaTransformer2DModel(ModelMixin, ConfigMixin):
184+
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
184185
r"""
185186
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
186187
@@ -363,8 +364,24 @@ def forward(
363364
timestep: torch.LongTensor,
364365
encoder_attention_mask: Optional[torch.Tensor] = None,
365366
attention_mask: Optional[torch.Tensor] = None,
367+
attention_kwargs: Optional[Dict[str, Any]] = None,
366368
return_dict: bool = True,
367369
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
370+
if attention_kwargs is not None:
371+
attention_kwargs = attention_kwargs.copy()
372+
lora_scale = attention_kwargs.pop("scale", 1.0)
373+
else:
374+
lora_scale = 1.0
375+
376+
if USE_PEFT_BACKEND:
377+
# weight the lora layers by setting `lora_scale` for each PEFT layer
378+
scale_lora_layers(self, lora_scale)
379+
else:
380+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
381+
logger.warning(
382+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
383+
)
384+
368385
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
369386
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
370387
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -460,6 +477,11 @@ def custom_forward(*inputs):
460477
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
461478
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
462479

480+
if USE_PEFT_BACKEND:
481+
# remove `lora_scale` from each PEFT layer
482+
unscale_lora_layers(self, lora_scale)
483+
463484
if not return_dict:
464485
return (output,)
486+
465487
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)