Skip to content

Commit a573e3d

Browse files
committed
enable subquadratic_attn
1 parent 566b5c8 commit a573e3d

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

scripts/play.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
print(reassuring_message_2)
1313

1414
import torch
15-
from torch import Generator, Tensor, randn, no_grad, zeros
15+
from torch import Generator, Tensor, randn, no_grad, zeros, nn
1616
from diffusers.models import UNet2DConditionModel, AutoencoderKL
17+
from diffusers.models.cross_attention import CrossAttention
1718
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras, sample_dpmpp_2m
1819

1920
from helpers.schedule_params import get_alphas, get_alphas_cumprod, get_betas, quantize_to
@@ -70,6 +71,12 @@
7071
upcast_attention=upcast_attention,
7172
).to(device).eval()
7273

74+
def subquad_attn(module: nn.Module) -> None:
75+
for m in module.children():
76+
if isinstance(m, CrossAttention):
77+
m.set_subquadratic_attention(query_chunk_size=1024, kv_chunk_size=4096)
78+
unet.apply(subquad_attn)
79+
7380
# sampling in higher-precision helps to converge more stably toward the "true" image (not necessarily better-looking though)
7481
sampling_dtype: torch.dtype = torch.float32
7582
# sampling_dtype: torch.dtype = torch_dtype

0 commit comments

Comments
 (0)