Skip to content

Commit 95d38c4

Browse files
authored
[Executorch][sdpa] Setup the structure to enable quantized gemms for sdpa (#9912)
moving q_at_k and qk_at_v into separate functions to prepare for quantized sdpa Differential Revision: [D71370607](https://our.internmc.facebook.com/intern/diff/D71370607/) ghstack-source-id: 276012278 Pull Request resolved: #9889
1 parent 356d4f9 commit 95d38c4

File tree

2 files changed

+372
-85
lines changed

2 files changed

+372
-85
lines changed

Diff for: extension/llm/custom_ops/op_sdpa.cpp

+203-65
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,51 @@ bool validate_flash_attention_args(
8282
return true;
8383
}
8484

85+
bool validate_cache_quant_params_args(
86+
const Tensor& t,
87+
const Tensor& t_zero_points,
88+
const Tensor& t_scales) {
89+
ET_CHECK_OR_RETURN_FALSE(
90+
t.dim() == t_scales.dim(),
91+
"Quantized tensor and scales must have the same number of dimensions");
92+
ET_CHECK_OR_RETURN_FALSE(
93+
t.dim() == t_zero_points.dim(),
94+
"Quantized tensor and scales must have the same number of dimensions");
95+
96+
ET_CHECK_OR_RETURN_FALSE(
97+
(t.scalar_type() == ScalarType::Char), "Tensor must be of int8_t type");
98+
99+
ET_CHECK_OR_RETURN_FALSE(
100+
(t_scales.scalar_type() == ScalarType::Float),
101+
"Scales tensor must be of float type");
102+
103+
ET_CHECK_OR_RETURN_FALSE(
104+
(t_zero_points.scalar_type() == ScalarType::Char),
105+
"Zero points tensor must be of int8_t type");
106+
107+
// Sizes
108+
for (int64_t i = 0; i < t.dim() - 1; i++) {
109+
ET_CHECK_OR_RETURN_FALSE(
110+
(t.size(i) == t_scales.size(i)),
111+
"Quantized tensor and scales have different shape"
112+
"at dim: %" PRId64 ", t: %zd, t_scales: %zd",
113+
i,
114+
t.size(i),
115+
t_scales.size(i));
116+
;
117+
ET_CHECK_OR_RETURN_FALSE(
118+
(t.size(i) == t_zero_points.size(i)),
119+
"Quantized tensor and zero points have different shape"
120+
"at dim: %" PRId64 ", t: %zd, t_scales: %zd",
121+
i,
122+
t.size(i),
123+
t_zero_points.size(i));
124+
;
125+
}
126+
127+
return true;
128+
}
129+
85130
bool validate_cache_params(
86131
const Tensor& k_cache,
87132
const Tensor& v_cache,
@@ -233,7 +278,13 @@ Tensor& flash_attention_kernel_out(
233278
dropout_p,
234279
is_causal,
235280
attn_mask,
236-
scale);
281+
scale,
282+
nullopt,
283+
nullopt,
284+
nullopt,
285+
nullopt,
286+
nullopt,
287+
nullopt);
237288
} else if (q_seq_len >= 192) {
238289
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
239290
output,
@@ -243,7 +294,13 @@ Tensor& flash_attention_kernel_out(
243294
dropout_p,
244295
is_causal,
245296
attn_mask,
246-
scale);
297+
scale,
298+
nullopt,
299+
nullopt,
300+
nullopt,
301+
nullopt,
302+
nullopt,
303+
nullopt);
247304
} else {
248305
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
249306
output,
@@ -253,28 +310,19 @@ Tensor& flash_attention_kernel_out(
253310
dropout_p,
254311
is_causal,
255312
attn_mask,
256-
scale);
313+
scale,
314+
nullopt,
315+
nullopt,
316+
nullopt,
317+
nullopt,
318+
nullopt,
319+
nullopt);
257320
}
258321
});
259322
return output;
260323
}
261324

