Skip to content

Commit ed224f9

Browse files
DN6sayakpaul
andauthored
Add single file support for Stable Cascade (#7274)
* update * update * update * update * update * update --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 531e719 commit ed224f9

File tree

4 files changed

+409
-5
lines changed

4 files changed

+409
-5
lines changed

src/diffusers/loaders/single_file_utils.py

+107-4
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,87 @@
8181
"timestep_spacing": "leading",
8282
}
8383

84+
85+
STABLE_CASCADE_DEFAULT_CONFIGS = {
86+
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
87+
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
88+
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
89+
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
90+
}
91+
92+
93+
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
94+
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
95+
96+
if is_stage_c:
97+
state_dict = {}
98+
for key in original_state_dict.keys():
99+
if key.endswith("in_proj_weight"):
100+
weights = original_state_dict[key].chunk(3, 0)
101+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
102+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
103+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
104+
elif key.endswith("in_proj_bias"):
105+
weights = original_state_dict[key].chunk(3, 0)
106+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
107+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
108+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
109+
elif key.endswith("out_proj.weight"):
110+
weights = original_state_dict[key]
111+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
112+
elif key.endswith("out_proj.bias"):
113+
weights = original_state_dict[key]
114+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
115+
else:
116+
state_dict[key] = original_state_dict[key]
117+
else:
118+
state_dict = {}
119+
for key in original_state_dict.keys():
120+
if key.endswith("in_proj_weight"):
121+
weights = original_state_dict[key].chunk(3, 0)
122+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
123+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
124+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
125+
elif key.endswith("in_proj_bias"):
126+
weights = original_state_dict[key].chunk(3, 0)
127+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
128+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
129+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
130+
elif key.endswith("out_proj.weight"):
131+
weights = original_state_dict[key]
132+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
133+
elif key.endswith("out_proj.bias"):
134+
weights = original_state_dict[key]
135+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
136+
# rename clip_mapper to clip_txt_pooled_mapper
137+
elif key.endswith("clip_mapper.weight"):
138+
weights = original_state_dict[key]
139+
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
140+
elif key.endswith("clip_mapper.bias"):
141+
weights = original_state_dict[key]
142+
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
143+
else:
144+
state_dict[key] = original_state_dict[key]
145+
146+
return state_dict
147+
148+
149+
def infer_stable_cascade_single_file_config(checkpoint):
150+
is_stage_c = "clip_txt_mapper.weight" in checkpoint
151+
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
152+
153+
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
154+
config_type = "stage_c_lite"
155+
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
156+
config_type = "stage_c"
157+
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
158+
config_type = "stage_b_lite"
159+
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
160+
config_type = "stage_b"
161+
162+
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
163+
164+
84165
DIFFUSERS_TO_LDM_MAPPING = {
85166
"unet": {
86167
"layers": {
@@ -229,10 +310,34 @@ def fetch_ldm_config_and_checkpoint(
229310
cache_dir=None,
230311
local_files_only=None,
231312
revision=None,
313+
):
314+
checkpoint = load_single_file_model_checkpoint(
315+
pretrained_model_link_or_path,
316+
resume_download=resume_download,
317+
force_download=force_download,
318+
proxies=proxies,
319+
token=token,
320+
cache_dir=cache_dir,
321+
local_files_only=local_files_only,
322+
revision=revision,
323+
)
324+
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
325+
326+
return original_config, checkpoint
327+
328+
329+
def load_single_file_model_checkpoint(
330+
pretrained_model_link_or_path,
331+
resume_download=False,
332+
force_download=False,
333+
proxies=None,
334+
token=None,
335+
cache_dir=None,
336+
local_files_only=None,
337+
revision=None,
232338
):
233339
if os.path.isfile(pretrained_model_link_or_path):
234340
checkpoint = load_state_dict(pretrained_model_link_or_path)
235-
236341
else:
237342
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
238343
checkpoint_path = _get_model_file(
@@ -252,9 +357,7 @@ def fetch_ldm_config_and_checkpoint(
252357
while "state_dict" in checkpoint:
253358
checkpoint = checkpoint["state_dict"]
254359

255-
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
256-
257-
return original_config, checkpoint
360+
return checkpoint
258361

259362

260363
def infer_original_config_file(class_name, checkpoint):

src/diffusers/loaders/unet.py

+105
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
set_adapter_layers,
4343
set_weights_and_activate_adapters,
4444
)
45+
from .single_file_utils import (
46+
convert_stable_cascade_unet_single_file_to_diffusers,
47+
infer_stable_cascade_single_file_config,
48+
load_single_file_model_checkpoint,
49+
)
4550
from .utils import AttnProcsLayers
4651

4752

@@ -896,3 +901,103 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
896901
self.config.encoder_hid_dim_type = "ip_image_proj"
897902

898903
self.to(dtype=self.dtype, device=self.device)
904+
905+
906+
class FromOriginalUNetMixin:
907+
"""
908+
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
909+
"""
910+
911+
@classmethod
912+
@validate_hf_hub_args
913+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
914+
r"""
915+
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
916+
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
917+
918+
Parameters:
919+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
920+
Can be either:
921+
- A link to the `.ckpt` file (for example
922+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
923+
- A path to a *file* containing all pipeline weights.
924+
config: (`dict`, *optional*):
925+
Dictionary containing the configuration of the model:
926+
torch_dtype (`str` or `torch.dtype`, *optional*):
927+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
928+
dtype is automatically derived from the model's weights.
929+
force_download (`bool`, *optional*, defaults to `False`):
930+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
931+
cached versions if they exist.
932+
cache_dir (`Union[str, os.PathLike]`, *optional*):
933+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
934+
is not used.
935+
resume_download (`bool`, *optional*, defaults to `False`):
936+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
937+
incompletely downloaded files are deleted.
938+
proxies (`Dict[str, str]`, *optional*):
939+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
940+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
941+
local_files_only (`bool`, *optional*, defaults to `False`):
942+
Whether to only load local model weights and configuration files or not. If set to True, the model
943+
won't be downloaded from the Hub.
944+
token (`str` or *bool*, *optional*):
945+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
946+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
947+
revision (`str`, *optional*, defaults to `"main"`):
948+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
949+
allowed by Git.
950+
kwargs (remaining dictionary of keyword arguments, *optional*):
951+
Can be used to overwrite load and saveable variables of the model.
952+
953+
"""
954+
config = kwargs.pop("config", None)
955+
resume_download = kwargs.pop("resume_download", False)
956+
force_download = kwargs.pop("force_download", False)
957+
proxies = kwargs.pop("proxies", None)
958+
token = kwargs.pop("token", None)
959+
cache_dir = kwargs.pop("cache_dir", None)
960+
local_files_only = kwargs.pop("local_files_only", None)
961+
revision = kwargs.pop("revision", None)
962+
torch_dtype = kwargs.pop("torch_dtype", None)
963+
964+
class_name = cls.__name__
965+
if class_name != "StableCascadeUNet":
966+
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
967+
968+
checkpoint = load_single_file_model_checkpoint(
969+
pretrained_model_link_or_path,
970+
resume_download=resume_download,
971+
force_download=force_download,
972+
proxies=proxies,
973+
token=token,
974+
cache_dir=cache_dir,
975+
local_files_only=local_files_only,
976+
revision=revision,
977+
)
978+
979+
if config is None:
980+
config = infer_stable_cascade_single_file_config(checkpoint)
981+
model_config = cls.load_config(**config, **kwargs)
982+
else:
983+
model_config = config
984+
985+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
986+
with ctx():
987+
model = cls.from_config(model_config, **kwargs)
988+
989+
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
990+
if is_accelerate_available():
991+
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
992+
if len(unexpected_keys) > 0:
993+
logger.warn(
994+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
995+
)
996+
997+
else:
998+
model.load_state_dict(diffusers_format_checkpoint)
999+
1000+
if torch_dtype is not None:
1001+
model.to(torch_dtype)
1002+
1003+
return model

src/diffusers/models/unets/unet_stable_cascade.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.nn as nn
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
24+
from ...loaders.unet import FromOriginalUNetMixin
2425
from ...utils import BaseOutput
2526
from ..attention_processor import Attention
2627
from ..modeling_utils import ModelMixin
@@ -134,7 +135,7 @@ class StableCascadeUNetOutput(BaseOutput):
134135
sample: torch.FloatTensor = None
135136

136137

137-
class StableCascadeUNet(ModelMixin, ConfigMixin):
138+
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
138139
_supports_gradient_checkpointing = True
139140

140141
@register_to_config

0 commit comments

Comments
 (0)