Skip to content

Commit 56d4927

Browse files
pytorchbotkirklandsign
authored andcommitted
[Executorch][sdpa] Add accidentaly removed flash attentiona args check (#9910)
Mostly preparing for adding quantized sdpa Differential Revision: [D71370596](https://our.internmc.facebook.com/intern/diff/D71370596/) ghstack-source-id: 276012277 Pull Request resolved: #9886
1 parent 2b105f9 commit 56d4927

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

extension/llm/custom_ops/op_sdpa.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ Tensor& custom_sdpa_out(
294294
output,
295295
"attn_mask and is_causal cannot be set at the same time");
296296

297+
ET_KERNEL_CHECK_MSG(
298+
ctx,
299+
validate_flash_attention_args(q, k, v, attn_mask),
300+
InvalidArgument,
301+
output,
302+
"Invalid arguments");
303+
297304
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
298305

299306
const int64_t seq_len = q.size(1);

0 commit comments

Comments
 (0)