Skip to content

Commit c5f6fb9

Browse files
yiyixuxusayakpaul
andcommitted
Sd35 controlnet (#10020)
* add model/pipeline Co-authored-by: Sayak Paul <[email protected]>
1 parent 30c9189 commit c5f6fb9

File tree

4 files changed

+367
-43
lines changed

4 files changed

+367
-43
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""
2+
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
3+
4+
Example:
5+
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
6+
```bash
7+
python scripts/convert_sd3_controlnet_to_diffusers.py \
8+
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
9+
--output_path "output/sd35-controlnet-canny" \
10+
--dtype "fp16" # optional, defaults to fp32
11+
```
12+
13+
Or download and convert from HuggingFace repository:
14+
```bash
15+
python scripts/convert_sd3_controlnet_to_diffusers.py \
16+
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
17+
--filename "sd3.5_large_controlnet_canny.safetensors" \
18+
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
19+
--dtype "fp32" # optional, defaults to fp32
20+
```
21+
22+
Note:
23+
The script supports the following ControlNet types from SD3.5:
24+
- Canny edge detection
25+
- Depth estimation
26+
- Blur detection
27+
28+
The checkpoint files can be downloaded from:
29+
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
30+
"""
31+
32+
import argparse
33+
34+
import safetensors.torch
35+
import torch
36+
from huggingface_hub import hf_hub_download
37+
38+
from diffusers import SD3ControlNetModel
39+
40+
41+
parser = argparse.ArgumentParser()
42+
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
43+
parser.add_argument(
44+
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
45+
)
46+
parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
47+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
48+
parser.add_argument(
49+
"--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
50+
)
51+
52+
args = parser.parse_args()
53+
54+
55+
def load_original_checkpoint(args):
56+
if args.original_state_dict_repo_id is not None:
57+
if args.filename is None:
58+
raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
59+
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
60+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
61+
elif args.checkpoint_path is not None:
62+
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
63+
ckpt_path = args.checkpoint_path
64+
else:
65+
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
66+
67+
original_state_dict = safetensors.torch.load_file(ckpt_path)
68+
return original_state_dict
69+
70+
71+
def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
72+
converted_state_dict = {}
73+
74+
# Direct mappings for controlnet blocks
75+
for i in range(19): # 19 controlnet blocks
76+
converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
77+
converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]
78+
79+
# Positional embeddings
80+
converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
81+
converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]
82+
83+
# Time and text embeddings
84+
time_text_mappings = {
85+
"time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
86+
"time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
87+
"time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
88+
"time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
89+
"time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
90+
"time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
91+
"time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
92+
"time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
93+
}
94+
95+
for new_key, old_key in time_text_mappings.items():
96+
if old_key in original_state_dict:
97+
converted_state_dict[new_key] = original_state_dict[old_key]
98+
99+
# Transformer blocks
100+
for i in range(19):
101+
# Split QKV into separate Q, K, V
102+
qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
103+
qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
104+
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
105+
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
106+
107+
block_mappings = {
108+
f"transformer_blocks.{i}.attn.to_q.weight": q,
109+
f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
110+
f"transformer_blocks.{i}.attn.to_k.weight": k,
111+
f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
112+
f"transformer_blocks.{i}.attn.to_v.weight": v,
113+
f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
114+
# Output projections
115+
f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
116+
f"transformer_blocks.{i}.attn.proj.weight"
117+
],
118+
f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
119+
f"transformer_blocks.{i}.attn.proj.bias"
120+
],
121+
# Feed forward
122+
f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
123+
f"transformer_blocks.{i}.mlp.fc1.weight"
124+
],
125+
f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
126+
f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
127+
f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
128+
# Norms
129+
f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
130+
f"transformer_blocks.{i}.adaLN_modulation.1.weight"
131+
],
132+
f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
133+
f"transformer_blocks.{i}.adaLN_modulation.1.bias"
134+
],
135+
}
136+
converted_state_dict.update(block_mappings)
137+
138+
return converted_state_dict
139+
140+
141+
def main(args):
142+
original_ckpt = load_original_checkpoint(args)
143+
original_dtype = next(iter(original_ckpt.values())).dtype
144+
145+
# Initialize dtype with fp32 as default
146+
if args.dtype == "fp16":
147+
dtype = torch.float16
148+
elif args.dtype == "bf16":
149+
dtype = torch.bfloat16
150+
elif args.dtype == "fp32":
151+
dtype = torch.float32
152+
else:
153+
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")
154+
155+
if dtype != original_dtype:
156+
print(
157+
f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
158+
)
159+
160+
converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)
161+
162+
controlnet = SD3ControlNetModel(
163+
patch_size=2,
164+
in_channels=16,
165+
num_layers=19,
166+
attention_head_dim=64,
167+
num_attention_heads=38,
168+
joint_attention_dim=None,
169+
caption_projection_dim=2048,
170+
pooled_projection_dim=2048,
171+
out_channels=16,
172+
pos_embed_max_size=None,
173+
pos_embed_type=None,
174+
use_pos_embed=False,
175+
force_zeros_for_pooled_projection=False,
176+
)
177+
178+
controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)
179+
180+
print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
181+
controlnet.to(dtype).save_pretrained(args.output_path)
182+
183+
184+
if __name__ == "__main__":
185+
main(args)

src/diffusers/models/controlnets/controlnet_sd3.py

+73-30
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..modeling_utils import ModelMixin
30+
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
3031
from .controlnet import BaseOutput, zero_module
3132

3233

@@ -58,40 +59,60 @@ def __init__(
5859
extra_conditioning_channels: int = 0,
5960
dual_attention_layers: Tuple[int, ...] = (),
6061
qk_norm: Optional[str] = None,
62+
pos_embed_type: Optional[str] = "sincos",
63+
use_pos_embed: bool = True,
64+
force_zeros_for_pooled_projection: bool = True,
6165
):
6266
super().__init__()
6367
default_out_channels = in_channels
6468
self.out_channels = out_channels if out_channels is not None else default_out_channels
6569
self.inner_dim = num_attention_heads * attention_head_dim
6670

67-
self.pos_embed = PatchEmbed(
68-
height=sample_size,
69-
width=sample_size,
70-
patch_size=patch_size,
71-
in_channels=in_channels,
72-
embed_dim=self.inner_dim,
73-
pos_embed_max_size=pos_embed_max_size,
74-
)
71+
if use_pos_embed:
72+
self.pos_embed = PatchEmbed(
73+
height=sample_size,
74+
width=sample_size,
75+
patch_size=patch_size,
76+
in_channels=in_channels,
77+
embed_dim=self.inner_dim,
78+
pos_embed_max_size=pos_embed_max_size,
79+
pos_embed_type=pos_embed_type,
80+
)
81+
else:
82+
self.pos_embed = None
7583
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
7684
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
7785
)
78-
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
79-
80-
# `attention_head_dim` is doubled to account for the mixing.
81-
# It needs to crafted when we get the actual checkpoints.
82-
self.transformer_blocks = nn.ModuleList(
83-
[
84-
JointTransformerBlock(
85-
dim=self.inner_dim,
86-
num_attention_heads=num_attention_heads,
87-
attention_head_dim=self.config.attention_head_dim,
88-
context_pre_only=False,
89-
qk_norm=qk_norm,
90-
use_dual_attention=True if i in dual_attention_layers else False,
91-
)
92-
for i in range(num_layers)
93-
]
94-
)
86+
if joint_attention_dim is not None:
87+
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
88+
89+
# `attention_head_dim` is doubled to account for the mixing.
90+
# It needs to crafted when we get the actual checkpoints.
91+
self.transformer_blocks = nn.ModuleList(
92+
[
93+
JointTransformerBlock(
94+
dim=self.inner_dim,
95+
num_attention_heads=num_attention_heads,
96+
attention_head_dim=self.config.attention_head_dim,
97+
context_pre_only=False,
98+
qk_norm=qk_norm,
99+
use_dual_attention=True if i in dual_attention_layers else False,
100+
)
101+
for i in range(num_layers)
102+
]
103+
)
104+
else:
105+
self.context_embedder = None
106+
self.transformer_blocks = nn.ModuleList(
107+
[
108+
SD3SingleTransformerBlock(
109+
dim=self.inner_dim,
110+
num_attention_heads=num_attention_heads,
111+
attention_head_dim=self.config.attention_head_dim,
112+
)
113+
for _ in range(num_layers)
114+
]
115+
)
95116

96117
# controlnet_blocks
97118
self.controlnet_blocks = nn.ModuleList([])
@@ -318,9 +339,27 @@ def forward(
318339
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
319340
)
320341

321-
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
342+
if self.pos_embed is not None and hidden_states.ndim != 4:
343+
raise ValueError("hidden_states must be 4D when pos_embed is used")
344+
345+
# SD3.5 8b controlnet does not have a `pos_embed`,
346+
# it use the `pos_embed` from the transformer to process input before passing to controlnet
347+
elif self.pos_embed is None and hidden_states.ndim != 3:
348+
raise ValueError("hidden_states must be 3D when pos_embed is not used")
349+
350+
if self.context_embedder is not None and encoder_hidden_states is None:
351+
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
352+
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
353+
elif self.context_embedder is None and encoder_hidden_states is not None:
354+
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
355+
356+
if self.pos_embed is not None:
357+
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
358+
322359
temb = self.time_text_embed(timestep, pooled_projections)
323-
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
360+
361+
if self.context_embedder is not None:
362+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
324363

325364
# add
326365
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
@@ -349,9 +388,13 @@ def custom_forward(*inputs):
349388
)
350389

351390
else:
352-
encoder_hidden_states, hidden_states = block(
353-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
354-
)
391+
if self.context_embedder is not None:
392+
encoder_hidden_states, hidden_states = block(
393+
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
394+
)
395+
else:
396+
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
397+
hidden_states = block(hidden_states, temb)
355398

356399
block_res_samples = block_res_samples + (hidden_states,)
357400

0 commit comments

Comments
 (0)