Skip to content

Commit 86bd6f5

Browse files
Merge pull request #2 from huggingface/smangrul/unet-enhancements
Smangrul/unet enhancements
2 parents 2646f3d + 0e771f0 commit 86bd6f5

File tree

2 files changed

+98
-78
lines changed

2 files changed

+98
-78
lines changed

src/diffusers/loaders.py

+88-71
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,55 @@ def _unfuse_lora_apply(self, module):
679679
if hasattr(module, "_unfuse_lora"):
680680
module._unfuse_lora()
681681

682+
def set_adapters(
683+
self,
684+
adapter_names: Union[List[str], str],
685+
weights: List[float] = None,
686+
):
687+
"""
688+
Sets the adapter layers for the unet.
689+
690+
Args:
691+
adapter_names (`List[str]` or `str`):
692+
The names of the adapters to use.
693+
weights (`List[float]`, *optional*):
694+
The weights to use for the unet. If `None`, the weights are set to `1.0` for all the adapters.
695+
"""
696+
if not self.use_peft_backend:
697+
raise ValueError("PEFT backend is required for this method.")
698+
699+
def process_weights(adapter_names, weights):
700+
if weights is None:
701+
weights = [1.0] * len(adapter_names)
702+
elif isinstance(weights, float):
703+
weights = [weights]
704+
705+
if len(adapter_names) != len(weights):
706+
raise ValueError(
707+
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
708+
)
709+
return weights
710+
711+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
712+
weights = process_weights(adapter_names, weights)
713+
set_weights_and_activate_adapters(self, adapter_names, weights)
714+
715+
def disable_lora(self):
716+
"""
717+
Disables the LoRA layers for the unet.
718+
"""
719+
if not self.use_peft_backend:
720+
raise ValueError("PEFT backend is required for this method.")
721+
set_adapter_layers(self, enabled=False)
722+
723+
def enable_lora(self):
724+
"""
725+
Enables the LoRA layers for the unet.
726+
"""
727+
if not self.use_peft_backend:
728+
raise ValueError("PEFT backend is required for this method.")
729+
set_adapter_layers(self, enabled=True)
730+
682731

683732
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
684733
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
@@ -1448,7 +1497,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
14481497

