@@ -59,59 +59,12 @@ def copy_blocks(
59
59
HabanaPagedAttention .copy_blocks (kv_caches , src_to_dists )
60
60
61
61
62
- @dataclass
63
- class HabanaAttentionMetadata (AttentionMetadataPerStage , HabanaPagedAttentionMetadata ):
64
- """Metadata for HabanaAttentionbackend.
65
-
66
- NOTE: Any python object stored here is not updated when it is
67
- cuda-graph replayed. If you have values that need to be changed
68
- dynamically, it should be stored in tensor. The tensor has to be
69
- updated from `CUDAGraphRunner.forward` API.
70
- """
71
- # Currently, input sequences can only contain all prompts
72
- # or all decoding. True if all sequences are prompts.
73
- is_prompt : bool
74
- # (batch_size,). The sequence length per sequence. Sequence length means
75
- # the computed tokens + new tokens None if it is a decoding.
76
- seq_lens : Optional [List [int ]]
77
- # seq_lens stored as a tensor.
62
+ @dataclass (frozen = True )
63
+ class HabanaAttentionMetadata (HabanaPagedAttentionMetadata , AttentionMetadataPerStage ):
64
+ """Metadata for HabanaAttentionbackend."""
65
+ attn_bias : Optional [torch .Tensor ]
78
66
seq_lens_tensor : Optional [torch .Tensor ]
79
67
80
- # |---------- N-1 iteration --------|
81
- # |---------------- N iteration ---------------------|
82
- # |- tokenA -|......................|-- newTokens ---|
83
- # |---------- context_len ----------|
84
- # |-------------------- seq_len ----------------------|
85
- # |-- query_len ---|
86
-
87
- # Maximum query length in the batch.
88
- max_query_len : Optional [int ]
89
- # (batch_size + 1,). The cumulative subquery lengths of the sequences in
90
- # the batch, used to index into subquery. E.g., if the subquery length
91
- # is [4, 6], it is [0, 4, 10].
92
- subquery_start_loc : Optional [torch .Tensor ]
93
- # FIXME: It is for flash attn.
94
- # (batch_size + 1,). The cumulative sequence lengths of the sequences in
95
- # the batch, used to index into sequence. E.g., if the sequence length is
96
- # [4, 6], it is [0, 4, 10].
97
- seq_start_loc : Optional [torch .Tensor ]
98
- # (batch_size,) A tensor of context lengths (tokens that are computed
99
- # so far).
100
- context_lens_tensor : Optional [torch .Tensor ]
101
-
102
- # Whether or not if cuda graph is enabled.
103
- # Cuda-graph is currently enabled for decoding only.
104
- # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
105
- use_cuda_graph : bool
106
-
107
- def __post_init__ (self ):
108
- # Set during the execution of the first attention op.
109
- # It is a list because it is needed to set per prompt
110
- # when alibi slopes is used. It is because of the limitation
111
- # from xformer API.
112
- # will not appear in the __repr__ and __init__
113
- self .attn_bias : Optional [List [AttentionBias ]] = None
114
-
115
68
116
69
class HabanaAttentionImpl (AttentionImpl , torch .nn .Module ):
117
70
"""
@@ -202,57 +155,35 @@ def forward(
202
155
203
156
if prefill_meta := attn_metadata .prefill_metadata :
204
157
# Prompt run.
205
- if kv_cache is None or prefill_meta .block_tables .numel () == 0 :
206
- # TODO: move this outside of model
207
- assert prefill_meta .attn_bias is not None , 'attn_bias must be set before calling model.forward!'
208
- query_shape = (batch_size , seq_len , self .num_heads , self .head_size )
209
- kv_shape = (batch_size , seq_len_kv , self .num_kv_heads , self .head_size )
210
- out = xops .prompt_attention (
211
- query .view (query_shape ),
212
- key .view (kv_shape ),
213
- value .view (kv_shape ),
214
- attn_bias = prefill_meta .attn_bias ,
215
- p = 0.0 ,
216
- scale = self .scale ,
217
- qk_matmul_op = self .qk_matmul ,
218
- softmax_op = self .softmax ,
219
- kv_matmul_op = self .kv_matmul ,
220
- )
221
- output = out .reshape (batch_size , seq_len , hidden_size )
222
- else :
223
- # prefix-enabled attention
224
- output = HabanaPagedAttention .forward_prefix (
225
- query ,
226
- key ,
227
- value ,
228
- key_cache ,
229
- value_cache ,
230
- prefill_meta .block_tables ,
231
- prefill_meta .subquery_start_loc ,
232
- prefill_meta .seq_lens_tensor ,
233
- prefill_meta .context_lens_tensor ,
234
- prefill_meta .max_query_len ,
235
- self .alibi_slopes ,
236
- )
158
+ assert prefill_meta .attn_bias is not None , 'attn_bias must be set before calling model.forward!'
159
+ query_shape = (batch_size , seq_len , self .num_heads , self .head_size )
160
+ kv_shape = (batch_size , seq_len_kv , self .num_kv_heads , self .head_size )
161
+ out = xops .prompt_attention (
162
+ query .view (query_shape ),
163
+ key .view (kv_shape ),
164
+ value .view (kv_shape ),
165
+ attn_bias = prefill_meta .attn_bias ,
166
+ p = 0.0 ,
167
+ scale = self .scale ,
168
+ qk_matmul_op = self .qk_matmul ,
169
+ softmax_op = self .softmax ,
170
+ kv_matmul_op = self .kv_matmul ,
171
+ )
172
+ output = out .reshape (batch_size , seq_len , hidden_size )
237
173
if decode_meta := attn_metadata .decode_metadata :
238
174
# Decoding run.
239
175
output = HabanaPagedAttention .forward_decode (
240
- query ,
241
- key_cache ,
242
- value_cache ,
243
- decode_meta .block_tables ,
244
- decode_meta .seq_lens_tensor ,
245
- attn_metadata .kv_cache_dtype ,
246
- self .num_kv_heads ,
247
- self .scale ,
248
- self .alibi_slopes ,
249
- kv_scale ,
250
- self .qk_matmul ,
251
- self .softmax ,
252
- self .kv_matmul ,
253
- self .key_cache .fetch_from_cache ,
254
- self .value_cache .fetch_from_cache ,
255
- )
176
+ query = query ,
177
+ key_cache = key_cache ,
178
+ value_cache = value_cache ,
179
+ block_list = decode_meta .block_list ,
180
+ block_mapping = decode_meta .block_mapping ,
181
+ block_bias = decode_meta .attn_bias ,
182
+ scale = self .scale ,
183
+ qk_matmul_op = self .qk_matmul ,
184
+ kv_matmul_op = self .kv_matmul ,
185
+ keys_fetch_func = self .key_cache .fetch_from_cache ,
186
+ values_fetch_func = self .value_cache .fetch_from_cache )
256
187
257
188
# Reshape the output tensor.
258
189
return output .view (batch_size , seq_len , hidden_size )
0 commit comments