|
9 | 9 | # sparse broadcasting for bias, mask, weights
|
10 | 10 | # flattened conditions for clarity
|
11 | 11 | # Hyungon Ryu (device arg fix)
|
| 12 | +# Alex Birch (MPS support) |
12 | 13 | # implementation of:
|
13 | 14 | # Self-attention Does Not Need O(n2) Memory":
|
14 | 15 | # https://arxiv.org/abs/2112.05682v2
|
@@ -51,11 +52,13 @@ def summarize_chunk(key_idx, query, key, value, mask, bias):
|
51 | 52 | attn_weights = torch.where(mask, attn_weights, big_neg)
|
52 | 53 | if weights_calc_fn is not None:
|
53 | 54 | attn_weights = weights_calc_fn(query_idx, key_idx, attn_weights, calc_fn_data)
|
| 55 | + attn_weights = attn_weights.contiguous() if attn_weights.device.type == 'mps' else attn_weights |
54 | 56 | max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
55 | 57 | max_score = max_score.detach()
|
56 | 58 | exp_weights = torch.exp(attn_weights - max_score)
|
57 | 59 | exp_values = torch.einsum('...vhf,...qhv->...qhf', value, exp_weights)
|
58 | 60 | max_score = torch.einsum('...qhk->...qh', max_score)
|
| 61 | + exp_values = exp_values.contiguous() if exp_values.device.type == 'mps' else exp_values |
59 | 62 | return exp_values, exp_weights.sum(dim=-1), max_score
|
60 | 63 | summarizer = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
61 | 64 |
|
|
0 commit comments