Skip to content

Commit b9a9cf5

Browse files
committed
support selecting torch.nn.functional.scaled_dot_product_attention
1 parent a77865e commit b9a9cf5

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

scripts/play.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from torch import Tensor, FloatTensor, BoolTensor, LongTensor, no_grad, zeros, tensor, arange, linspace, lerp
2323
from diffusers.models import UNet2DConditionModel, AutoencoderKL
24+
from diffusers.models.cross_attention import AttnProcessor2_0
2425
from diffusers.utils.import_utils import is_xformers_available
2526
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras, sample_dpmpp_2m
2627

@@ -121,7 +122,7 @@
121122
upcast_attention=upcast_attention,
122123
).to(device).eval()
123124

124-
attn_mode = AttentionMode.TorchMultiheadAttention
125+
attn_mode = AttentionMode.ScaledDPAttn
125126
match(attn_mode):
126127
case AttentionMode.Standard: pass
127128
case AttentionMode.Chunked:
@@ -134,6 +135,8 @@
134135
case AttentionMode.TorchMultiheadAttention:
135136
tap_module: TapModule = replace_attn_to_tap_module(to_mha)
136137
unet.apply(tap_module)
138+
case AttentionMode.ScaledDPAttn:
139+
unet.set_attn_processor(AttnProcessor2_0())
137140
case AttentionMode.Xformers:
138141
assert is_xformers_available()
139142
unet.enable_xformers_memory_efficient_attention()
@@ -411,7 +414,7 @@
411414
# xformers attn_bias is only implemented for Triton + A100 GPU
412415
# https://github.com/facebookresearch/xformers/issues/576
413416
# chunked attention *can* be made to support masks, but I didn't implement it yet
414-
case AttentionMode.Xformers | AttentionMode.Chunked:
417+
case AttentionMode.Xformers | AttentionMode.Chunked | AttentionMode.ScaledDPAttn:
415418
mask_denorm = None
416419

417420
denoiser: Denoiser = denoiser_factory(

src/helpers/attention/mode.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from enum import Enum, auto
22

33
class AttentionMode(Enum):
4+
# usual diffusers CrossAttention layer, CrossAttnProcessor via baddbmm(), bmm()
45
Standard = auto()
56
# https://github.com/huggingface/diffusers/issues/1892
7+
# usual diffusers CrossAttention layer, CrossAttnProcessor via torch.narrow()'d baddbmm(), bmm()s ("memory-efficient" in pure PyTorch)
68
Chunked = auto()
9+
# replaces diffusers' CrossAttention layers with torch.nn.MultiheadAttention
710
TorchMultiheadAttention = auto()
11+
# usual diffusers CrossAttention layer, CrossAttnProcessor via torch.nn.functional.scaled_dot_product_attention
12+
ScaledDPAttn = auto()
13+
# usual diffusers CrossAttention layer, CrossAttnProcessor via Xformers
814
Xformers = auto()

0 commit comments

Comments
 (0)