Skip to content

Commit 1623c9c

Browse files
eliphatfsJimmy
authored and
Jimmy
committed
Implement CustomDiffusionAttnProcessor2_0. (huggingface#4604)
* Implement `CustomDiffusionAttnProcessor2_0` * Doc-strings and type annotations for `CustomDiffusionAttnProcessor2_0`. (huggingface#1) * Update attnprocessor.md * Update attention_processor.py * Interops for `CustomDiffusionAttnProcessor2_0`. * Formatted `attention_processor.py`. * Formatted doc-string in `attention_processor.py` * Conditional CustomDiffusion2_0 for training example. * Remove unnecessary reference impl in comments. * Fix `save_attn_procs`.
1 parent eba2a36 commit 1623c9c

File tree

4 files changed

+139
-7
lines changed

4 files changed

+139
-7
lines changed

docs/source/en/api/attnprocessor.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ An attention processor is a class for applying different types of attention mech
1717
## CustomDiffusionAttnProcessor
1818
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
1919

20+
## CustomDiffusionAttnProcessor2_0
21+
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
22+
2023
## AttnAddedKVProcessor
2124
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
2225

@@ -39,4 +42,4 @@ An attention processor is a class for applying different types of attention mech
3942
[[autodoc]] models.attention_processor.SlicedAttnProcessor
4043

4144
## SlicedAttnAddedKVProcessor
42-
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
45+
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor

examples/custom_diffusion/train_custom_diffusion.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
UNet2DConditionModel,
5252
)
5353
from diffusers.loaders import AttnProcsLayers
54-
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
54+
from diffusers.models.attention_processor import (
55+
CustomDiffusionAttnProcessor,
56+
CustomDiffusionAttnProcessor2_0,
57+
CustomDiffusionXFormersAttnProcessor,
58+
)
5559
from diffusers.optimization import get_scheduler
5660
from diffusers.utils import check_min_version, is_wandb_available
5761
from diffusers.utils.import_utils import is_xformers_available
@@ -870,7 +874,9 @@ def main(args):
870874
unet.to(accelerator.device, dtype=weight_dtype)
871875
vae.to(accelerator.device, dtype=weight_dtype)
872876

873-
attention_class = CustomDiffusionAttnProcessor
877+
attention_class = (
878+
CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor
879+
)
874880
if args.enable_xformers_memory_efficient_attention:
875881
if is_xformers_available():
876882
import xformers

src/diffusers/loaders.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def save_attn_procs(
559559
"""
560560
from .models.attention_processor import (
561561
CustomDiffusionAttnProcessor,
562+
CustomDiffusionAttnProcessor2_0,
562563
CustomDiffusionXFormersAttnProcessor,
563564
)
564565

@@ -578,15 +579,25 @@ def save_function(weights, filename):
578579
os.makedirs(save_directory, exist_ok=True)
579580

580581
is_custom_diffusion = any(
581-
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
582+
isinstance(
583+
x,
584+
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
585+
)
582586
for (_, x) in self.attn_processors.items()
583587
)
584588
if is_custom_diffusion:
585589
model_to_save = AttnProcsLayers(
586590
{
587591
y: x
588592
for (y, x) in self.attn_processors.items()
589-
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
593+
if isinstance(
594+
x,
595+
(
596+
CustomDiffusionAttnProcessor,
597+
CustomDiffusionAttnProcessor2_0,
598+
CustomDiffusionXFormersAttnProcessor,
599+
),
600+
)
590601
}
591602
)
592603
state_dict = model_to_save.state_dict()

src/diffusers/models/attention_processor.py

+114-2
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def set_use_memory_efficient_attention_xformers(
173173
LORA_ATTENTION_PROCESSORS,
174174
)
175175
is_custom_diffusion = hasattr(self, "processor") and isinstance(
176-
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
176+
self.processor,
177+
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
177178
)
178179
is_added_kv_processor = hasattr(self, "processor") and isinstance(
179180
self.processor,
@@ -261,7 +262,12 @@ def set_use_memory_efficient_attention_xformers(
261262
processor.load_state_dict(self.processor.state_dict())
262263
processor.to(self.processor.to_q_lora.up.weight.device)
263264
elif is_custom_diffusion:
264-
processor = CustomDiffusionAttnProcessor(
265+
attn_processor_class = (
266+
CustomDiffusionAttnProcessor2_0
267+
if hasattr(F, "scaled_dot_product_attention")
268+
else CustomDiffusionAttnProcessor
269+
)
270+
processor = attn_processor_class(
265271
train_kv=self.processor.train_kv,
266272
train_q_out=self.processor.train_q_out,
267273
hidden_size=self.processor.hidden_size,
@@ -1156,6 +1162,111 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
11561162
return hidden_states
11571163

11581164

1165+
class CustomDiffusionAttnProcessor2_0(nn.Module):
1166+
r"""
1167+
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
1168+
dot-product attention.
1169+
1170+
Args:
1171+
train_kv (`bool`, defaults to `True`):
1172+
Whether to newly train the key and value matrices corresponding to the text features.
1173+
train_q_out (`bool`, defaults to `True`):
1174+
Whether to newly train query matrices corresponding to the latent image features.
1175+
hidden_size (`int`, *optional*, defaults to `None`):
1176+
The hidden size of the attention layer.
1177+
cross_attention_dim (`int`, *optional*, defaults to `None`):
1178+
The number of channels in the `encoder_hidden_states`.
1179+
out_bias (`bool`, defaults to `True`):
1180+
Whether to include the bias parameter in `train_q_out`.
1181+
dropout (`float`, *optional*, defaults to 0.0):
1182+
The dropout probability to use.
1183+
"""
1184+
1185+
def __init__(
1186+
self,
1187+
train_kv=True,
1188+
train_q_out=True,
1189+
hidden_size=None,
1190+
cross_attention_dim=None,
1191+
out_bias=True,
1192+
dropout=0.0,
1193+
):
1194+
super().__init__()
1195+
self.train_kv = train_kv
1196+
self.train_q_out = train_q_out
1197+
1198+
self.hidden_size = hidden_size
1199+
self.cross_attention_dim = cross_attention_dim
1200+
1201+
# `_custom_diffusion` id for easy serialization and loading.
1202+
if self.train_kv:
1203+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1204+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1205+
if self.train_q_out:
1206+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1207+
self.to_out_custom_diffusion = nn.ModuleList([])
1208+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1209+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1210+
1211+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1212+
batch_size, sequence_length, _ = hidden_states.shape
1213+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1214+
if self.train_q_out:
1215+
query = self.to_q_custom_diffusion(hidden_states)
1216+
else:
1217+
query = attn.to_q(hidden_states)
1218+
1219+
if encoder_hidden_states is None:
1220+
crossattn = False
1221+
encoder_hidden_states = hidden_states
1222+
else:
1223+
crossattn = True
1224+
if attn.norm_cross:
1225+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1226+
1227+
if self.train_kv:
1228+
key = self.to_k_custom_diffusion(encoder_hidden_states)
1229+
value = self.to_v_custom_diffusion(encoder_hidden_states)
1230+
else:
1231+
key = attn.to_k(encoder_hidden_states)
1232+
value = attn.to_v(encoder_hidden_states)
1233+
1234+
if crossattn:
1235+
detach = torch.ones_like(key)
1236+
detach[:, :1, :] = detach[:, :1, :] * 0.0
1237+
key = detach * key + (1 - detach) * key.detach()
1238+
value = detach * value + (1 - detach) * value.detach()
1239+
1240+
inner_dim = hidden_states.shape[-1]
1241+
1242+
head_dim = inner_dim // attn.heads
1243+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1244+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1245+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1246+
1247+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1248+
# TODO: add support for attn.scale when we move to Torch 2.1
1249+
hidden_states = F.scaled_dot_product_attention(
1250+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1251+
)
1252+
1253+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1254+
hidden_states = hidden_states.to(query.dtype)
1255+
1256+
if self.train_q_out:
1257+
# linear proj
1258+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1259+
# dropout
1260+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1261+
else:
1262+
# linear proj
1263+
hidden_states = attn.to_out[0](hidden_states)
1264+
# dropout
1265+
hidden_states = attn.to_out[1](hidden_states)
1266+
1267+
return hidden_states
1268+
1269+
11591270
class SlicedAttnProcessor:
11601271
r"""
11611272
Processor for implementing sliced attention.
@@ -1639,6 +1750,7 @@ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
16391750
XFormersAttnAddedKVProcessor,
16401751
CustomDiffusionAttnProcessor,
16411752
CustomDiffusionXFormersAttnProcessor,
1753+
CustomDiffusionAttnProcessor2_0,
16421754
# depraceted
16431755
LoRAAttnProcessor,
16441756
LoRAAttnProcessor2_0,

0 commit comments

Comments
 (0)