|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import os
|
| 15 | +import warnings |
15 | 16 | from collections import defaultdict
|
16 | 17 | from pathlib import Path
|
17 | 18 | from typing import Callable, Dict, List, Optional, Union
|
|
45 | 46 |
|
46 | 47 | logger = logging.get_logger(__name__)
|
47 | 48 |
|
| 49 | +TEXT_ENCODER_NAME = "text_encoder" |
| 50 | +UNET_NAME = "unet" |
48 | 51 |
|
49 | 52 | LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
50 | 53 | LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
@@ -87,6 +90,9 @@ def map_from(module, state_dict, *args, **kwargs):
|
87 | 90 |
|
88 | 91 |
|
89 | 92 | class UNet2DConditionLoadersMixin:
|
| 93 | + text_encoder_name = TEXT_ENCODER_NAME |
| 94 | + unet_name = UNET_NAME |
| 95 | + |
90 | 96 | def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
91 | 97 | r"""
|
92 | 98 | Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
|
@@ -225,6 +231,18 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
|
225 | 231 | is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
226 | 232 |
|
227 | 233 | if is_lora:
|
| 234 | + is_new_lora_format = all( |
| 235 | + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() |
| 236 | + ) |
| 237 | + if is_new_lora_format: |
| 238 | + # Strip the `"unet"` prefix. |
| 239 | + is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) |
| 240 | + if is_text_encoder_present: |
| 241 | + warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." |
| 242 | + warnings.warn(warn_message) |
| 243 | + unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] |
| 244 | + state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} |
| 245 | + |
228 | 246 | lora_grouped_dict = defaultdict(dict)
|
229 | 247 | for key, value in state_dict.items():
|
230 | 248 | attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
@@ -672,8 +690,8 @@ class LoraLoaderMixin:
|
672 | 690 |
|
673 | 691 | </Tip>
|
674 | 692 | """
|
675 |
| - text_encoder_name = "text_encoder" |
676 |
| - unet_name = "unet" |
| 693 | + text_encoder_name = TEXT_ENCODER_NAME |
| 694 | + unet_name = UNET_NAME |
677 | 695 |
|
678 | 696 | def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
679 | 697 | r"""
|
@@ -810,33 +828,33 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
|
810 | 828 | # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
811 | 829 | # their prefixes.
|
812 | 830 | keys = list(state_dict.keys())
|
813 |
| - |
814 |
| - # Load the layers corresponding to UNet. |
815 |
| - if all(key.startswith(self.unet_name) for key in keys): |
| 831 | + if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): |
| 832 | + # Load the layers corresponding to UNet. |
| 833 | + unet_keys = [k for k in keys if k.startswith(self.unet_name)] |
816 | 834 | logger.info(f"Loading {self.unet_name}.")
|
817 |
| - unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} |
| 835 | + unet_lora_state_dict = { |
| 836 | + k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys |
| 837 | + } |
818 | 838 | self.unet.load_attn_procs(unet_lora_state_dict)
|
819 | 839 |
|
820 |
| - # Load the layers corresponding to text encoder and make necessary adjustments. |
821 |
| - elif all(key.startswith(self.text_encoder_name) for key in keys): |
| 840 | + # Load the layers corresponding to text encoder and make necessary adjustments. |
| 841 | + text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] |
822 | 842 | logger.info(f"Loading {self.text_encoder_name}.")
|
823 | 843 | text_encoder_lora_state_dict = {
|
824 |
| - k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) |
| 844 | + k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys |
825 | 845 | }
|
826 |
| - attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict) |
827 |
| - self._modify_text_encoder(attn_procs_text_encoder) |
| 846 | + if len(text_encoder_lora_state_dict) > 0: |
| 847 | + attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) |
| 848 | + self._modify_text_encoder(attn_procs_text_encoder) |
828 | 849 |
|
829 | 850 | # Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
830 | 851 | # contain the module names of the `unet` as its keys WITHOUT any prefix.
|
831 | 852 | elif not all(
|
832 | 853 | key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
833 | 854 | ):
|
834 | 855 | self.unet.load_attn_procs(state_dict)
|
835 |
| - deprecation_message = "You have saved the LoRA weights using the old format. This will be" |
836 |
| - " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" |
837 |
| - " in a dictionary and then create a new dictionary like the following:" |
838 |
| - " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." |
839 |
| - deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) |
| 856 | + warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." |
| 857 | + warnings.warn(warn_message) |
840 | 858 |
|
841 | 859 | def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
|
842 | 860 | r"""
|
@@ -872,7 +890,9 @@ def _get_lora_layer_attribute(self, name: str) -> str:
|
872 | 890 | else:
|
873 | 891 | return "to_out_lora"
|
874 | 892 |
|
875 |
| - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): |
| 893 | + def _load_text_encoder_attn_procs( |
| 894 | + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs |
| 895 | + ): |
876 | 896 | r"""
|
877 | 897 | Load pretrained attention processor layers for
|
878 | 898 | [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
|
0 commit comments