Skip to content

Commit c50b9bf

Browse files
committed
switch to my newer diffusers cross-attn API
1 parent b9a9cf5 commit c50b9bf

File tree

5 files changed

+54
-28
lines changed

5 files changed

+54
-28
lines changed

Diff for: scripts/play.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os, fnmatch
2+
from diffusers.models.attention_utils import mask_to_bias
23
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
34

45
# monkey-patch _randn to use CPU random before k-diffusion uses it
@@ -20,6 +21,7 @@
2021

2122
import torch
2223
from torch import Tensor, FloatTensor, BoolTensor, LongTensor, no_grad, zeros, tensor, arange, linspace, lerp
24+
from torch.nn.functional import pad
2325
from diffusers.models import UNet2DConditionModel, AutoencoderKL
2426
from diffusers.models.cross_attention import AttnProcessor2_0
2527
from diffusers.utils.import_utils import is_xformers_available
@@ -414,12 +416,34 @@
414416
# xformers attn_bias is only implemented for Triton + A100 GPU
415417
# https://github.com/facebookresearch/xformers/issues/576
416418
# chunked attention *can* be made to support masks, but I didn't implement it yet
417-
case AttentionMode.Xformers | AttentionMode.Chunked | AttentionMode.ScaledDPAttn:
419+
case AttentionMode.Xformers:
420+
from packaging import version
421+
from xformers import __version__ as xformers_version
422+
# attn bias support was/will be added in 0.0.17:
423+
# https://github.com/facebookresearch/xformers/blob/main/CHANGELOG.md
424+
if version.parse(xformers_version) >= version.parse('0.0.17'):
425+
# cutlassF is our best bet, but currently only supports token lengths which are multiples of 8
426+
# https://gist.github.com/Birch-san/0c36d228e1d4b881a06d1c6e5289d569
427+
# strictly speaking we should worry that making the key slightly longer, slightly
428+
# affects the softmax averaging. oh well.
429+
# https://github.com/lllyasviel/ControlNet/discussions/12
430+
mask_length = mask_denorm.shape[-1]
431+
extra_tokens_needed = 8 - (mask_length % 8)
432+
# 0-pad mask to multiple of 8 tokens
433+
mask_denorm = pad(mask_denorm, (0, extra_tokens_needed))
434+
# replicate-pad embedding to multiple of 8 tokens (mask will hide the extra tokens)
435+
embedding_denorm = pad(embedding_denorm, (0, 0, 0, extra_tokens_needed,), 'replicate')
436+
else:
437+
# if you're older than that, then we discard the masks
438+
mask_denorm = None
439+
case AttentionMode.Chunked:
418440
mask_denorm = None
441+
442+
cross_attention_bias: Optional[FloatTensor] = None if mask_denorm is None else mask_to_bias(mask_denorm, unet.dtype)
419443

