Skip to content

Commit 0437214

Browse files
committed
MPS fixes; now working
1 parent 4e96ca3 commit 0437214

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/diffusers/models/sub_quadratic_attention.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# sparse broadcasting for bias, mask, weights
1010
# flattened conditions for clarity
1111
# Hyungon Ryu (device arg fix)
12+
# Alex Birch (MPS support)
1213
# implementation of:
1314
# Self-attention Does Not Need O(n2) Memory":
1415
# https://arxiv.org/abs/2112.05682v2
@@ -51,11 +52,13 @@ def summarize_chunk(key_idx, query, key, value, mask, bias):
5152
attn_weights = torch.where(mask, attn_weights, big_neg)
5253
if weights_calc_fn is not None:
5354
attn_weights = weights_calc_fn(query_idx, key_idx, attn_weights, calc_fn_data)
55+
attn_weights = attn_weights.contiguous() if attn_weights.device.type == 'mps' else attn_weights
5456
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
5557
max_score = max_score.detach()
5658
exp_weights = torch.exp(attn_weights - max_score)
5759
exp_values = torch.einsum('...vhf,...qhv->...qhf', value, exp_weights)
5860
max_score = torch.einsum('...qhk->...qh', max_score)
61+
exp_values = exp_values.contiguous() if exp_values.device.type == 'mps' else exp_values
5962
return exp_values, exp_weights.sum(dim=-1), max_score
6063
summarizer = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
6164

0 commit comments

Comments
 (0)