|
5 | 5 | # credit:
|
6 | 6 | # Amin Rezaei (original author)
|
7 | 7 | # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
| 8 | +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) |
8 | 9 | # implementation of:
|
9 | 10 | # Self-attention Does Not Need O(n2) Memory":
|
10 | 11 | # https://arxiv.org/abs/2112.05682v2
|
|
16 | 17 | import math
|
17 | 18 | from typing import Optional, NamedTuple, Protocol, List
|
18 | 19 |
|
19 |
| -def dynamic_slice( |
20 |
| - x: Tensor, |
21 |
| - starts: List[int], |
22 |
| - sizes: List[int], |
| 20 | +def narrow_trunc( |
| 21 | + input: Tensor, |
| 22 | + dim: int, |
| 23 | + start: int, |
| 24 | + length: int |
23 | 25 | ) -> Tensor:
|
24 |
| - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] |
25 |
| - return x[slicing] |
| 26 | + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) |
26 | 27 |
|
27 | 28 | class AttnChunk(NamedTuple):
|
28 | 29 | exp_values: Tensor
|
@@ -76,15 +77,17 @@ def _query_chunk_attention(
|
76 | 77 | _, _, v_channels_per_head = value.shape
|
77 | 78 |
|
78 | 79 | def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
79 |
| - key_chunk = dynamic_slice( |
| 80 | + key_chunk = narrow_trunc( |
80 | 81 | key,
|
81 |
| - (0, chunk_idx, 0), |
82 |
| - (batch_x_heads, kv_chunk_size, k_channels_per_head) |
| 82 | + 1, |
| 83 | + chunk_idx, |
| 84 | + kv_chunk_size |
83 | 85 | )
|
84 |
| - value_chunk = dynamic_slice( |
| 86 | + value_chunk = narrow_trunc( |
85 | 87 | value,
|
86 |
| - (0, chunk_idx, 0), |
87 |
| - (batch_x_heads, kv_chunk_size, v_channels_per_head) |
| 88 | + 1, |
| 89 | + chunk_idx, |
| 90 | + kv_chunk_size |
88 | 91 | )
|
89 | 92 | return summarize_chunk(query, key_chunk, value_chunk)
|
90 | 93 |
|
@@ -161,10 +164,11 @@ def efficient_dot_product_attention(
|
161 | 164 | kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
162 | 165 |
|
163 | 166 | def get_query_chunk(chunk_idx: int) -> Tensor:
|
164 |
| - return dynamic_slice( |
| 167 | + return narrow_trunc( |
165 | 168 | query,
|
166 |
| - (0, chunk_idx, 0), |
167 |
| - (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) |
| 169 | + 1, |
| 170 | + chunk_idx, |
| 171 | + min(query_chunk_size, q_tokens) |
168 | 172 | )
|
169 | 173 |
|
170 | 174 | summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
|
0 commit comments