|
| 1 | +/* |
| 2 | + * Copyright (c) 2025 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#include <driver_types.h> |
| 17 | + |
| 18 | +#include <flashinfer/attention/mla_fa2.cuh> |
| 19 | +#include <flashinfer/attention/scheduler.cuh> |
| 20 | +#include <flashinfer/fastdiv.cuh> |
| 21 | +#include <optional> |
| 22 | + |
| 23 | +#include "batch_mla_config.inc" |
| 24 | +#include "pytorch_extension_utils.h" |
| 25 | + |
| 26 | +using namespace flashinfer; |
| 27 | + |
| 28 | +void BatchMLAPageAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, |
| 29 | + std::vector<int64_t> plan_info_vec, at::Tensor q_nope, |
| 30 | + at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache, |
| 31 | + at::Tensor kv_indices, at::Tensor o, |
| 32 | + std::optional<at::Tensor> maybe_lse, int mask_mode_code, |
| 33 | + int num_heads, int page_size, float sm_scale, int64_t cuda_stream) { |
| 34 | + // q_nope: [n, num_heads, head_dim_ckv] |
| 35 | + // q_pe: [n, num_heads, head_dim_kpe] |
| 36 | + // ckv_cache: [num_pages, page_size, head_dim_ckv] |
| 37 | + // kpe_cache: [num_pages, page_size, head_dim_kpe] |
| 38 | + MLAPlanInfo plan_info; |
| 39 | + plan_info.FromVector(plan_info_vec); |
| 40 | + |
| 41 | + auto device = q_nope.device(); |
| 42 | + |
| 43 | + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); |
| 44 | + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); |
| 45 | + |
| 46 | + const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code); |
| 47 | + |
| 48 | + auto q_scalar_type = q_nope.scalar_type(); |
| 49 | + auto kv_scalar_type = ckv_cache.scalar_type(); |
| 50 | + |
| 51 | + unsigned int q_nope_stride_n = q_nope.stride(0); |
| 52 | + unsigned int q_nope_stride_h = q_nope.stride(1); |
| 53 | + unsigned int q_pe_stride_n = q_pe.stride(0); |
| 54 | + unsigned int q_pe_stride_h = q_pe.stride(1); |
| 55 | + unsigned int ckv_stride_page = ckv_cache.stride(0); |
| 56 | + unsigned int ckv_stride_n = ckv_cache.stride(1); |
| 57 | + unsigned int kpe_stride_page = kpe_cache.stride(0); |
| 58 | + unsigned int kpe_stride_n = kpe_cache.stride(1); |
| 59 | + unsigned int o_stride_n = o.stride(0); |
| 60 | + unsigned int o_stride_h = o.stride(1); |
| 61 | + |
| 62 | + cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); |
| 63 | + |
| 64 | + DISPATCH_context( |
| 65 | + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] { |
| 66 | + Params params; |
| 67 | + |
| 68 | + params.q_nope = static_cast<DTypeQ*>(q_nope.data_ptr()); |
| 69 | + params.q_pe = static_cast<DTypeQ*>(q_pe.data_ptr()); |
| 70 | + params.ckv = static_cast<DTypeKV*>(ckv_cache.data_ptr()); |
| 71 | + params.kpe = static_cast<DTypeKV*>(kpe_cache.data_ptr()); |
| 72 | + params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr()); |
| 73 | + |
| 74 | + params.q_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_indptr_offset); |
| 75 | + params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset); |
| 76 | + params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr()); |
| 77 | + params.q_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_len_offset); |
| 78 | + params.kv_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset); |
| 79 | + params.q_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_start_offset); |
| 80 | + params.kv_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_start_offset); |
| 81 | + params.kv_end = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_end_offset); |
| 82 | + params.work_indptr = |
| 83 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset); |
| 84 | + params.final_o = static_cast<DTypeO*>(o.data_ptr()); |
| 85 | + params.final_lse = |
| 86 | + maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr; |
| 87 | + params.partial_o = |
| 88 | + GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_o_offset); |
| 89 | + params.partial_lse = |
| 90 | + GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset); |
| 91 | + |
| 92 | + params.num_heads = uint_fastdiv(num_heads); |
| 93 | + params.block_size = uint_fastdiv(page_size); |
| 94 | + |
| 95 | + params.q_nope_stride_n = q_nope_stride_n; |
| 96 | + params.q_nope_stride_h = q_nope_stride_h; |
| 97 | + params.q_pe_stride_n = q_pe_stride_n; |
| 98 | + params.q_pe_stride_h = q_pe_stride_h; |
| 99 | + params.ckv_stride_page = ckv_stride_page; |
| 100 | + params.ckv_stride_n = ckv_stride_n; |
| 101 | + params.kpe_stride_page = kpe_stride_page; |
| 102 | + params.kpe_stride_n = kpe_stride_n; |
| 103 | + params.o_stride_n = o_stride_n; |
| 104 | + params.o_stride_h = o_stride_h; |
| 105 | + |
| 106 | + params.sm_scale = sm_scale; |
| 107 | + |
| 108 | + cudaError_t status = mla::BatchMLAPageAttention<MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE>( |
| 109 | + params, plan_info.num_blks_x, plan_info.num_blks_y, stream); |
| 110 | + |
| 111 | + TORCH_CHECK(status == cudaSuccess, |
| 112 | + "Failed to run MLA, error: ", cudaGetErrorString(status)); |
| 113 | + }); |
| 114 | +} |
0 commit comments