Skip to content

Commit f1f38ff

Browse files
[ControlNet] Adds controlnet for SanaTransformer (#11040)
* added controlnet for sana transformer * improve code quality * addressed PR comments * bug fixes * added test cases * update * added dummy objects * addressed PR comments * update * Forcing update * add to docs * code quality * addressed PR comments * addressed PR comments * update * addressed PR comments * added proper styling * update * Revert "added proper styling" This reverts commit 344ee8a. * manually ordered * Apply suggestions from code review --------- Co-authored-by: Aryan <[email protected]>
1 parent 36538e1 commit f1f38ff

17 files changed

+2062
-22
lines changed

docs/source/en/_toctree.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -270,16 +270,18 @@
270270
- sections:
271271
- local: api/models/controlnet
272272
title: ControlNetModel
273+
- local: api/models/controlnet_union
274+
title: ControlNetUnionModel
273275
- local: api/models/controlnet_flux
274276
title: FluxControlNetModel
275277
- local: api/models/controlnet_hunyuandit
276278
title: HunyuanDiT2DControlNetModel
279+
- local: api/models/controlnet_sana
280+
title: SanaControlNetModel
277281
- local: api/models/controlnet_sd3
278282
title: SD3ControlNetModel
279283
- local: api/models/controlnet_sparsectrl
280284
title: SparseControlNetModel
281-
- local: api/models/controlnet_union
282-
title: ControlNetUnionModel
283285
title: ControlNets
284286
- sections:
285287
- local: api/models/allegro_transformer3d
@@ -424,6 +426,8 @@
424426
title: ControlNet with Stable Diffusion 3
425427
- local: api/pipelines/controlnet_sdxl
426428
title: ControlNet with Stable Diffusion XL
429+
- local: api/pipelines/controlnet_sana
430+
title: ControlNet-Sana
427431
- local: api/pipelines/controlnetxs
428432
title: ControlNet-XS
429433
- local: api/pipelines/controlnetxs_sdxl
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# SanaControlNetModel
14+
15+
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
16+
17+
The abstract from the paper is:
18+
19+
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
20+
21+
This model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
22+
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
23+
24+
## SanaControlNetModel
25+
[[autodoc]] SanaControlNetModel
26+
27+
## SanaControlNetOutput
28+
[[autodoc]] models.controlnets.controlnet_sana.SanaControlNetOutput
29+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# ControlNet
14+
15+
<div class="flex flex-wrap space-x-1">
16+
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
17+
</div>
18+
19+
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
20+
21+
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
22+
23+
The abstract from the paper is:
24+
25+
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
26+
27+
This pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
28+
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
29+
30+
## SanaControlNetPipeline
31+
[[autodoc]] SanaControlNetPipeline
32+
- all
33+
- __call__
34+
35+
## SanaPipelineOutput
36+
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#!/usr/bin/env python
2+
from __future__ import annotations
3+
4+
import argparse
5+
from contextlib import nullcontext
6+
7+
import torch
8+
from accelerate import init_empty_weights
9+
10+
from diffusers import (
11+
SanaControlNetModel,
12+
)
13+
from diffusers.models.modeling_utils import load_model_dict_into_meta
14+
from diffusers.utils.import_utils import is_accelerate_available
15+
16+
17+
CTX = init_empty_weights if is_accelerate_available else nullcontext
18+
19+
20+
def main(args):
21+
file_path = args.orig_ckpt_path
22+
23+
all_state_dict = torch.load(file_path, weights_only=True)
24+
state_dict = all_state_dict.pop("state_dict")
25+
converted_state_dict = {}
26+
27+
# Patch embeddings.
28+
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
29+
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
30+
31+
# Caption projection.
32+
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
33+
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
34+
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
35+
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
36+
37+
# AdaLN-single LN
38+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
39+
"t_embedder.mlp.0.weight"
40+
)
41+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
42+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
43+
"t_embedder.mlp.2.weight"
44+
)
45+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
46+
47+
# Shared norm.
48+
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
49+
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
50+
51+
# y norm
52+
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
53+
54+
# Positional embedding interpolation scale.
55+
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
56+
57+
# ControlNet Input Projection.
58+
converted_state_dict["input_block.weight"] = state_dict.pop("controlnet.0.before_proj.weight")
59+
converted_state_dict["input_block.bias"] = state_dict.pop("controlnet.0.before_proj.bias")
60+
61+
for depth in range(7):
62+
# Transformer blocks.
63+
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
64+
f"controlnet.{depth}.copied_block.scale_shift_table"
65+
)
66+
67+
# Linear Attention is all you need 🤘
68+
# Self attention.
69+
q, k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.attn.qkv.weight"), 3, dim=0)
70+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
71+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
72+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
73+
# Projection.
74+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
75+
f"controlnet.{depth}.copied_block.attn.proj.weight"
76+
)
77+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
78+
f"controlnet.{depth}.copied_block.attn.proj.bias"
79+
)
80+
81+
# Feed-forward.
82+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
83+
f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.weight"
84+
)
85+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
86+
f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.bias"
87+
)
88+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
89+
f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.weight"
90+
)
91+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
92+
f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.bias"
93+
)
94+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
95+
f"controlnet.{depth}.copied_block.mlp.point_conv.conv.weight"
96+
)
97+
98+
# Cross-attention.
99+
q = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.weight")
100+
q_bias = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.bias")
101+
k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.weight"), 2, dim=0)
102+
k_bias, v_bias = torch.chunk(
103+
state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0
104+
)
105+
106+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
107+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
108+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
109+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
110+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
111+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
112+
113+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
114+
f"controlnet.{depth}.copied_block.cross_attn.proj.weight"
115+
)
116+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
117+
f"controlnet.{depth}.copied_block.cross_attn.proj.bias"
118+
)
119+
120+
# ControlNet After Projection
121+
converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop(
122+
f"controlnet.{depth}.after_proj.weight"
123+
)
124+
converted_state_dict[f"controlnet_blocks.{depth}.bias"] = state_dict.pop(f"controlnet.{depth}.after_proj.bias")
125+
126+
# ControlNet
127+
with CTX():
128+
controlnet = SanaControlNetModel(
129+
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
130+
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
131+
num_layers=model_kwargs[args.model_type]["num_layers"],
132+
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
133+
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
134+
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
135+
caption_channels=2304,
136+
sample_size=args.image_size // 32,
137+
interpolation_scale=interpolation_scale[args.image_size],
138+
)
139+
140+
if is_accelerate_available():
141+
load_model_dict_into_meta(controlnet, converted_state_dict)
142+
else:
143+
controlnet.load_state_dict(converted_state_dict, strict=True, assign=True)
144+
145+
num_model_params = sum(p.numel() for p in controlnet.parameters())
146+
print(f"Total number of controlnet parameters: {num_model_params}")
147+
148+
controlnet = controlnet.to(weight_dtype)
149+
controlnet.load_state_dict(converted_state_dict, strict=True)
150+
151+
print(f"Saving Sana ControlNet in Diffusers format in {args.dump_path}.")
152+
controlnet.save_pretrained(args.dump_path)
153+
154+
155+
DTYPE_MAPPING = {
156+
"fp32": torch.float32,
157+
"fp16": torch.float16,
158+
"bf16": torch.bfloat16,
159+
}
160+
161+
VARIANT_MAPPING = {
162+
"fp32": None,
163+
"fp16": "fp16",
164+
"bf16": "bf16",
165+
}
166+
167+
168+
if __name__ == "__main__":
169+
parser = argparse.ArgumentParser()
170+
171+
parser.add_argument(
172+
"--orig_ckpt_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
173+
)
174+
parser.add_argument(
175+
"--image_size",
176+
default=1024,
177+
type=int,
178+
choices=[512, 1024, 2048, 4096],
179+
required=False,
180+
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
181+
)
182+
parser.add_argument(
183+
"--model_type",
184+
default="SanaMS_1600M_P1_ControlNet_D7",
185+
type=str,
186+
choices=["SanaMS_1600M_P1_ControlNet_D7", "SanaMS_600M_P1_ControlNet_D7"],
187+
)
188+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
189+
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
190+
191+
args = parser.parse_args()
192+
193+
model_kwargs = {
194+
"SanaMS_1600M_P1_ControlNet_D7": {
195+
"num_attention_heads": 70,
196+
"attention_head_dim": 32,
197+
"num_cross_attention_heads": 20,
198+
"cross_attention_head_dim": 112,
199+
"cross_attention_dim": 2240,
200+
"num_layers": 7,
201+
},
202+
"SanaMS_600M_P1_ControlNet_D7": {
203+
"num_attention_heads": 36,
204+
"attention_head_dim": 32,
205+
"num_cross_attention_heads": 16,
206+
"cross_attention_head_dim": 72,
207+
"cross_attention_dim": 1152,
208+
"num_layers": 7,
209+
},
210+
}
211+
212+
device = "cuda" if torch.cuda.is_available() else "cpu"
213+
weight_dtype = DTYPE_MAPPING[args.dtype]
214+
variant = VARIANT_MAPPING[args.dtype]
215+
216+
main(args)

