Skip to content

Commit 1ae9b05

Browse files
authored
Fix enable memory efficient attention on ROCm (#10564)
* fix enable memory efficient attention on ROCm while calling CK implementation * Update attention_processor.py refactor of picking a set element
1 parent aad69ac commit 1ae9b05

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/diffusers/models/attention_processor.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,12 @@ def set_use_memory_efficient_attention_xformers(
405405
else:
406406
try:
407407
# Make sure we can run the memory efficient attention
408-
_ = xformers.ops.memory_efficient_attention(
409-
torch.randn((1, 2, 40), device="cuda"),
410-
torch.randn((1, 2, 40), device="cuda"),
411-
torch.randn((1, 2, 40), device="cuda"),
412-
)
408+
dtype = None
409+
if attention_op is not None:
410+
op_fw, op_bw = attention_op
411+
dtype, *_ = op_fw.SUPPORTED_DTYPES
412+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
413+
_ = xformers.ops.memory_efficient_attention(q, q, q)
413414
except Exception as e:
414415
raise e
415416

0 commit comments

Comments
 (0)