@@ -3254,6 +3254,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
3254
3254
super ().unfuse_lora (components = components )
3255
3255
3256
3256
3257
+ class SanaLoraLoaderMixin (LoraBaseMixin ):
3258
+ r"""
3259
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
3260
+ """
3261
+
3262
+ _lora_loadable_modules = ["transformer" ]
3263
+ transformer_name = TRANSFORMER_NAME
3264
+
3265
+ @classmethod
3266
+ @validate_hf_hub_args
3267
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3268
+ def lora_state_dict (
3269
+ cls ,
3270
+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3271
+ ** kwargs ,
3272
+ ):
3273
+ r"""
3274
+ Return state dict for lora weights and the network alphas.
3275
+
3276
+ <Tip warning={true}>
3277
+
3278
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3279
+
3280
+ This function is experimental and might change in the future.
3281
+
3282
+ </Tip>
3283
+
3284
+ Parameters:
3285
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3286
+ Can be either:
3287
+
3288
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3289
+ the Hub.
3290
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3291
+ with [`ModelMixin.save_pretrained`].
3292
+ - A [torch state
3293
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3294
+
3295
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3296
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3297
+ is not used.
3298
+ force_download (`bool`, *optional*, defaults to `False`):
3299
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3300
+ cached versions if they exist.
3301
+
3302
+ proxies (`Dict[str, str]`, *optional*):
3303
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3304
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3305
+ local_files_only (`bool`, *optional*, defaults to `False`):
3306
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3307
+ won't be downloaded from the Hub.
3308
+ token (`str` or *bool*, *optional*):
3309
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3310
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3311
+ revision (`str`, *optional*, defaults to `"main"`):
3312
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3313
+ allowed by Git.
3314
+ subfolder (`str`, *optional*, defaults to `""`):
3315
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3316
+
3317
+ """
3318
+ # Load the main state dict first which has the LoRA layers for either of
3319
+ # transformer and text encoder or both.
3320
+ cache_dir = kwargs .pop ("cache_dir" , None )
3321
+ force_download = kwargs .pop ("force_download" , False )
3322
+ proxies = kwargs .pop ("proxies" , None )
3323
+ local_files_only = kwargs .pop ("local_files_only" , None )
3324
+ token = kwargs .pop ("token" , None )
3325
+ revision = kwargs .pop ("revision" , None )
3326
+ subfolder = kwargs .pop ("subfolder" , None )
3327
+ weight_name = kwargs .pop ("weight_name" , None )
3328
+ use_safetensors = kwargs .pop ("use_safetensors" , None )
3329
+
3330
+ allow_pickle = False
3331
+ if use_safetensors is None :
3332
+ use_safetensors = True
3333
+ allow_pickle = True
3334
+
3335
+ user_agent = {
3336
+ "file_type" : "attn_procs_weights" ,
3337
+ "framework" : "pytorch" ,
3338
+ }
3339
+
3340
+ state_dict = _fetch_state_dict (
3341
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3342
+ weight_name = weight_name ,
3343
+ use_safetensors = use_safetensors ,
3344
+ local_files_only = local_files_only ,
3345
+ cache_dir = cache_dir ,
3346
+ force_download = force_download ,
3347
+ proxies = proxies ,
3348
+ token = token ,
3349
+ revision = revision ,
3350
+ subfolder = subfolder ,
3351
+ user_agent = user_agent ,
3352
+ allow_pickle = allow_pickle ,
3353
+ )
3354
+
3355
+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
3356
+ if is_dora_scale_present :
3357
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3358
+ logger .warning (warn_msg )
3359
+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
3360
+
3361
+ return state_dict
3362
+
3363
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3364
+ def load_lora_weights (
3365
+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3366
+ ):
3367
+ """
3368
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3369
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3370
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3371
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3372
+ dict is loaded into `self.transformer`.
3373
+
3374
+ Parameters:
3375
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3376
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3377
+ adapter_name (`str`, *optional*):
3378
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3379
+ `default_{i}` where i is the total number of adapters being loaded.
3380
+ low_cpu_mem_usage (`bool`, *optional*):
3381
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3382
+ weights.
3383
+ kwargs (`dict`, *optional*):
3384
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3385
+ """
3386
+ if not USE_PEFT_BACKEND :
3387
+ raise ValueError ("PEFT backend is required for this method." )
3388
+
3389
+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
3390
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3391
+ raise ValueError (
3392
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3393
+ )
3394
+
3395
+ # if a dict is passed, copy it instead of modifying it inplace
3396
+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
3397
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
3398
+
3399
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3400
+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3401
+
3402
+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
3403
+ if not is_correct_format :
3404
+ raise ValueError ("Invalid LoRA checkpoint." )
3405
+
3406
+ self .load_lora_into_transformer (
3407
+ state_dict ,
3408
+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
3409
+ adapter_name = adapter_name ,
3410
+ _pipeline = self ,
3411
+ low_cpu_mem_usage = low_cpu_mem_usage ,
3412
+ )
3413
+
3414
+ @classmethod
3415
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
3416
+ def load_lora_into_transformer (
3417
+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
3418
+ ):
3419
+ """
3420
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3421
+
3422
+ Parameters:
3423
+ state_dict (`dict`):
3424
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3425
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3426
+ encoder lora layers.
3427
+ transformer (`CogVideoXTransformer3DModel`):
3428
+ The Transformer model to load the LoRA layers into.
3429
+ adapter_name (`str`, *optional*):
3430
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3431
+ `default_{i}` where i is the total number of adapters being loaded.
3432
+ low_cpu_mem_usage (`bool`, *optional*):
3433
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3434
+ weights.
3435
+ """
3436
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3437
+ raise ValueError (
3438
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3439
+ )
3440
+
3441
+ # Load the layers corresponding to transformer.
3442
+ logger .info (f"Loading { cls .transformer_name } ." )
3443
+ transformer .load_lora_adapter (
3444
+ state_dict ,
3445
+ network_alphas = None ,
3446
+ adapter_name = adapter_name ,
3447
+ _pipeline = _pipeline ,
3448
+ low_cpu_mem_usage = low_cpu_mem_usage ,
3449
+ )
3450
+
3451
+ @classmethod
3452
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3453
+ def save_lora_weights (
3454
+ cls ,
3455
+ save_directory : Union [str , os .PathLike ],
3456
+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
3457
+ is_main_process : bool = True ,
3458
+ weight_name : str = None ,
3459
+ save_function : Callable = None ,
3460
+ safe_serialization : bool = True ,
3461
+ ):
3462
+ r"""
3463
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3464
+
3465
+ Arguments:
3466
+ save_directory (`str` or `os.PathLike`):
3467
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3468
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3469
+ State dict of the LoRA layers corresponding to the `transformer`.
3470
+ is_main_process (`bool`, *optional*, defaults to `True`):
3471
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3472
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3473
+ process to avoid race conditions.
3474
+ save_function (`Callable`):
3475
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3476
+ replace `torch.save` with another method. Can be configured with the environment variable
3477
+ `DIFFUSERS_SAVE_MODE`.
3478
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3479
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3480
+ """
3481
+ state_dict = {}
3482
+
3483
+ if not transformer_lora_layers :
3484
+ raise ValueError ("You must pass `transformer_lora_layers`." )
3485
+
3486
+ if transformer_lora_layers :
3487
+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
3488
+
3489
+ # Save the model
3490
+ cls .write_lora_layers (
3491
+ state_dict = state_dict ,
3492
+ save_directory = save_directory ,
3493
+ is_main_process = is_main_process ,
3494
+ weight_name = weight_name ,
3495
+ save_function = save_function ,
3496
+ safe_serialization = safe_serialization ,
3497
+ )
3498
+
3499
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3500
+ def fuse_lora (
3501
+ self ,
3502
+ components : List [str ] = ["transformer" , "text_encoder" ],
3503
+ lora_scale : float = 1.0 ,
3504
+ safe_fusing : bool = False ,
3505
+ adapter_names : Optional [List [str ]] = None ,
3506
+ ** kwargs ,
3507
+ ):
3508
+ r"""
3509
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3510
+
3511
+ <Tip warning={true}>
3512
+
3513
+ This is an experimental API.
3514
+
3515
+ </Tip>
3516
+
3517
+ Args:
3518
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3519
+ lora_scale (`float`, defaults to 1.0):
3520
+ Controls how much to influence the outputs with the LoRA parameters.
3521
+ safe_fusing (`bool`, defaults to `False`):
3522
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3523
+ adapter_names (`List[str]`, *optional*):
3524
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3525
+
3526
+ Example:
3527
+
3528
+ ```py
3529
+ from diffusers import DiffusionPipeline
3530
+ import torch
3531
+
3532
+ pipeline = DiffusionPipeline.from_pretrained(
3533
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3534
+ ).to("cuda")
3535
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3536
+ pipeline.fuse_lora(lora_scale=0.7)
3537
+ ```
3538
+ """
3539
+ super ().fuse_lora (
3540
+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
3541
+ )
3542
+
3543
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3544
+ def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" ], ** kwargs ):
3545
+ r"""
3546
+ Reverses the effect of
3547
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3548
+
3549
+ <Tip warning={true}>
3550
+
3551
+ This is an experimental API.
3552
+
3553
+ </Tip>
3554
+
3555
+ Args:
3556
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3557
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3558
+ unfuse_text_encoder (`bool`, defaults to `True`):
3559
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3560
+ LoRA parameters then it won't have any effect.
3561
+ """
3562
+ super ().unfuse_lora (components = components )
3563
+
3564
+
3257
3565
class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
3258
3566
def __init__ (self , * args , ** kwargs ):
3259
3567
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments