Skip to content

[Executorch][sdpa] Setup the structure to enable quantized gemms for sdpa #9912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 203 additions & 65 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,51 @@ bool validate_flash_attention_args(
return true;
}

bool validate_cache_quant_params_args(
const Tensor& t,
const Tensor& t_zero_points,
const Tensor& t_scales) {
ET_CHECK_OR_RETURN_FALSE(
t.dim() == t_scales.dim(),
"Quantized tensor and scales must have the same number of dimensions");
ET_CHECK_OR_RETURN_FALSE(
t.dim() == t_zero_points.dim(),
"Quantized tensor and scales must have the same number of dimensions");

ET_CHECK_OR_RETURN_FALSE(
(t.scalar_type() == ScalarType::Char), "Tensor must be of int8_t type");

ET_CHECK_OR_RETURN_FALSE(
(t_scales.scalar_type() == ScalarType::Float),
"Scales tensor must be of float type");

ET_CHECK_OR_RETURN_FALSE(
(t_zero_points.scalar_type() == ScalarType::Char),
"Zero points tensor must be of int8_t type");

// Sizes
for (int64_t i = 0; i < t.dim() - 1; i++) {
ET_CHECK_OR_RETURN_FALSE(
(t.size(i) == t_scales.size(i)),
"Quantized tensor and scales have different shape"
"at dim: %" PRId64 ", t: %zd, t_scales: %zd",
i,
t.size(i),
t_scales.size(i));
;
ET_CHECK_OR_RETURN_FALSE(
(t.size(i) == t_zero_points.size(i)),
"Quantized tensor and zero points have different shape"
"at dim: %" PRId64 ", t: %zd, t_scales: %zd",
i,
t.size(i),
t_zero_points.size(i));
;
}

return true;
}

bool validate_cache_params(
const Tensor& k_cache,
const Tensor& v_cache,
Expand Down Expand Up @@ -233,7 +278,13 @@ Tensor& flash_attention_kernel_out(
dropout_p,
is_causal,
attn_mask,
scale);
scale,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt);
} else if (q_seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
Expand All @@ -243,7 +294,13 @@ Tensor& flash_attention_kernel_out(
dropout_p,
is_causal,
attn_mask,
scale);
scale,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
output,
Expand All @@ -253,28 +310,19 @@ Tensor& flash_attention_kernel_out(
dropout_p,
is_causal,
attn_mask,
scale);
scale,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt);
}
});
return output;
}

/*
Input params
@param[in] q_projected Projected query with query weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] k_projected Projected query with key weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] v_projected Projected query with value weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] key_cache Cache of previous k_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
@param[in] key_cache Cache of previous v_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
....
@param[in] start_pos: sequence position
*/
Tensor& custom_sdpa_out(
Tensor& custom_sdpa_out_impl(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
Expand All @@ -285,7 +333,13 @@ Tensor& custom_sdpa_out(
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
Tensor& output,
const optional<Tensor>& q_zero_points = nullopt,
const optional<Tensor>& q_scales = nullopt,
const optional<Tensor>& k_zero_points = nullopt,
const optional<Tensor>& k_scales = nullopt,
const optional<Tensor>& v_zero_points = nullopt,
const optional<Tensor>& v_scales = nullopt) {
ET_KERNEL_CHECK_MSG(
ctx,
!attn_mask.has_value() || !is_causal,
Expand All @@ -300,6 +354,40 @@ Tensor& custom_sdpa_out(
output,
"Invalid arguments");

bool is_seq_at_dim_1{true};
if (q.scalar_type() == ScalarType::Char) {
is_seq_at_dim_1 = false;
ET_KERNEL_CHECK_MSG(
ctx,
q_scales.has_value() && q_zero_points.has_value() &&
k_scales.has_value() && k_zero_points.has_value() &&
q_scales.has_value() && q_zero_points.has_value(),
InvalidArgument,
output,
"If q is quantized, k and v must be quantized as well");
ET_KERNEL_CHECK_MSG(
ctx,
validate_cache_quant_params_args(
q, q_zero_points.value(), q_scales.value()),
InvalidArgument,
output,
"Invalid arguments for quantized query");
ET_KERNEL_CHECK_MSG(
ctx,
validate_cache_quant_params_args(
k, k_zero_points.value(), k_scales.value()),
InvalidArgument,
output,
"Invalid arguments for quantized key");
ET_KERNEL_CHECK_MSG(
ctx,
validate_cache_quant_params_args(
v, v_zero_points.value(), v_scales.value()),
InvalidArgument,
output,
"Invalid arguments for quantized value");
}

ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");

const int64_t seq_len = q.size(1);
Expand All @@ -315,53 +403,103 @@ Tensor& custom_sdpa_out(

// TODO(task): replace the template param selection logic
// with whatever apprpriately makes more sense for
ET_SWITCH_FLOAT_TYPES(q.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
true, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
} else if (q_seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
true, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
true, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
}
});
ET_SWITCH_FLOAT_TYPES(
output.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
nullopt, // q_zero_points
nullopt, // q_scales
nullopt, // k_zero_points
nullopt, // k_scales
nullopt, // v_zero_points
nullopt, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
} else if (q_seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
nullopt, // q_zero_points
nullopt, // q_scales
nullopt, // k_zero_points
nullopt, // k_scales
nullopt, // v_zero_points
nullopt, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
nullopt, // q_zero_points
nullopt, // q_scales
nullopt, // k_zero_points
nullopt, // k_scales
nullopt, // v_zero_points
nullopt, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
}
});
return output;
}

/*
Input params
@param[in] q_projected Projected query with query weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] k_projected Projected query with key weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] v_projected Projected query with value weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] key_cache Cache of previous k_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
@param[in] key_cache Cache of previous v_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
....
@param[in] start_pos: sequence position
*/
Tensor& custom_sdpa_out(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
return custom_sdpa_out_impl(
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
}
/*
Input params
@param[in] q_projected Projected query with query weights.
Expand Down
Loading
Loading