420444
denoiser: Denoiser = denoiser_factory(
421445
cross_attention_conds=embedding_denorm,
422-
cross_attention_mask=mask_denorm,
446+
cross_attention_bias=cross_attention_bias,
423447
conds_per_prompt=conds_per_prompt_tensor,
424448
cond_weights=cond_weights,
425449
uncond_ixs=uncond_ixs,

Diff for: src/diffusers

Submodule diffusers updated 512 files

Diff for: src/helpers/attention/multi_head_attention/multi_head_attention.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch import nn, Tensor
1+
from torch import nn, Tensor, FloatTensor
22
from typing import Optional
33
from ..attn_compatible import CrossAttnCompatible
44

@@ -29,20 +29,22 @@ def forward(
2929
hidden_states: Tensor,
3030
encoder_hidden_states: Optional[Tensor] = None,
3131
attention_mask: Optional[Tensor] = None,
32-
cross_attn_mask: Optional[Tensor] = None,
3332
**cross_attention_kwargs,
3433
) -> Tensor:
3534
kv = hidden_states if encoder_hidden_states is None else encoder_hidden_states
36-
if cross_attn_mask is not None:
37-
cross_attn_mask = cross_attn_mask.repeat_interleave(self.num_heads, dim=0)
38-
cross_attn_mask = cross_attn_mask.unsqueeze(-2)
35+
if encoder_hidden_states is not None and 'encoder_attention_bias' in cross_attention_kwargs:
36+
encoder_attention_bias: FloatTensor = cross_attention_kwargs['encoder_attention_bias']
37+
encoder_attention_bias = encoder_attention_bias.repeat_interleave(self.num_heads, dim=0)
38+
encoder_attention_bias = encoder_attention_bias.unsqueeze(-2)
3939
_, vision_tokens, _ = hidden_states.shape
40-
cross_attn_mask = cross_attn_mask.expand(-1, vision_tokens, -1)
40+
encoder_attention_bias = encoder_attention_bias.expand(-1, vision_tokens, -1)
41+
else:
42+
encoder_attention_bias = None
4143
out, _ = super().forward(
4244
query=hidden_states,
4345
key=kv,
4446
value=kv,
4547
need_weights=False,
46-
attn_mask=cross_attn_mask,
48+
attn_mask=encoder_attention_bias,
4749
)
4850
return out

Diff for: src/helpers/batch_denoiser.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __call__(
1818
class AbstractBatchDenoiser(PostInitMixin, ABC, Denoiser):
1919
denoiser: DiffusersSDDenoiser
2020
cross_attention_conds: FloatTensor
21-
cross_attention_mask: Optional[BoolTensor]
21+
cross_attention_bias: Optional[FloatTensor]
2222
conds_per_prompt: LongTensor
2323
cond_weights: FloatTensor
2424
center_denoise_outputs: Optional[BoolTensor]
@@ -54,7 +54,7 @@ def __call__(
5454
input=noised_latents_in,
5555
sigma=sigma_in,
5656
encoder_hidden_states=self.cross_attention_conds,
57-
attention_mask=self.cross_attention_mask,
57+
cross_attention_bias=self.cross_attention_bias,
5858
)
5959
del noised_latents_in, sigma_in
6060
if self.center_denoise_outputs is not None:
@@ -163,7 +163,7 @@ def __call__(
163163
input=noised_latents_in,
164164
sigma=sigma_in,
165165
encoder_hidden_states=self.cross_attention_conds,
166-
attention_mask=self.cross_attention_mask,
166+
cross_attention_bias=self.cross_attention_bias,
167167
)
168168
if self.center_denoise_outputs is not None:
169169
denoised_latents = where(
@@ -194,7 +194,7 @@ class BatchDenoiserFactory():
194194
def __call__(
195195
self,
196196
cross_attention_conds: FloatTensor,
197-
cross_attention_mask: Optional[BoolTensor],
197+
cross_attention_bias: Optional[FloatTensor],
198198
conds_per_prompt: LongTensor,
199199
cond_weights: FloatTensor,
200200
uncond_ixs: Optional[LongTensor],
@@ -208,15 +208,15 @@ def __call__(
208208
return BatchNoCFGDenoiser(
209209
denoiser=self.denoiser,
210210
cross_attention_conds=cross_attention_conds,
211-
cross_attention_mask=cross_attention_mask,
211+
cross_attention_bias=cross_attention_bias,
212212
conds_per_prompt=conds_per_prompt,
213213
cond_weights=cond_weights,
214214
center_denoise_outputs=center_denoise_outputs,
215215
)
216216
return BatchCFGDenoiser(
217217
denoiser=self.denoiser,
218218
cross_attention_conds=cross_attention_conds,
219-
cross_attention_mask=cross_attention_mask,
219+
cross_attention_bias=cross_attention_bias,
220220
conds_per_prompt=conds_per_prompt,
221221
cond_weights=cond_weights,
222222
center_denoise_outputs=center_denoise_outputs,

Diff for: src/helpers/diffusers_denoiser.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch import Tensor, FloatTensor, BoolTensor
1+
from torch import Tensor, FloatTensor
22
from diffusers.models import UNet2DConditionModel
33
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
44
from k_diffusion.external import DiscreteEpsDDPMDenoiser, DiscreteVDDPMDenoiser
@@ -17,21 +17,21 @@ def get_eps(
1717
sample: FloatTensor,
1818
timestep: Union[Tensor, float, int],
1919
encoder_hidden_states: Tensor,
20+
cross_attention_bias: Optional[FloatTensor] = None,
2021
return_dict: bool = True,
21-
attention_mask: Optional[BoolTensor] = None,
2222
) -> Tensor:
23-
# cross_attn_mask is a proposal from my xattn_mask_2 branch of diffusers:
23+
# encoder_attention_bias is a proposal from my cross_attn_mask_3 branch of diffusers:
2424
# https://github.com/huggingface/diffusers/issues/1890
2525
# don't pass it in if we don't have to, to ensure compatibility with main branch of diffusers
26-
attn_kwargs = {} if attention_mask is None else {
27-
'cross_attn_mask': attention_mask,
26+
cross_attention_kwargs = {} if cross_attention_bias is None else {
27+
'encoder_attention_bias': cross_attention_bias,
2828
}
29-
out: UNet2DConditionOutput = self.inner_model(
29+
out: UNet2DConditionOutput = self.inner_model.forward(
3030
sample.to(self.inner_model.dtype),
3131
timestep.to(self.inner_model.dtype),
3232
encoder_hidden_states=encoder_hidden_states.to(self.inner_model.dtype),
3333
return_dict=return_dict,
34-
**attn_kwargs,
34+
cross_attention_kwargs=cross_attention_kwargs,
3535
)
3636
return out.sample.to(self.sampling_dtype)
3737

@@ -50,21 +50,21 @@ def get_v(
5050
sample: FloatTensor,
5151
timestep: Union[Tensor, float, int],
5252
encoder_hidden_states: Tensor,
53+
cross_attention_bias: Optional[FloatTensor] = None,
5354
return_dict: bool = True,
54-
attention_mask: Optional[BoolTensor] = None,
5555
) -> Tensor:
56-
# cross_attn_mask is a proposal from my xattn_mask_2 branch of diffusers:
56+
# encoder_attention_bias is a proposal from my cross_attn_mask_3 branch of diffusers:
5757
# https://github.com/huggingface/diffusers/issues/1890
5858
# don't pass it in if we don't have to, to ensure compatibility with main branch of diffusers
59-
attn_kwargs = {} if attention_mask is None else {
60-
'cross_attn_mask': attention_mask,
59+
cross_attention_kwargs = {} if cross_attention_bias is None else {
60+
'encoder_attention_bias': cross_attention_bias,
6161
}
6262
out: UNet2DConditionOutput = self.inner_model(
6363
sample.to(self.inner_model.dtype),
6464
timestep.to(self.inner_model.dtype),
6565
encoder_hidden_states=encoder_hidden_states.to(self.inner_model.dtype),
6666
return_dict=return_dict,
67-
**attn_kwargs,
67+
cross_attention_kwargs=cross_attention_kwargs,
6868
)
6969
return out.sample.to(self.sampling_dtype)
7070

0 commit comments

Comments
 (0)