diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 1d0d64f12d5..202ff17188d 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -82,6 +82,51 @@ bool validate_flash_attention_args( return true; } +bool validate_cache_quant_params_args( + const Tensor& t, + const Tensor& t_zero_points, + const Tensor& t_scales) { + ET_CHECK_OR_RETURN_FALSE( + t.dim() == t_scales.dim(), + "Quantized tensor and scales must have the same number of dimensions"); + ET_CHECK_OR_RETURN_FALSE( + t.dim() == t_zero_points.dim(), + "Quantized tensor and scales must have the same number of dimensions"); + + ET_CHECK_OR_RETURN_FALSE( + (t.scalar_type() == ScalarType::Char), "Tensor must be of int8_t type"); + + ET_CHECK_OR_RETURN_FALSE( + (t_scales.scalar_type() == ScalarType::Float), + "Scales tensor must be of float type"); + + ET_CHECK_OR_RETURN_FALSE( + (t_zero_points.scalar_type() == ScalarType::Char), + "Zero points tensor must be of int8_t type"); + + // Sizes + for (int64_t i = 0; i < t.dim() - 1; i++) { + ET_CHECK_OR_RETURN_FALSE( + (t.size(i) == t_scales.size(i)), + "Quantized tensor and scales have different shape" + "at dim: %" PRId64 ", t: %zd, t_scales: %zd", + i, + t.size(i), + t_scales.size(i)); + ; + ET_CHECK_OR_RETURN_FALSE( + (t.size(i) == t_zero_points.size(i)), + "Quantized tensor and zero points have different shape" + "at dim: %" PRId64 ", t: %zd, t_scales: %zd", + i, + t.size(i), + t_zero_points.size(i)); + ; + } + + return true; +} + bool validate_cache_params( const Tensor& k_cache, const Tensor& v_cache, @@ -233,7 +278,13 @@ Tensor& flash_attention_kernel_out( dropout_p, is_causal, attn_mask, - scale); + scale, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt); } else if (q_seq_len >= 192) { sdpa::impl::cpu_flash_attention( output, @@ -243,7 +294,13 @@ Tensor& flash_attention_kernel_out( dropout_p, is_causal, attn_mask, - scale); + scale, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt); } else { sdpa::impl::cpu_flash_attention( output, @@ -253,28 +310,19 @@ Tensor& flash_attention_kernel_out( dropout_p, is_causal, attn_mask, - scale); + scale, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt); } }); return output; } -/* - Input params - @param[in] q_projected Projected query with query weights. - Format [n_layers, batch size, seq_len, num heads, head dim] - @param[in] k_projected Projected query with key weights. - Format [n_layers, batch size, seq_len, num heads, head dim] - @param[in] v_projected Projected query with value weights. - Format [n_layers, batch size, seq_len, num heads, head dim] - @param[in] key_cache Cache of previous k_projected. - Format [n_layers, batch size, max_seq_len, num heads, head dim] - @param[in] key_cache Cache of previous v_projected. - Format [n_layers, batch size, max_seq_len, num heads, head dim] - .... - @param[in] start_pos: sequence position -*/ -Tensor& custom_sdpa_out( +Tensor& custom_sdpa_out_impl( RuntimeContext& ctx, const Tensor& q, const Tensor& k, @@ -285,7 +333,13 @@ Tensor& custom_sdpa_out( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, - Tensor& output) { + Tensor& output, + const optional& q_zero_points = nullopt, + const optional& q_scales = nullopt, + const optional& k_zero_points = nullopt, + const optional& k_scales = nullopt, + const optional& v_zero_points = nullopt, + const optional& v_scales = nullopt) { ET_KERNEL_CHECK_MSG( ctx, !attn_mask.has_value() || !is_causal, @@ -300,6 +354,40 @@ Tensor& custom_sdpa_out( output, "Invalid arguments"); + bool is_seq_at_dim_1{true}; + if (q.scalar_type() == ScalarType::Char) { + is_seq_at_dim_1 = false; + ET_KERNEL_CHECK_MSG( + ctx, + q_scales.has_value() && q_zero_points.has_value() && + k_scales.has_value() && k_zero_points.has_value() && + q_scales.has_value() && q_zero_points.has_value(), + InvalidArgument, + output, + "If q is quantized, k and v must be quantized as well"); + ET_KERNEL_CHECK_MSG( + ctx, + validate_cache_quant_params_args( + q, q_zero_points.value(), q_scales.value()), + InvalidArgument, + output, + "Invalid arguments for quantized query"); + ET_KERNEL_CHECK_MSG( + ctx, + validate_cache_quant_params_args( + k, k_zero_points.value(), k_scales.value()), + InvalidArgument, + output, + "Invalid arguments for quantized key"); + ET_KERNEL_CHECK_MSG( + ctx, + validate_cache_quant_params_args( + v, v_zero_points.value(), v_scales.value()), + InvalidArgument, + output, + "Invalid arguments for quantized value"); + } + ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); const int64_t seq_len = q.size(1); @@ -315,53 +403,103 @@ Tensor& custom_sdpa_out( // TODO(task): replace the template param selection logic // with whatever apprpriately makes more sense for - ET_SWITCH_FLOAT_TYPES(q.scalar_type(), ctx, "flash_attention", CTYPE, [&] { - // TODO we need to re-evaluate this for ARM CPUs - // And there can be many so instead of templatizing - // we might consider another appraoch - if (q_seq_len >= 768) { - sdpa::impl::cpu_flash_attention( - output, - q, - k, - v, - dropout_p, - is_causal, - attn_mask, - scale, - true, /* is_seq_at_dim_1 */ - start_pos, - num_keys_for_causal_attention); - } else if (q_seq_len >= 192) { - sdpa::impl::cpu_flash_attention( - output, - q, - k, - v, - dropout_p, - is_causal, - attn_mask, - scale, - true, /* is_seq_at_dim_1 */ - start_pos, - num_keys_for_causal_attention); - } else { - sdpa::impl::cpu_flash_attention( - output, - q, - k, - v, - dropout_p, - is_causal, - attn_mask, - scale, - true, /* is_seq_at_dim_1 */ - start_pos, - num_keys_for_causal_attention); - } - }); + ET_SWITCH_FLOAT_TYPES( + output.scalar_type(), ctx, "flash_attention", CTYPE, [&] { + // TODO we need to re-evaluate this for ARM CPUs + // And there can be many so instead of templatizing + // we might consider another appraoch + if (q_seq_len >= 768) { + sdpa::impl::cpu_flash_attention( + output, + q, + k, + v, + dropout_p, + is_causal, + attn_mask, + scale, + nullopt, // q_zero_points + nullopt, // q_scales + nullopt, // k_zero_points + nullopt, // k_scales + nullopt, // v_zero_points + nullopt, // v_scales + is_seq_at_dim_1, /* is_seq_at_dim_1 */ + start_pos, + num_keys_for_causal_attention); + } else if (q_seq_len >= 192) { + sdpa::impl::cpu_flash_attention( + output, + q, + k, + v, + dropout_p, + is_causal, + attn_mask, + scale, + nullopt, // q_zero_points + nullopt, // q_scales + nullopt, // k_zero_points + nullopt, // k_scales + nullopt, // v_zero_points + nullopt, // v_scales + is_seq_at_dim_1, /* is_seq_at_dim_1 */ + start_pos, + num_keys_for_causal_attention); + } else { + sdpa::impl::cpu_flash_attention( + output, + q, + k, + v, + dropout_p, + is_causal, + attn_mask, + scale, + nullopt, // q_zero_points + nullopt, // q_scales + nullopt, // k_zero_points + nullopt, // k_scales + nullopt, // v_zero_points + nullopt, // v_scales + is_seq_at_dim_1, /* is_seq_at_dim_1 */ + start_pos, + num_keys_for_causal_attention); + } + }); return output; } + +/* + Input params + @param[in] q_projected Projected query with query weights. + Format [n_layers, batch size, seq_len, num heads, head dim] + @param[in] k_projected Projected query with key weights. + Format [n_layers, batch size, seq_len, num heads, head dim] + @param[in] v_projected Projected query with value weights. + Format [n_layers, batch size, seq_len, num heads, head dim] + @param[in] key_cache Cache of previous k_projected. + Format [n_layers, batch size, max_seq_len, num heads, head dim] + @param[in] key_cache Cache of previous v_projected. + Format [n_layers, batch size, max_seq_len, num heads, head dim] + .... + @param[in] start_pos: sequence position +*/ +Tensor& custom_sdpa_out( + RuntimeContext& ctx, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output) { + return custom_sdpa_out_impl( + ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); +} /* Input params @param[in] q_projected Projected query with query weights. diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 46a1797f67c..0639c539ed1 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -30,6 +30,94 @@ namespace native { namespace sdpa::impl { +struct MaybeQuantizedMatrixData { + const void* data{nullptr}; + const int8_t* zero_points{nullptr}; + const float* scales{nullptr}; + int64_t m = 0, n = 0; + ScalarType dtype{ScalarType::Float}; + MaybeQuantizedMatrixData() = default; + MaybeQuantizedMatrixData( + const void* data_, + const int8_t* zero_points_, + const float* scales_, + int64_t m_, + int64_t n_, + ScalarType dtype_) + : data(data_), + zero_points(zero_points_), + scales(scales_), + m(m_), + n(n_), + dtype(dtype_) {} +}; + +template +void _q_at_k_gemm( + const int64_t q_m, + const int64_t k_n, + const int64_t qk_k, + const MaybeQuantizedMatrixData& q_data, + const int64_t q_stride_m, + const MaybeQuantizedMatrixData& k_data, + const int64_t k_stride_n, + accum_t* qk_data) { + ET_CHECK_MSG(q_data.dtype == k_data.dtype, "q and k must have same dtype"); + ET_CHECK_MSG( + q_data.dtype == ScalarType::Char || q_data.dtype == ScalarType::Float, + "q and k must be either int8 or float"); + if (q_data.dtype == ScalarType::Char) { + ET_CHECK_MSG(false, "int8 not supported yet"); + } else { + ::executorch::cpublas::gemm( + ::executorch::cpublas::TransposeType::Transpose, + ::executorch::cpublas::TransposeType::NoTranspose, + k_n, + q_m, + qk_k, + static_cast(1), + static_cast(k_data.data), + k_stride_n, + static_cast(q_data.data), + q_stride_m, + static_cast(0), + qk_data, + k_n); + } +} + +template +void _qk_at_v_gemm( + const int64_t m, + const int64_t n, + const int64_t k, + const accum_t* qk_data, + const int64_t qk_stride_m, + const MaybeQuantizedMatrixData& v_data, + const int64_t v_stride_n, + accum_t* o_data, + const int64_t o_stride_m, + const accum_t beta) { + if (v_data.dtype == ScalarType::Char) { + ET_CHECK_MSG(false, "int8 not supported yet"); + } else { + ::executorch::cpublas::gemm( + ::executorch::cpublas::TransposeType::NoTranspose, + ::executorch::cpublas::TransposeType::NoTranspose, + n, + m, + k, + static_cast(1), + static_cast(v_data.data), + v_stride_n, + qk_data, + qk_stride_m, + beta, + o_data, + o_stride_m); + } +} + constexpr size_t kKVDim = 4; template @@ -211,6 +299,12 @@ void cpu_flash_attention( bool is_causal, const optional& attn_mask, const optional& scale, + const optional& q_zero_points, + const optional& q_scales, + const optional& k_zero_points, + const optional& k_scales, + const optional& v_zero_points, + const optional& v_scales, bool is_seq_at_dim_1 = false, const int64_t start_pos = 0, const int64_t num_keys_for_causal_attention = -1) { @@ -300,6 +394,8 @@ void cpu_flash_attention( kvSize); } + bool is_quantized_sdpa = query.scalar_type() == ScalarType::Char; + auto strides = query.strides(); int64_t qStrideB = strides[0]; int64_t qStrideH = strides[1]; @@ -453,20 +549,54 @@ void cpu_flash_attention( int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); // Calculate scale * q @ k.T fill_stub(qk_data, static_cast(0), qSplitSize * kvSplitSize); - ::executorch::cpublas::gemm( - ::executorch::cpublas::TransposeType::Transpose, - ::executorch::cpublas::TransposeType::NoTranspose, + + const void* q_sub_matrix_data_ptr; + const void* k_sub_matrix_data_ptr; + const float* q_scales_ptr = nullptr; + const float* k_scales_ptr = nullptr; + const int8_t* q_zero_points_ptr = nullptr; + const int8_t* k_zero_points_ptr = nullptr; + int64_t q_offset = i * qStrideB + j * qStrideH + m * qStrideM; + int64_t k_offset = i * kStrideB + j_kv * kStrideH + n * kStrideN; + if (is_quantized_sdpa) { + ET_CHECK_MSG( + !is_seq_at_dim_1, "For quantized SDPA, seq_len must be at dim 2"); + q_scales_ptr = q_scales.value().const_data_ptr() + q_offset; + k_scales_ptr = k_scales.value().const_data_ptr() + k_offset; + q_zero_points_ptr = + q_zero_points.value().const_data_ptr() + q_offset; + k_zero_points_ptr = + k_zero_points.value().const_data_ptr() + k_offset; + q_sub_matrix_data_ptr = (const int8_t*)(q_data) + q_offset; + k_sub_matrix_data_ptr = (const int8_t*)(k_data) + k_offset; + } else { + q_sub_matrix_data_ptr = (const scalar_t*)(q_data) + q_offset; + k_sub_matrix_data_ptr = (const scalar_t*)(k_data) + k_offset; + } + MaybeQuantizedMatrixData q_sub_matrix_data = MaybeQuantizedMatrixData( + static_cast(q_sub_matrix_data_ptr), + q_zero_points_ptr, + q_scales_ptr, + qBlockSize, + headSize, + query.scalar_type()); + MaybeQuantizedMatrixData k_sub_matrix_data = MaybeQuantizedMatrixData( + static_cast(k_sub_matrix_data_ptr), + k_zero_points_ptr, + k_scales_ptr, kvBlockSize, + headSize, + key.scalar_type()); + _q_at_k_gemm( qBlockSize, + kvBlockSize, headSize, - static_cast(1), - k_data + i * kStrideB + j_kv * kStrideH + n * kStrideN, - kStrideN, - q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sub_matrix_data, qStrideM, - static_cast(0), - qk_data, - kvBlockSize); + k_sub_matrix_data, + kStrideN, + qk_data); + // There are 4 cases that is_causal has to cover to fill // not-attendable-position with -inf /* 1. Everything is attended to. This happens when m_start_pos > n + @@ -583,21 +713,40 @@ void cpu_flash_attention( headSize); } } - // Calculate Softmax(q @ k.T) @ v - ::executorch::cpublas::gemm( - ::executorch::cpublas::TransposeType::NoTranspose, - ::executorch::cpublas::TransposeType::NoTranspose, + + const void* v_sub_matrix_data_ptr; + const float* v_scales_ptr = nullptr; + const int8_t* v_zero_points_ptr = nullptr; + int64_t v_offset = i * vStrideB + j_kv * vStrideH + n * vStrideN; + if (is_quantized_sdpa) { + ET_CHECK_MSG( + !is_seq_at_dim_1, "For quantized SDPA, seq_len must be at dim 2"); + v_scales_ptr = v_scales.value().const_data_ptr() + v_offset; + v_zero_points_ptr = + v_zero_points.value().const_data_ptr() + v_offset; + v_sub_matrix_data_ptr = (const int8_t*)(v_data) + v_offset; + } else { + v_sub_matrix_data_ptr = (const scalar_t*)(v_data) + v_offset; + } + MaybeQuantizedMatrixData v_sub_matrix_data = MaybeQuantizedMatrixData( + static_cast(v_sub_matrix_data_ptr), + v_zero_points_ptr, + v_scales_ptr, + kvBlockSize, headSize, + value.scalar_type()); + // Calculate Softmax(q @ k.T) @ v + _qk_at_v_gemm( qBlockSize, + headSize, kvBlockSize, - static_cast(1), - v_data + i * vStrideB + j_kv * vStrideH + n * vStrideN, - vStrideN, - conditional_data_ptr(qk_data, qk_reduced_data), + qk_data, kvBlockSize, - n == 0 ? static_cast(0) : static_cast(1), + v_sub_matrix_data, + vStrideN, dst_data, - headSize); + headSize, + n == 0 ? static_cast(0) : static_cast(1)); } // dst <- dst / sum[row] // reorder MHA output with strides