Skip to content

Commit b119815

Browse files
committed
Use narrow instead of dynamic_slice
1 parent 3bfe2bb commit b119815

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

modules/sub_quadratic_attention.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# credit:
66
# Amin Rezaei (original author)
77
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
8+
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
89
# implementation of:
910
# Self-attention Does Not Need O(n2) Memory":
1011
# https://arxiv.org/abs/2112.05682v2
@@ -16,13 +17,13 @@
1617
import math
1718
from typing import Optional, NamedTuple, Protocol, List
1819

19-
def dynamic_slice(
20-
x: Tensor,
21-
starts: List[int],
22-
sizes: List[int],
20+
def narrow_trunc(
21+
input: Tensor,
22+
dim: int,
23+
start: int,
24+
length: int
2325
) -> Tensor:
24-
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
25-
return x[slicing]
26+
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
2627

2728
class AttnChunk(NamedTuple):
2829
exp_values: Tensor
@@ -76,15 +77,17 @@ def _query_chunk_attention(
7677
_, _, v_channels_per_head = value.shape
7778

7879
def chunk_scanner(chunk_idx: int) -> AttnChunk:
79-
key_chunk = dynamic_slice(
80+
key_chunk = narrow_trunc(
8081
key,
81-
(0, chunk_idx, 0),
82-
(batch_x_heads, kv_chunk_size, k_channels_per_head)
82+
1,
83+
chunk_idx,
84+
kv_chunk_size
8385
)
84-
value_chunk = dynamic_slice(
86+
value_chunk = narrow_trunc(
8587
value,
86-
(0, chunk_idx, 0),
87-
(batch_x_heads, kv_chunk_size, v_channels_per_head)
88+
1,
89+
chunk_idx,
90+
kv_chunk_size
8891
)
8992
return summarize_chunk(query, key_chunk, value_chunk)
9093

@@ -161,10 +164,11 @@ def efficient_dot_product_attention(
161164
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
162165

163166
def get_query_chunk(chunk_idx: int) -> Tensor:
164-
return dynamic_slice(
167+
return narrow_trunc(
165168
query,
166-
(0, chunk_idx, 0),
167-
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
169+
1,
170+
chunk_idx,
171+
min(query_chunk_size, q_tokens)
168172
)
169173

170174
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)

0 commit comments

Comments
 (0)