14491498
@classmethod
14501499
def load_lora_into_unet(
1451-
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None, adapter_name="default"
1500+
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None, adapter_name=None
14521501
):
14531502
"""
14541503
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1468,7 +1517,8 @@ def load_lora_into_unet(
14681517
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
14691518
argument to `True` will raise an error.
14701519
adapter_name (`str`, *optional*):
1471-
The name of the adapter to load the weights into. By default we use `"default"`
1520+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1521+
`default_{i}` where i is the total number of adapters being loaded.
14721522
"""
14731523
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
14741524
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1500,38 +1550,19 @@ def load_lora_into_unet(
15001550

15011551
state_dict = convert_unet_state_dict_to_peft(state_dict)
15021552

1503-
target_modules = []
1504-
ranks = []
1553+
rank = {}
15051554
for key in state_dict.keys():
1506-
# filter out the name
1507-
filtered_name = ".".join(key.split(".")[:-2])
1508-
target_modules.append(filtered_name)
15091555
if "lora_B" in key:
1510-
rank = state_dict[key].shape[1]
1511-
ranks.append(rank)
1556+
rank[key] = state_dict[key].shape[1]
15121557

1513-
current_rank = ranks[0]
1514-
if not all(rank == current_rank for rank in ranks):
1515-
raise ValueError("Multi-rank not supported yet")
1558+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
1559+
lora_config = LoraConfig(**lora_config_kwargs)
15161560

1517-
if network_alphas is not None:
1518-
alphas = set(network_alphas.values())
1519-
if len(alphas) == 1:
1520-
alpha = alphas.pop()
1521-
# TODO: support multi-alpha
1522-
else:
1523-
raise ValueError("Multi-alpha not supported yet")
1524-
else:
1525-
alpha = current_rank
1526-
1527-
lora_config = LoraConfig(
1528-
r=current_rank,
1529-
lora_alpha=alpha,
1530-
target_modules=target_modules,
1531-
)
1561+
# adapter_name
1562+
if adapter_name is None:
1563+
adapter_name = get_adapter_name(unet)
15321564

15331565
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
1534-
15351566
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
15361567

15371568
if incompatible_keys is not None:
@@ -1655,12 +1686,14 @@ def load_lora_into_text_encoder(
16551686
if adapter_name is None:
16561687
adapter_name = get_adapter_name(text_encoder)
16571688

1689+
16581690
# inject LoRA layers and load the state dict
16591691
text_encoder.load_adapter(
16601692
adapter_name=adapter_name,
16611693
adapter_state_dict=text_encoder_lora_state_dict,
16621694
peft_config=lora_config,
16631695
)
1696+
16641697
# scale LoRA layers with `lora_scale`
16651698
scale_lora_layers(text_encoder, weight=lora_scale)
16661699

@@ -2258,7 +2291,7 @@ def unfuse_text_encoder_lora(text_encoder):
22582291

22592292
self.num_fused_loras -= 1
22602293

2261-
def set_adapter_for_text_encoder(
2294+
def set_adapters_for_text_encoder(
22622295
self,
22632296
adapter_names: Union[List[str], str],
22642297
text_encoder: Optional[PreTrainedModel] = None,
@@ -2336,60 +2369,44 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] =
23362369
def set_adapters(
23372370
self,
23382371
adapter_names: Union[List[str], str],
2339-
weights: List[float] = None,
2372+
unet_weights: List[float] = None,
2373+
te_weights: List[float] = None,
2374+
te2_weights: List[float] = None,
23402375
):
2341-
"""
2342-
Sets the adapter layers for the unet.
2343-
2344-
Args:
2345-
adapter_names (`List[str]` or `str`):
2346-
The names of the adapters to use.
2347-
weights (`List[float]`, *optional*):
2348-
The weights to use for the unet. If `None`, the weights are set to `1.0` for all the adapters.
2349-
"""
2350-
if not self.use_peft_backend:
2351-
raise ValueError("PEFT backend is required for this method.")
2352-
2353-
def process_weights(adapter_names, weights):
2354-
if weights is None:
2355-
weights = [1.0] * len(adapter_names)
2356-
elif isinstance(weights, float):
2357-
weights = [weights]
2358-
2359-
if len(adapter_names) != len(weights):
2360-
raise ValueError(
2361-
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
2362-
)
2363-
return weights
2364-
2365-
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
2366-
weights = process_weights(adapter_names, weights)
2376+
# Handle the UNET
2377+
self.unet.set_adapters(adapter_names, unet_weights)
23672378

2368-
for key, value in self.components.items():
2369-
if isinstance(value, nn.Module):
2370-
set_weights_and_activate_adapters(value, adapter_names, weights)
2379+
# Handle the Text Encoder
2380+
if hasattr(self, "text_encoder"):
2381+
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, te_weights)
2382+
if hasattr(self, "text_encoder_2"):
2383+
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, te2_weights)
23712384

23722385
def disable_lora(self):
2373-
"""
2374-
Disables the LoRA layers for the unet.
2375-
"""
23762386
if not self.use_peft_backend:
23772387
raise ValueError("PEFT backend is required for this method.")
23782388

2379-
for key, value in self.components.items():
2380-
if isinstance(value, nn.Module):
2381-
set_adapter_layers(value, enabled=False)
2389+
# Disable unet adapters
2390+
self.unet.disable_lora()
2391+
2392+
# Disable text encoder adapters
2393+
if hasattr(self, "text_encoder"):
2394+
self.disable_lora_for_text_encoder(self.text_encoder)
2395+
if hasattr(self, "text_encoder_2"):
2396+
self.disable_lora_for_text_encoder(self.text_encoder_2)
23822397

23832398
def enable_lora(self):
2384-
"""
2385-
Enables the LoRA layers for the unet.
2386-
"""
23872399
if not self.use_peft_backend:
23882400
raise ValueError("PEFT backend is required for this method.")
23892401

2390-
for key, value in self.components.items():
2391-
if isinstance(value, nn.Module):
2392-
set_adapter_layers(value, enabled=True)
2402+
# Enable unet adapters
2403+
self.unet.enable_lora()
2404+
2405+
# Enable text encoder adapters
2406+
if hasattr(self, "text_encoder"):
2407+
self.enable_lora_for_text_encoder(self.text_encoder)
2408+
if hasattr(self, "text_encoder_2"):
2409+
self.enable_lora_for_text_encoder(self.text_encoder_2)
23932410

23942411

23952412
class FromSingleFileMixin:

src/diffusers/utils/peft_utils.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,16 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
123123
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
124124
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
125125

126-
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
127-
# get the alpha occuring the most number of times
128-
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
129-
130-
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
131-
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
132-
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
126+
if network_alpha_dict is not None:
127+
if len(set(network_alpha_dict.values())) > 1:
128+
# get the alpha occuring the most number of times
129+
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
130+
131+
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
132+
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
133+
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
134+
else:
135+
lora_alpha = set(network_alpha_dict.values()).pop()
133136

134137
# layer names without the Diffusers specific
135138
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})

0 commit comments

Comments
 (0)