262-
/*
263-
Input params
264-
@param[in] q_projected Projected query with query weights.
265-
Format [n_layers, batch size, seq_len, num heads, head dim]
266-
@param[in] k_projected Projected query with key weights.
267-
Format [n_layers, batch size, seq_len, num heads, head dim]
268-
@param[in] v_projected Projected query with value weights.
269-
Format [n_layers, batch size, seq_len, num heads, head dim]
270-
@param[in] key_cache Cache of previous k_projected.
271-
Format [n_layers, batch size, max_seq_len, num heads, head dim]
272-
@param[in] key_cache Cache of previous v_projected.
273-
Format [n_layers, batch size, max_seq_len, num heads, head dim]
274-
....
275-
@param[in] start_pos: sequence position
276-
*/
277-
Tensor& custom_sdpa_out(
325+
Tensor& custom_sdpa_out_impl(
278326
RuntimeContext& ctx,
279327
const Tensor& q,
280328
const Tensor& k,
@@ -285,7 +333,13 @@ Tensor& custom_sdpa_out(
285333
const bool is_causal,
286334
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
287335
const optional<double> scale,
288-
Tensor& output) {
336+
Tensor& output,
337+
const optional<Tensor>& q_zero_points = nullopt,
338+
const optional<Tensor>& q_scales = nullopt,
339+
const optional<Tensor>& k_zero_points = nullopt,
340+
const optional<Tensor>& k_scales = nullopt,
341+
const optional<Tensor>& v_zero_points = nullopt,
342+
const optional<Tensor>& v_scales = nullopt) {
289343
ET_KERNEL_CHECK_MSG(
290344
ctx,
291345
!attn_mask.has_value() || !is_causal,
@@ -300,6 +354,40 @@ Tensor& custom_sdpa_out(
300354
output,
301355
"Invalid arguments");
302356

357+
bool is_seq_at_dim_1{true};
358+
if (q.scalar_type() == ScalarType::Char) {
359+
is_seq_at_dim_1 = false;
360+
ET_KERNEL_CHECK_MSG(
361+
ctx,
362+
q_scales.has_value() && q_zero_points.has_value() &&
363+
k_scales.has_value() && k_zero_points.has_value() &&
364+
q_scales.has_value() && q_zero_points.has_value(),
365+
InvalidArgument,
366+
output,
367+
"If q is quantized, k and v must be quantized as well");
368+
ET_KERNEL_CHECK_MSG(
369+
ctx,
370+
validate_cache_quant_params_args(
371+
q, q_zero_points.value(), q_scales.value()),
372+
InvalidArgument,
373+
output,
374+
"Invalid arguments for quantized query");
375+
ET_KERNEL_CHECK_MSG(
376+
ctx,
377+
validate_cache_quant_params_args(
378+
k, k_zero_points.value(), k_scales.value()),
379+
InvalidArgument,
380+
output,
381+
"Invalid arguments for quantized key");
382+
ET_KERNEL_CHECK_MSG(
383+
ctx,
384+
validate_cache_quant_params_args(
385+
v, v_zero_points.value(), v_scales.value()),
386+
InvalidArgument,
387+
output,
388+
"Invalid arguments for quantized value");
389+
}
390+
303391
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
304392

305393
const int64_t seq_len = q.size(1);
@@ -315,53 +403,103 @@ Tensor& custom_sdpa_out(
315403

316404
// TODO(task): replace the template param selection logic
317405
// with whatever apprpriately makes more sense for
318-
ET_SWITCH_FLOAT_TYPES(q.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
319-
// TODO we need to re-evaluate this for ARM CPUs
320-
// And there can be many so instead of templatizing
321-
// we might consider another appraoch
322-
if (q_seq_len >= 768) {
323-
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
324-
output,
325-
q,
326-
k,
327-
v,
328-
dropout_p,
329-
is_causal,
330-
attn_mask,
331-
scale,
332-
true, /* is_seq_at_dim_1 */
333-
start_pos,
334-
num_keys_for_causal_attention);
335-
} else if (q_seq_len >= 192) {
336-
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
337-
output,
338-
q,
339-
k,
340-
v,
341-
dropout_p,
342-
is_causal,
343-
attn_mask,
344-
scale,
345-
true, /* is_seq_at_dim_1 */
346-
start_pos,
347-
num_keys_for_causal_attention);
348-
} else {
349-
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
350-
output,
351-
q,
352-
k,
353-
v,
354-
dropout_p,
355-
is_causal,
356-
attn_mask,
357-
scale,
358-
true, /* is_seq_at_dim_1 */
359-
start_pos,
360-
num_keys_for_causal_attention);
361-
}
362-
});
406+
ET_SWITCH_FLOAT_TYPES(
407+
output.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
408+
// TODO we need to re-evaluate this for ARM CPUs
409+
// And there can be many so instead of templatizing
410+
// we might consider another appraoch
411+
if (q_seq_len >= 768) {
412+
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
413+
output,
414+
q,
415+
k,
416+
v,
417+
dropout_p,
418+
is_causal,
419+
attn_mask,
420+
scale,
421+
nullopt, // q_zero_points
422+
nullopt, // q_scales
423+
nullopt, // k_zero_points
424+
nullopt, // k_scales
425+
nullopt, // v_zero_points
426+
nullopt, // v_scales
427+
is_seq_at_dim_1, /* is_seq_at_dim_1 */
428+
start_pos,
429+
num_keys_for_causal_attention);
430+
} else if (q_seq_len >= 192) {
431+
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
432+
output,
433+
q,
434+
k,
435+
v,
436+
dropout_p,
437+
is_causal,
438+
attn_mask,
439+
scale,
440+
nullopt, // q_zero_points
441+
nullopt, // q_scales
442+
nullopt, // k_zero_points
443+
nullopt, // k_scales
444+
nullopt, // v_zero_points
445+
nullopt, // v_scales
446+
is_seq_at_dim_1, /* is_seq_at_dim_1 */
447+
start_pos,
448+
num_keys_for_causal_attention);
449+
} else {
450+
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
451+
output,
452+
q,
453+
k,
454+
v,
455+
dropout_p,
456+
is_causal,
457+
attn_mask,
458+
scale,
459+
nullopt, // q_zero_points
460+
nullopt, // q_scales
461+
nullopt, // k_zero_points
462+
nullopt, // k_scales
463+
nullopt, // v_zero_points
464+
nullopt, // v_scales
465+
is_seq_at_dim_1, /* is_seq_at_dim_1 */
466+
start_pos,
467+
num_keys_for_causal_attention);
468+
}
469+
});
363470
return output;
364471
}
472+
473+
/*
474+
Input params
475+
@param[in] q_projected Projected query with query weights.
476+
Format [n_layers, batch size, seq_len, num heads, head dim]
477+
@param[in] k_projected Projected query with key weights.
478+
Format [n_layers, batch size, seq_len, num heads, head dim]
479+
@param[in] v_projected Projected query with value weights.
480+
Format [n_layers, batch size, seq_len, num heads, head dim]
481+
@param[in] key_cache Cache of previous k_projected.
482+
Format [n_layers, batch size, max_seq_len, num heads, head dim]
483+
@param[in] key_cache Cache of previous v_projected.
484+
Format [n_layers, batch size, max_seq_len, num heads, head dim]
485+
....
486+
@param[in] start_pos: sequence position
487+
*/
488+
Tensor& custom_sdpa_out(
489+
RuntimeContext& ctx,
490+
const Tensor& q,
491+
const Tensor& k,
492+
const Tensor& v,
493+
const int64_t start_pos,
494+
const optional<Tensor>& attn_mask,
495+
const double dropout_p,
496+
const bool is_causal,
497+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
498+
const optional<double> scale,
499+
Tensor& output) {
500+
return custom_sdpa_out_impl(
501+
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
502+
}
365503
/*
366504
Input params
367505
@param[in] q_projected Projected query with query weights.

0 commit comments

Comments
 (0)