Skip to content

Commit 07e6d08

Browse files
pytorchbotkirklandsign
authored andcommitted
[Executorch][SDPA] Remove slice creation (#9911)
Not needed because the only reason we were slicing was to get sliced seqlen of k and v that is used for causal attention. Differential Revision: [D71370595](https://our.internmc.facebook.com/intern/diff/D71370595/) ghstack-source-id: 276012275 Pull Request resolved: #9888
1 parent 56d4927 commit 07e6d08

File tree

2 files changed

+22
-68
lines changed

2 files changed

+22
-68
lines changed

extension/llm/custom_ops/op_sdpa.cpp

+13-67
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ Tensor& flash_attention_kernel_out(
273273
Format [n_layers, batch size, max_seq_len, num heads, head dim]
274274
....
275275
@param[in] start_pos: sequence position
276-
@param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
277276
*/
278277
Tensor& custom_sdpa_out(
279278
RuntimeContext& ctx,
@@ -306,63 +305,7 @@ Tensor& custom_sdpa_out(
306305
const int64_t seq_len = q.size(1);
307306
auto q_seq_len = q.size(1);
308307

309-
// Refactor the following into create_view util perhaps using
310-
// TensorPtr
311-
std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim>
312-
sliced_key_dim_order{0, 1, 2, 3};
313-
std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim>
314-
sliced_key_sizes;
315-
sliced_key_sizes[0] = k.size(0);
316-
sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
317-
sliced_key_sizes[2] = k.size(2);
318-
sliced_key_sizes[3] = k.size(3);
319-
std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim>
320-
sliced_key_strides;
321-
dim_order_to_stride_nocheck(
322-
sliced_key_sizes.data(),
323-
sliced_key_dim_order.data(),
324-
sdpa::impl::kKVDim,
325-
sliced_key_strides.data());
326-
// since the cache is sliced, the batch stride needs to stay the same.
327-
sliced_key_strides[0] = k.strides()[0];
328-
void* key_cache_data = k.mutable_data_ptr();
329-
TensorImpl k_impl = TensorImpl(
330-
k.scalar_type(),
331-
sdpa::impl::kKVDim,
332-
sliced_key_sizes.data(),
333-
key_cache_data,
334-
sliced_key_dim_order.data(),
335-
sliced_key_strides.data(),
336-
TensorShapeDynamism::STATIC);
337-
Tensor sliced_key_cache(&k_impl);
338-
339-
std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim>
340-
sliced_value_dim_order{0, 1, 2, 3};
341-
std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim>
342-
sliced_value_sizes;
343-
sliced_value_sizes[0] = v.size(0);
344-
sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
345-
sliced_value_sizes[2] = v.size(2);
346-
sliced_value_sizes[3] = v.size(3);
347-
std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim>
348-
sliced_value_strides;
349-
dim_order_to_stride_nocheck(
350-
sliced_value_sizes.data(),
351-
sliced_value_dim_order.data(),
352-
sdpa::impl::kKVDim,
353-
sliced_value_strides.data());
354-
// since the cache is sliced, the batch stride needs to stay the same.
355-
sliced_value_strides[0] = v.strides()[0];
356-
void* value_cache_data = v.mutable_data_ptr();
357-
TensorImpl value_impl = TensorImpl(
358-
v.scalar_type(),
359-
sdpa::impl::kKVDim,
360-
sliced_value_sizes.data(),
361-
value_cache_data,
362-
sliced_value_dim_order.data(),
363-
sliced_value_strides.data(),
364-
TensorShapeDynamism::STATIC);
365-
Tensor sliced_value_cache(&value_impl);
308+
const int64_t num_keys_for_causal_attention = start_pos + seq_len;
366309

367310
ET_KERNEL_CHECK(
368311
ctx,
@@ -380,38 +323,41 @@ Tensor& custom_sdpa_out(
380323
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
381324
output,
382325
q,
383-
sliced_key_cache,
384-
sliced_value_cache,
326+
k,
327+
v,
385328
dropout_p,
386329
is_causal,
387330
attn_mask,
388331
scale,
389332
true, /* is_seq_at_dim_1 */
390-
start_pos);
333+
start_pos,
334+
num_keys_for_causal_attention);
391335
} else if (q_seq_len >= 192) {
392336
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
393337
output,
394338
q,
395-
sliced_key_cache,
396-
sliced_value_cache,
339+
k,
340+
v,
397341
dropout_p,
398342
is_causal,
399343
attn_mask,
400344
scale,
401345
true, /* is_seq_at_dim_1 */
402-
start_pos);
346+
start_pos,
347+
num_keys_for_causal_attention);
403348
} else {
404349
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
405350
output,
406351
q,
407-
sliced_key_cache,
408-
sliced_value_cache,
352+
k,
353+
v,
409354
dropout_p,
410355
is_causal,
411356
attn_mask,
412357
scale,
413358
true, /* is_seq_at_dim_1 */
414-
start_pos);
359+
start_pos,
360+
num_keys_for_causal_attention);
415361
}
416362
});
417363
return output;

extension/llm/custom_ops/op_sdpa_impl.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ void cpu_flash_attention(
212212
const optional<Tensor>& attn_mask,
213213
const optional<double>& scale,
214214
bool is_seq_at_dim_1 = false,
215-
const int64_t start_pos = 0) {
215+
const int64_t start_pos = 0,
216+
const int64_t num_keys_for_causal_attention = -1) {
216217
(void)dropout_p;
217218
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
218219
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
@@ -258,6 +259,13 @@ void cpu_flash_attention(
258259
kvSize = value.size(1);
259260
}
260261

262+
if (num_keys_for_causal_attention > 0) {
263+
ET_CHECK_MSG(
264+
num_keys_for_causal_attention <= kvSize,
265+
"num_keys_for_causal_attention must be <= kvSize");
266+
kvSize = num_keys_for_causal_attention;
267+
}
268+
261269
ET_CHECK_MSG(
262270
num_heads_kv <= num_head,
263271
"FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64

0 commit comments

Comments
 (0)