@@ -679,6 +679,55 @@ def _unfuse_lora_apply(self, module):
679
679
if hasattr (module , "_unfuse_lora" ):
680
680
module ._unfuse_lora ()
681
681
682
+ def set_adapters (
683
+ self ,
684
+ adapter_names : Union [List [str ], str ],
685
+ weights : List [float ] = None ,
686
+ ):
687
+ """
688
+ Sets the adapter layers for the unet.
689
+
690
+ Args:
691
+ adapter_names (`List[str]` or `str`):
692
+ The names of the adapters to use.
693
+ weights (`List[float]`, *optional*):
694
+ The weights to use for the unet. If `None`, the weights are set to `1.0` for all the adapters.
695
+ """
696
+ if not self .use_peft_backend :
697
+ raise ValueError ("PEFT backend is required for this method." )
698
+
699
+ def process_weights (adapter_names , weights ):
700
+ if weights is None :
701
+ weights = [1.0 ] * len (adapter_names )
702
+ elif isinstance (weights , float ):
703
+ weights = [weights ]
704
+
705
+ if len (adapter_names ) != len (weights ):
706
+ raise ValueError (
707
+ f"Length of adapter names { len (adapter_names )} is not equal to the length of the weights { len (weights )} "
708
+ )
709
+ return weights
710
+
711
+ adapter_names = [adapter_names ] if isinstance (adapter_names , str ) else adapter_names
712
+ weights = process_weights (adapter_names , weights )
713
+ set_weights_and_activate_adapters (self , adapter_names , weights )
714
+
715
+ def disable_lora (self ):
716
+ """
717
+ Disables the LoRA layers for the unet.
718
+ """
719
+ if not self .use_peft_backend :
720
+ raise ValueError ("PEFT backend is required for this method." )
721
+ set_adapter_layers (self , enabled = False )
722
+
723
+ def enable_lora (self ):
724
+ """
725
+ Enables the LoRA layers for the unet.
726
+ """
727
+ if not self .use_peft_backend :
728
+ raise ValueError ("PEFT backend is required for this method." )
729
+ set_adapter_layers (self , enabled = True )
730
+
682
731
683
732
def load_textual_inversion_state_dicts (pretrained_model_name_or_paths , ** kwargs ):
684
733
cache_dir = kwargs .pop ("cache_dir" , DIFFUSERS_CACHE )
@@ -1448,7 +1497,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
1448
1497
1449
1498
@classmethod
1450
1499
def load_lora_into_unet (
1451
- cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None , adapter_name = "default"
1500
+ cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None , adapter_name = None
1452
1501
):
1453
1502
"""
1454
1503
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1468,7 +1517,8 @@ def load_lora_into_unet(
1468
1517
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1469
1518
argument to `True` will raise an error.
1470
1519
adapter_name (`str`, *optional*):
1471
- The name of the adapter to load the weights into. By default we use `"default"`
1520
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1521
+ `default_{i}` where i is the total number of adapters being loaded.
1472
1522
"""
1473
1523
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
1474
1524
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1500,38 +1550,19 @@ def load_lora_into_unet(
1500
1550
1501
1551
state_dict = convert_unet_state_dict_to_peft (state_dict )
1502
1552
1503
- target_modules = []
1504
- ranks = []
1553
+ rank = {}
1505
1554
for key in state_dict .keys ():
1506
- # filter out the name
1507
- filtered_name = "." .join (key .split ("." )[:- 2 ])
1508
- target_modules .append (filtered_name )
1509
1555
if "lora_B" in key :
1510
- rank = state_dict [key ].shape [1 ]
1511
- ranks .append (rank )
1556
+ rank [key ] = state_dict [key ].shape [1 ]
1512
1557
1513
- current_rank = ranks [0 ]
1514
- if not all (rank == current_rank for rank in ranks ):
1515
- raise ValueError ("Multi-rank not supported yet" )
1558
+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict )
1559
+ lora_config = LoraConfig (** lora_config_kwargs )
1516
1560
1517
- if network_alphas is not None :
1518
- alphas = set (network_alphas .values ())
1519
- if len (alphas ) == 1 :
1520
- alpha = alphas .pop ()
1521
- # TODO: support multi-alpha
1522
- else :
1523
- raise ValueError ("Multi-alpha not supported yet" )
1524
- else :
1525
- alpha = current_rank
1526
-
1527
- lora_config = LoraConfig (
1528
- r = current_rank ,
1529
- lora_alpha = alpha ,
1530
- target_modules = target_modules ,
1531
- )
1561
+ # adapter_name
1562
+ if adapter_name is None :
1563
+ adapter_name = get_adapter_name (unet )
1532
1564
1533
1565
inject_adapter_in_model (lora_config , unet , adapter_name = adapter_name )
1534
-
1535
1566
incompatible_keys = set_peft_model_state_dict (unet , state_dict , adapter_name )
1536
1567
1537
1568
if incompatible_keys is not None :
@@ -1655,12 +1686,14 @@ def load_lora_into_text_encoder(
1655
1686
if adapter_name is None :
1656
1687
adapter_name = get_adapter_name (text_encoder )
1657
1688
1689
+
1658
1690
# inject LoRA layers and load the state dict
1659
1691
text_encoder .load_adapter (
1660
1692
adapter_name = adapter_name ,
1661
1693
adapter_state_dict = text_encoder_lora_state_dict ,
1662
1694
peft_config = lora_config ,
1663
1695
)
1696
+
1664
1697
# scale LoRA layers with `lora_scale`
1665
1698
scale_lora_layers (text_encoder , weight = lora_scale )
1666
1699
@@ -2258,7 +2291,7 @@ def unfuse_text_encoder_lora(text_encoder):
2258
2291
2259
2292
self .num_fused_loras -= 1
2260
2293
2261
- def set_adapter_for_text_encoder (
2294
+ def set_adapters_for_text_encoder (
2262
2295
self ,
2263
2296
adapter_names : Union [List [str ], str ],
2264
2297
text_encoder : Optional [PreTrainedModel ] = None ,
@@ -2336,60 +2369,44 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] =
2336
2369
def set_adapters (
2337
2370
self ,
2338
2371
adapter_names : Union [List [str ], str ],
2339
- weights : List [float ] = None ,
2372
+ unet_weights : List [float ] = None ,
2373
+ te_weights : List [float ] = None ,
2374
+ te2_weights : List [float ] = None ,
2340
2375
):
2341
- """
2342
- Sets the adapter layers for the unet.
2343
-
2344
- Args:
2345
- adapter_names (`List[str]` or `str`):
2346
- The names of the adapters to use.
2347
- weights (`List[float]`, *optional*):
2348
- The weights to use for the unet. If `None`, the weights are set to `1.0` for all the adapters.
2349
- """
2350
- if not self .use_peft_backend :
2351
- raise ValueError ("PEFT backend is required for this method." )
2352
-
2353
- def process_weights (adapter_names , weights ):
2354
- if weights is None :
2355
- weights = [1.0 ] * len (adapter_names )
2356
- elif isinstance (weights , float ):
2357
- weights = [weights ]
2358
-
2359
- if len (adapter_names ) != len (weights ):
2360
- raise ValueError (
2361
- f"Length of adapter names { len (adapter_names )} is not equal to the length of the weights { len (weights )} "
2362
- )
2363
- return weights
2364
-
2365
- adapter_names = [adapter_names ] if isinstance (adapter_names , str ) else adapter_names
2366
- weights = process_weights (adapter_names , weights )
2376
+ # Handle the UNET
2377
+ self .unet .set_adapters (adapter_names , unet_weights )
2367
2378
2368
- for key , value in self .components .items ():
2369
- if isinstance (value , nn .Module ):
2370
- set_weights_and_activate_adapters (value , adapter_names , weights )
2379
+ # Handle the Text Encoder
2380
+ if hasattr (self , "text_encoder" ):
2381
+ self .set_adapters_for_text_encoder (adapter_names , self .text_encoder , te_weights )
2382
+ if hasattr (self , "text_encoder_2" ):
2383
+ self .set_adapters_for_text_encoder (adapter_names , self .text_encoder_2 , te2_weights )
2371
2384
2372
2385
def disable_lora (self ):
2373
- """
2374
- Disables the LoRA layers for the unet.
2375
- """
2376
2386
if not self .use_peft_backend :
2377
2387
raise ValueError ("PEFT backend is required for this method." )
2378
2388
2379
- for key , value in self .components .items ():
2380
- if isinstance (value , nn .Module ):
2381
- set_adapter_layers (value , enabled = False )
2389
+ # Disable unet adapters
2390
+ self .unet .disable_lora ()
2391
+
2392
+ # Disable text encoder adapters
2393
+ if hasattr (self , "text_encoder" ):
2394
+ self .disable_lora_for_text_encoder (self .text_encoder )
2395
+ if hasattr (self , "text_encoder_2" ):
2396
+ self .disable_lora_for_text_encoder (self .text_encoder_2 )
2382
2397
2383
2398
def enable_lora (self ):
2384
- """
2385
- Enables the LoRA layers for the unet.
2386
- """
2387
2399
if not self .use_peft_backend :
2388
2400
raise ValueError ("PEFT backend is required for this method." )
2389
2401
2390
- for key , value in self .components .items ():
2391
- if isinstance (value , nn .Module ):
2392
- set_adapter_layers (value , enabled = True )
2402
+ # Enable unet adapters
2403
+ self .unet .enable_lora ()
2404
+
2405
+ # Enable text encoder adapters
2406
+ if hasattr (self , "text_encoder" ):
2407
+ self .enable_lora_for_text_encoder (self .text_encoder )
2408
+ if hasattr (self , "text_encoder_2" ):
2409
+ self .enable_lora_for_text_encoder (self .text_encoder_2 )
2393
2410
2394
2411
2395
2412
class FromSingleFileMixin :
0 commit comments