@@ -26,28 +26,28 @@ class SummarizeChunk(Protocol):
26
26
@staticmethod
27
27
def __call__ (
28
28
query : Tensor ,
29
- key : Tensor ,
29
+ key_t : Tensor ,
30
30
value : Tensor ,
31
31
) -> AttnChunk : ...
32
32
33
33
class ComputeQueryChunkAttn (Protocol ):
34
34
@staticmethod
35
35
def __call__ (
36
36
query : Tensor ,
37
- key : Tensor ,
37
+ key_t : Tensor ,
38
38
value : Tensor ,
39
39
) -> Tensor : ...
40
40
41
41
def _summarize_chunk (
42
42
query : Tensor ,
43
- key : Tensor ,
43
+ key_t : Tensor ,
44
44
value : Tensor ,
45
45
scale : float ,
46
46
) -> AttnChunk :
47
47
attn_weights = torch .baddbmm (
48
48
torch .empty (1 , 1 , 1 , device = query .device , dtype = query .dtype ),
49
49
query ,
50
- key . transpose ( 1 , 2 ) ,
50
+ key_t ,
51
51
alpha = scale ,
52
52
beta = 0 ,
53
53
)
@@ -60,19 +60,19 @@ def _summarize_chunk(
60
60
61
61
def _query_chunk_attention (
62
62
query : Tensor ,
63
- key : Tensor ,
63
+ key_t : Tensor ,
64
64
value : Tensor ,
65
65
summarize_chunk : SummarizeChunk ,
66
66
kv_chunk_size : int ,
67
67
) -> Tensor :
68
- batch_x_heads , k_tokens , k_channels_per_head = key .shape
68
+ batch_x_heads , k_channels_per_head , k_tokens = key_t .shape
69
69
_ , _ , v_channels_per_head = value .shape
70
70
71
71
def chunk_scanner (chunk_idx : int ) -> AttnChunk :
72
72
key_chunk = dynamic_slice (
73
- key ,
74
- (0 , chunk_idx , 0 ),
75
- (batch_x_heads , kv_chunk_size , k_channels_per_head )
73
+ key_t ,
74
+ (0 , 0 , chunk_idx ),
75
+ (batch_x_heads , k_channels_per_head , kv_chunk_size )
76
76
)
77
77
value_chunk = dynamic_slice (
78
78
value ,
@@ -99,14 +99,14 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk:
99
99
# TODO: refactor CrossAttention#get_attention_scores to share code with this
100
100
def _get_attention_scores_no_kv_chunking (
101
101
query : Tensor ,
102
- key : Tensor ,
102
+ key_t : Tensor ,
103
103
value : Tensor ,
104
104
scale : float ,
105
105
) -> Tensor :
106
106
attn_scores = torch .baddbmm (
107
107
torch .empty (1 , 1 , 1 , device = query .device , dtype = query .dtype ),
108
108
query ,
109
- key . transpose ( 1 , 2 ) ,
109
+ key_t ,
110
110
alpha = scale ,
111
111
beta = 0 ,
112
112
)
@@ -121,21 +121,21 @@ class ScannedChunk(NamedTuple):
121
121
122
122
def efficient_dot_product_attention (
123
123
query : Tensor ,
124
- key : Tensor ,
124
+ key_t : Tensor ,
125
125
value : Tensor ,
126
126
query_chunk_size = 1024 ,
127
127
kv_chunk_size : Optional [int ] = None ,
128
128
kv_chunk_size_min : Optional [int ] = None ,
129
129
use_checkpoint = True ,
130
130
):
131
- """Computes efficient dot-product attention given query, key, and value.
131
+ """Computes efficient dot-product attention given query, transposed key, and value.
132
132
This is efficient version of attention presented in
133
133
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
134
134
Args:
135
135
query: queries for calculating attention with shape of
136
136
`[batch * num_heads, tokens, channels_per_head]`.
137
- key : keys for calculating attention with shape of
138
- `[batch * num_heads, tokens, channels_per_head ]`.
137
+ key_t : keys for calculating attention with shape of
138
+ `[batch * num_heads, channels_per_head, tokens ]`.
139
139
value: values to be used in attention with shape of
140
140
`[batch * num_heads, tokens, channels_per_head]`.
141
141
query_chunk_size: int: query chunks size
@@ -146,7 +146,7 @@ def efficient_dot_product_attention(
146
146
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
147
147
"""
148
148
batch_x_heads , q_tokens , q_channels_per_head = query .shape
149
- _ , k_tokens , _ = key .shape
149
+ _ , _ , k_tokens = key_t .shape
150
150
scale = q_channels_per_head ** - 0.5
151
151
152
152
kv_chunk_size = min (kv_chunk_size or int (math .sqrt (k_tokens )), k_tokens )
@@ -178,7 +178,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor:
178
178
# fast-path for when there's just 1 query chunk
179
179
return compute_query_chunk_attn (
180
180
query = query ,
181
- key = key ,
181
+ key_t = key_t ,
182
182
value = value ,
183
183
)
184
184
@@ -187,7 +187,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor:
187
187
res = torch .cat ([
188
188
compute_query_chunk_attn (
189
189
query = get_query_chunk (i * query_chunk_size ),
190
- key = key ,
190
+ key_t = key_t ,
191
191
value = value ,
192
192
) for i in range (math .ceil (q_tokens / query_chunk_size ))
193
193
], dim = 1 )
0 commit comments