Skip to content

Commit 6090bba

Browse files
t2i pipeline (huggingface#3932)
* Quick implementation of t2i-adapter Load adapter module with from_pretrained Prototyping generalized adapter framework Writeup doc string for sideload framework(WIP) + some minor update on implementation Update adapter models Remove old adapter optional args in UNet Add StableDiffusionAdapterPipeline unit test Handle cpu offload in StableDiffusionAdapterPipeline Auto correct coding style Update model repo name to "RzZ/sd-v1-4-adapter-pipeline" Refactor MultiAdapter to better compatible with config system Export MultiAdapter Create pipeline document template from controlnet Create dummy objects Supproting new AdapterLight model Fix StableDiffusionAdapterPipeline common pipeline test [WIP] Update adapter pipeline document Handle num_inference_steps in StableDiffusionAdapterPipeline Update definition of Adapter "channels_in" Update documents Apply code style Fix doc typo and merge error Update doc string and example Quality of life improvement Remove redundant code and file from prototyping Remove unused pageage Remove comments Fix title Fix typo Add conditioning scale arg Bring back old implmentation Offload sideload Add supply info on document Update src/diffusers/models/adapter.py Co-authored-by: Will Berman <[email protected]> Update MultiAdapter constructor Swap out custom checkpoint and update pipeline constructor Update docment Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Correcting style Following single-file policy Update auto size in image preprocess func Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <[email protected]> fix copies Update adapter pipeline behavior Add adapter_conditioning_scale doc string Add the missing doc string Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Fix few bugs from suggestion Handle L-mode PIL image as control image Rename to differentiate adapter resblock Update src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <[email protected]> Fix typo Update adapter parameter name Update test case and code style Fix copies Fix typo Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <[email protected]> Update Adapter class name Add checkpoint converting script Fix style Fix-copies Remove dev script Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Updates for parameter rename Fix convert_adapter remove main fix diff more refactoring more more small fixes refactor tests more slow tests more tests Update docs/source/en/api/pipelines/overview.mdx Co-authored-by: Sayak Paul <[email protected]> add community contributor to docs Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> fix remove from_adapters license paper link docs more url fixes more docs fix fixes fix fix * fix sample inplace add * additional_kwargs -> additional_residuals * move t2i adapter pipeline to own module * preprocess -> _preprocess_adapter_image * add TencentArc to license * fix example code links * add image converter and fix example doc string * fix links * clearer additional residual application --------- Co-authored-by: HimariO <[email protected]>
1 parent a62d11b commit 6090bba

11 files changed

+1219
-6
lines changed

__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
AutoencoderKL,
4040
ControlNetModel,
4141
ModelMixin,
42+
MultiAdapter,
4243
PriorTransformer,
44+
T2IAdapter,
4345
T5FilmDecoder,
4446
Transformer2DModel,
4547
UNet1DModel,
@@ -151,6 +153,7 @@
151153
SemanticStableDiffusionPipeline,
152154
ShapEImg2ImgPipeline,
153155
ShapEPipeline,
156+
StableDiffusionAdapterPipeline,
154157
StableDiffusionAttendAndExcitePipeline,
155158
StableDiffusionControlNetImg2ImgPipeline,
156159
StableDiffusionControlNetInpaintPipeline,

models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
if is_torch_available():
19+
from .adapter import MultiAdapter, T2IAdapter
1920
from .autoencoder_kl import AutoencoderKL
2021
from .controlnet import ControlNetModel
2122
from .dual_transformer_2d import DualTransformer2DModel

models/adapter.py

+291
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional
16+
17+
import torch
18+
import torch.nn as nn
19+
20+
from ..configuration_utils import ConfigMixin, register_to_config
21+
from .modeling_utils import ModelMixin
22+
from .resnet import Downsample2D
23+
24+
25+
class MultiAdapter(ModelMixin):
26+
r"""
27+
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
28+
user-assigned weighting.
29+
30+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
31+
implements for all the model (such as downloading or saving, etc.)
32+
33+
Parameters:
34+
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
35+
A list of `T2IAdapter` model instances.
36+
"""
37+
38+
def __init__(self, adapters: List["T2IAdapter"]):
39+
super(MultiAdapter, self).__init__()
40+
41+
self.num_adapter = len(adapters)
42+
self.adapters = nn.ModuleList(adapters)
43+
44+
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
45+
r"""
46+
Args:
47+
xs (`torch.Tensor`):
48+
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
49+
`channel` should equal to `num_adapter` * "number of channel of image".
50+
adapter_weights (`List[float]`, *optional*, defaults to None):
51+
List of floats representing the weight which will be multiply to each adapter's output before adding
52+
them together.
53+
"""
54+
if adapter_weights is None:
55+
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
56+
else:
57+
adapter_weights = torch.tensor(adapter_weights)
58+
59+
if xs.shape[1] % self.num_adapter != 0:
60+
raise ValueError(
61+
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
62+
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
63+
)
64+
x_list = torch.chunk(xs, self.num_adapter, dim=1)
65+
accume_state = None
66+
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
67+
features = adapter(x)
68+
if accume_state is None:
69+
accume_state = features
70+
else:
71+
for i in range(len(features)):
72+
accume_state[i] += w * features[i]
73+
return accume_state
74+
75+
76+
class T2IAdapter(ModelMixin, ConfigMixin):
77+
r"""
78+
A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
79+
generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
80+
architecture follows the original implementation of
81+
[Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
82+
and
83+
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
84+
85+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
86+
implements for all the model (such as downloading or saving, etc.)
87+
88+
Parameters:
89+
in_channels (`int`, *optional*, defaults to 3):
90+
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
91+
image as *control image*.
92+
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
93+
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
94+
also determine the number of downsample blocks in the Adapter.
95+
num_res_blocks (`int`, *optional*, defaults to 2):
96+
Number of ResNet blocks in each downsample block
97+
"""
98+
99+
@register_to_config
100+
def __init__(
101+
self,
102+
in_channels: int = 3,
103+
channels: List[int] = [320, 640, 1280, 1280],
104+
num_res_blocks: int = 2,
105+
downscale_factor: int = 8,
106+
adapter_type: str = "full_adapter",
107+
):
108+
super().__init__()
109+
110+
if adapter_type == "full_adapter":
111+
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
112+
elif adapter_type == "light_adapter":
113+
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
114+
else:
115+
raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
116+
117+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
118+
return self.adapter(x)
119+
120+
@property
121+
def total_downscale_factor(self):
122+
return self.adapter.total_downscale_factor
123+
124+
125+
# full adapter
126+
127+
128+
class FullAdapter(nn.Module):
129+
def __init__(
130+
self,
131+
in_channels: int = 3,
132+
channels: List[int] = [320, 640, 1280, 1280],
133+
num_res_blocks: int = 2,
134+
downscale_factor: int = 8,
135+
):
136+
super().__init__()
137+
138+
in_channels = in_channels * downscale_factor**2
139+
140+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
141+
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
142+
143+
self.body = nn.ModuleList(
144+
[
145+
AdapterBlock(channels[0], channels[0], num_res_blocks),
146+
*[
147+
AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
148+
for i in range(1, len(channels))
149+
],
150+
]
151+
)
152+
153+
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
154+
155+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
156+
x = self.unshuffle(x)
157+
x = self.conv_in(x)
158+
159+
features = []
160+
161+
for block in self.body:
162+
x = block(x)
163+
features.append(x)
164+
165+
return features
166+
167+
168+
class AdapterBlock(nn.Module):
169+
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
170+
super().__init__()
171+
172+
self.downsample = None
173+
if down:
174+
self.downsample = Downsample2D(in_channels)
175+
176+
self.in_conv = None
177+
if in_channels != out_channels:
178+
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
179+
180+
self.resnets = nn.Sequential(
181+
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
182+
)
183+
184+
def forward(self, x):
185+
if self.downsample is not None:
186+
x = self.downsample(x)
187+
188+
if self.in_conv is not None:
189+
x = self.in_conv(x)
190+
191+
x = self.resnets(x)
192+
193+
return x
194+
195+
196+
class AdapterResnetBlock(nn.Module):
197+
def __init__(self, channels):
198+
super().__init__()
199+
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
200+
self.act = nn.ReLU()
201+
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
202+
203+
def forward(self, x):
204+
h = x
205+
h = self.block1(h)
206+
h = self.act(h)
207+
h = self.block2(h)
208+
209+
return h + x
210+
211+
212+
# light adapter
213+
214+
215+
class LightAdapter(nn.Module):
216+
def __init__(
217+
self,
218+
in_channels: int = 3,
219+
channels: List[int] = [320, 640, 1280],
220+
num_res_blocks: int = 4,
221+
downscale_factor: int = 8,
222+
):
223+
super().__init__()
224+
225+
in_channels = in_channels * downscale_factor**2
226+
227+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
228+
229+
self.body = nn.ModuleList(
230+
[
231+
LightAdapterBlock(in_channels, channels[0], num_res_blocks),
232+
*[
233+
LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
234+
for i in range(len(channels) - 1)
235+
],
236+
LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
237+
]
238+
)
239+
240+
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
241+
242+
def forward(self, x):
243+
x = self.unshuffle(x)
244+
245+
features = []
246+
247+
for block in self.body:
248+
x = block(x)
249+
features.append(x)
250+
251+
return features
252+
253+
254+
class LightAdapterBlock(nn.Module):
255+
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
256+
super().__init__()
257+
mid_channels = out_channels // 4
258+
259+
self.downsample = None
260+
if down:
261+
self.downsample = Downsample2D(in_channels)
262+
263+
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
264+
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
265+
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
266+
267+
def forward(self, x):
268+
if self.downsample is not None:
269+
x = self.downsample(x)
270+
271+
x = self.in_conv(x)
272+
x = self.resnets(x)
273+
x = self.out_conv(x)
274+
275+
return x
276+
277+
278+
class LightAdapterResnetBlock(nn.Module):
279+
def __init__(self, channels):
280+
super().__init__()
281+
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
282+
self.act = nn.ReLU()
283+
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
284+
285+
def forward(self, x):
286+
h = x
287+
h = self.block1(h)
288+
h = self.act(h)
289+
h = self.block2(h)
290+
291+
return h + x

models/unet_2d_blocks.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -955,10 +955,13 @@ def forward(
955955
attention_mask: Optional[torch.FloatTensor] = None,
956956
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
957957
encoder_attention_mask: Optional[torch.FloatTensor] = None,
958+
additional_residuals=None,
958959
):
959960
output_states = ()
960961

961-
for resnet, attn in zip(self.resnets, self.attentions):
962+
blocks = list(zip(self.resnets, self.attentions))
963+
964+
for i, (resnet, attn) in enumerate(blocks):
962965
if self.training and self.gradient_checkpointing:
963966

964967
def create_custom_forward(module, return_dict=None):
@@ -999,6 +1002,10 @@ def custom_forward(*inputs):
9991002
return_dict=False,
10001003
)[0]
10011004