src/diffusers/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@
190190
"OmniGenTransformer2DModel",
191191
"PixArtTransformer2DModel",
192192
"PriorTransformer",
193+
"SanaControlNetModel",
193194
"SanaTransformer2DModel",
194195
"SD3ControlNetModel",
195196
"SD3MultiControlNetModel",
@@ -428,6 +429,7 @@
428429
"PixArtSigmaPAGPipeline",
429430
"PixArtSigmaPipeline",
430431
"ReduxImageEncoder",
432+
"SanaControlNetPipeline",
431433
"SanaPAGPipeline",
432434
"SanaPipeline",
433435
"SanaSprintPipeline",
@@ -782,6 +784,7 @@
782784
OmniGenTransformer2DModel,
783785
PixArtTransformer2DModel,
784786
PriorTransformer,
787+
SanaControlNetModel,
785788
SanaTransformer2DModel,
786789
SD3ControlNetModel,
787790
SD3MultiControlNetModel,
@@ -999,6 +1002,7 @@
9991002
PixArtSigmaPAGPipeline,
10001003
PixArtSigmaPipeline,
10011004
ReduxImageEncoder,
1005+
SanaControlNetPipeline,
10021006
SanaPAGPipeline,
10031007
SanaPipeline,
10041008
SanaSprintPipeline,

src/diffusers/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"HunyuanDiT2DControlNetModel",
5050
"HunyuanDiT2DMultiControlNetModel",
5151
]
52+
_import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
5253
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
5354
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
5455
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
@@ -134,6 +135,7 @@
134135
HunyuanDiT2DMultiControlNetModel,
135136
MultiControlNetModel,
136137
MultiControlNetUnionModel,
138+
SanaControlNetModel,
137139
SD3ControlNetModel,
138140
SD3MultiControlNetModel,
139141
SparseControlNetModel,

src/diffusers/models/controlnets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
HunyuanDiT2DControlNetModel,
1010
HunyuanDiT2DMultiControlNetModel,
1111
)
12+
from .controlnet_sana import SanaControlNetModel
1213
from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
1314
from .controlnet_sparsectrl import (
1415
SparseControlNetConditioningEmbedding,

0 commit comments

Comments
 (0)