@@ -273,7 +273,6 @@ Tensor& flash_attention_kernel_out(
273
273
Format [n_layers, batch size, max_seq_len, num heads, head dim]
274
274
....
275
275
@param[in] start_pos: sequence position
276
- @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
277
276
*/
278
277
Tensor& custom_sdpa_out (
279
278
RuntimeContext& ctx,
@@ -306,63 +305,7 @@ Tensor& custom_sdpa_out(
306
305
const int64_t seq_len = q.size (1 );
307
306
auto q_seq_len = q.size (1 );
308
307
309
- // Refactor the following into create_view util perhaps using
310
- // TensorPtr
311
- std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim >
312
- sliced_key_dim_order{0 , 1 , 2 , 3 };
313
- std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim >
314
- sliced_key_sizes;
315
- sliced_key_sizes[0 ] = k.size (0 );
316
- sliced_key_sizes[1 ] = start_pos + seq_len; // key_cache.size(2);
317
- sliced_key_sizes[2 ] = k.size (2 );
318
- sliced_key_sizes[3 ] = k.size (3 );
319
- std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim >
320
- sliced_key_strides;
321
- dim_order_to_stride_nocheck (
322
- sliced_key_sizes.data (),
323
- sliced_key_dim_order.data (),
324
- sdpa::impl::kKVDim ,
325
- sliced_key_strides.data ());
326
- // since the cache is sliced, the batch stride needs to stay the same.
327
- sliced_key_strides[0 ] = k.strides ()[0 ];
328
- void * key_cache_data = k.mutable_data_ptr ();
329
- TensorImpl k_impl = TensorImpl (
330
- k.scalar_type (),
331
- sdpa::impl::kKVDim ,
332
- sliced_key_sizes.data (),
333
- key_cache_data,
334
- sliced_key_dim_order.data (),
335
- sliced_key_strides.data (),
336
- TensorShapeDynamism::STATIC);
337
- Tensor sliced_key_cache (&k_impl);
338
-
339
- std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim >
340
- sliced_value_dim_order{0 , 1 , 2 , 3 };
341
- std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim >
342
- sliced_value_sizes;
343
- sliced_value_sizes[0 ] = v.size (0 );
344
- sliced_value_sizes[1 ] = start_pos + seq_len; // value_cache.size(2);
345
- sliced_value_sizes[2 ] = v.size (2 );
346
- sliced_value_sizes[3 ] = v.size (3 );
347
- std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim >
348
- sliced_value_strides;
349
- dim_order_to_stride_nocheck (
350
- sliced_value_sizes.data (),
351
- sliced_value_dim_order.data (),
352
- sdpa::impl::kKVDim ,
353
- sliced_value_strides.data ());
354
- // since the cache is sliced, the batch stride needs to stay the same.
355
- sliced_value_strides[0 ] = v.strides ()[0 ];
356
- void * value_cache_data = v.mutable_data_ptr ();
357
- TensorImpl value_impl = TensorImpl (
358
- v.scalar_type (),
359
- sdpa::impl::kKVDim ,
360
- sliced_value_sizes.data (),
361
- value_cache_data,
362
- sliced_value_dim_order.data (),
363
- sliced_value_strides.data (),
364
- TensorShapeDynamism::STATIC);
365
- Tensor sliced_value_cache (&value_impl);
308
+ const int64_t num_keys_for_causal_attention = start_pos + seq_len;
366
309
367
310
ET_KERNEL_CHECK (
368
311
ctx,
@@ -380,38 +323,41 @@ Tensor& custom_sdpa_out(
380
323
sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
381
324
output,
382
325
q,
383
- sliced_key_cache ,
384
- sliced_value_cache ,
326
+ k ,
327
+ v ,
385
328
dropout_p,
386
329
is_causal,
387
330
attn_mask,
388
331
scale,
389
332
true , /* is_seq_at_dim_1 */
390
- start_pos);
333
+ start_pos,
334
+ num_keys_for_causal_attention);
391
335
} else if (q_seq_len >= 192 ) {
392
336
sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
393
337
output,
394
338
q,
395
- sliced_key_cache ,
396
- sliced_value_cache ,
339
+ k ,
340
+ v ,
397
341
dropout_p,
398
342
is_causal,
399
343
attn_mask,
400
344
scale,
401
345
true , /* is_seq_at_dim_1 */
402
- start_pos);
346
+ start_pos,
347
+ num_keys_for_causal_attention);
403
348
} else {
404
349
sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
405
350
output,
406
351
q,
407
- sliced_key_cache ,
408
- sliced_value_cache ,
352
+ k ,
353
+ v ,
409
354
dropout_p,
410
355
is_causal,
411
356
attn_mask,
412
357
scale,
413
358
true , /* is_seq_at_dim_1 */
414
- start_pos);
359
+ start_pos,
360
+ num_keys_for_causal_attention);
415
361
}
416
362
});
417
363
return output;
0 commit comments