1005+
# apply additional residuals to the output of the last pair of resnet and attention blocks
1006+
if i == len(blocks) - 1 and additional_residuals is not None:
1007+
hidden_states = hidden_states + additional_residuals
1008+
10021009
output_states = output_states + (hidden_states,)
10031010

10041011
if self.downsamplers is not None:

models/unet_2d_condition.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -899,23 +899,36 @@ def forward(
899899
sample = self.conv_in(sample)
900900

901901
# 3. down
902+
903+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
904+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
905+
902906
down_block_res_samples = (sample,)
903907
for downsample_block in self.down_blocks:
904908
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
909+
# For t2i-adapter CrossAttnDownBlock2D
910+
additional_residuals = {}
911+
if is_adapter and len(down_block_additional_residuals) > 0:
912+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
913+
905914
sample, res_samples = downsample_block(
906915
hidden_states=sample,
907916
temb=emb,
908917
encoder_hidden_states=encoder_hidden_states,
909918
attention_mask=attention_mask,
910919
cross_attention_kwargs=cross_attention_kwargs,
911920
encoder_attention_mask=encoder_attention_mask,
921+
**additional_residuals,
912922
)
913923
else:
914924
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
915925

926+
if is_adapter and len(down_block_additional_residuals) > 0:
927+
sample += down_block_additional_residuals.pop(0)
928+
916929
down_block_res_samples += res_samples
917930

918-
if down_block_additional_residuals is not None:
931+
if is_controlnet:
919932
new_down_block_res_samples = ()
920933

921934
for down_block_res_sample, down_block_additional_residual in zip(
@@ -937,7 +950,7 @@ def forward(
937950
encoder_attention_mask=encoder_attention_mask,
938951
)
939952

940-
if mid_block_additional_residual is not None:
953+
if is_controlnet:
941954
sample = sample + mid_block_additional_residual
942955

943956
# 5. up

pipelines/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
StableUnCLIPPipeline,
102102
)
103103
from .stable_diffusion_safe import StableDiffusionPipelineSafe
104+
from .t2i_adapter import StableDiffusionAdapterPipeline
104105
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
105106
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
106107
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder

0 commit comments

Comments
 (0)