Skip to content

Commit 0c23279

Browse files
committed
[Executorch][sdpa] Add accidentaly removed flash attentiona args check
as the title Differential Revision: [D71370594](https://our.internmc.facebook.com/intern/diff/D71370594/) ghstack-source-id: 276012276 Pull Request resolved: #9887
1 parent 2623d09 commit 0c23279

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

extension/llm/custom_ops/op_sdpa.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ Tensor& custom_sdpa_out(
294294
output,
295295
"attn_mask and is_causal cannot be set at the same time");
296296

297+
ET_KERNEL_CHECK_MSG(
298+
ctx,
299+
validate_flash_attention_args(q, k, v, attn_mask),
300+
InvalidArgument,
301+
output,
302+
"Invalid arguments");
303+
297304
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
298305

299306
const int64_t seq_len = q.size(1);

0 commit comments

Comments
 (0)