From dfdb07320b23a6cd340897dfb0099cbdf68fd24a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 8 Feb 2025 22:51:49 +0000 Subject: [PATCH 1/8] upd --- csrc/batch_decode_customize_config.jinja | 1 + csrc/batch_prefill_customize_config.jinja | 1 + .../batch_prefill_sm90_customize_config.jinja | 1 + csrc/single_decode_customize_config.jinja | 1 + csrc/single_prefill_customize_config.jinja | 1 + ...single_prefill_sm90_customize_config.jinja | 1 + flashinfer/utils.py | 2 +- include/flashinfer/attention/decode.cuh | 25 +- .../attention/hopper/variant_helper.cuh | 64 +++++ .../flashinfer/attention/hopper/variants.cuh | 17 +- include/flashinfer/attention/prefill.cuh | 75 +++--- include/flashinfer/attention/scheduler.cuh | 4 +- .../flashinfer/attention/variant_helper.cuh | 64 +++++ include/flashinfer/attention/variants.cuh | 215 ++--------------- src/bench_single_prefill.cu | 29 ++- src/test_single_prefill.cu | 223 +++++++++--------- tests/test_jit_example.py | 76 ++---- 17 files changed, 340 insertions(+), 460 deletions(-) create mode 100644 include/flashinfer/attention/hopper/variant_helper.cuh create mode 100644 include/flashinfer/attention/variant_helper.cuh diff --git a/csrc/batch_decode_customize_config.jinja b/csrc/batch_decode_customize_config.jinja index 24ba9f1a6..da8037022 100644 --- a/csrc/batch_decode_customize_config.jinja +++ b/csrc/batch_decode_customize_config.jinja @@ -3,6 +3,7 @@ #include #include #include +#include #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} diff --git a/csrc/batch_prefill_customize_config.jinja b/csrc/batch_prefill_customize_config.jinja index ccdbcdee3..77490d71b 100644 --- a/csrc/batch_prefill_customize_config.jinja +++ b/csrc/batch_prefill_customize_config.jinja @@ -5,6 +5,7 @@ #include #include #include +#include #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} diff --git a/csrc/batch_prefill_sm90_customize_config.jinja b/csrc/batch_prefill_sm90_customize_config.jinja index 73cdd25ce..5b10355fc 100644 --- a/csrc/batch_prefill_sm90_customize_config.jinja +++ b/csrc/batch_prefill_sm90_customize_config.jinja @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include diff --git a/csrc/single_decode_customize_config.jinja b/csrc/single_decode_customize_config.jinja index 8a6baec5f..9bd3b429b 100644 --- a/csrc/single_decode_customize_config.jinja +++ b/csrc/single_decode_customize_config.jinja @@ -2,6 +2,7 @@ #include #include #include +#include #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} diff --git a/csrc/single_prefill_customize_config.jinja b/csrc/single_prefill_customize_config.jinja index ed2498ab9..fa31e08b7 100644 --- a/csrc/single_prefill_customize_config.jinja +++ b/csrc/single_prefill_customize_config.jinja @@ -4,6 +4,7 @@ #include #include #include +#include #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} diff --git a/csrc/single_prefill_sm90_customize_config.jinja b/csrc/single_prefill_sm90_customize_config.jinja index 0c8a0e4d5..7922ca2ba 100644 --- a/csrc/single_prefill_sm90_customize_config.jinja +++ b/csrc/single_prefill_sm90_customize_config.jinja @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include diff --git a/flashinfer/utils.py b/flashinfer/utils.py index df74de81e..3d6e70866 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -171,7 +171,7 @@ def _get_cache_alibi_slopes_buf( key = (f"alibi_slopes_{num_qo_heads}", device) buf = _cache_buf.get(key) if buf is None: - buf = (get_alibi_slopes(num_qo_heads) * log2e).to(device) + buf = get_alibi_slopes(num_qo_heads).to(device) _cache_buf[key] = buf return buf diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index d1aa9284d..af7b18b44 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -67,6 +67,7 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian uint32_t qo_head_idx, uint32_t kv_head_idx, float* s, state_t& st) { uint32_t tx = threadIdx.x, tz = threadIdx.z; + const float sm_scale_log2 = variant.sm_scale_log2; float m_prev = st.m; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { @@ -91,18 +92,19 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian const uint32_t pos = kv_idx_base + tz * tile_size + j; s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, kv_head_idx); - bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, - kv_head_idx); - s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf; + bool mask = iter_base + tz * tile_size + j < iter_bound; + mask &= variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, + kv_head_idx); + s[j] = mask ? s[j] : -math::inf; st.m = max(st.m, s[j]); } if constexpr (variant.use_softmax) { - float o_scale = math::ptx_exp2(m_prev - st.m); + float o_scale = math::ptx_exp2(m_prev * sm_scale_log2 - st.m * sm_scale_log2); st.d *= o_scale; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { - s[j] = math::ptx_exp2(s[j] - st.m); + s[j] = math::ptx_exp2(s[j] * sm_scale_log2 - st.m * sm_scale_log2); st.d += s[j]; } #pragma unroll @@ -263,12 +265,6 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par // do not apply rotary embedding to q matrix q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size); } - // multiple q_vec by sm_scale -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - q_vec[i] = variant.QueryTransform(params, q_vec[i]); - } - block.sync(); uint32_t chunk_start = kv_chunk_idx * kv_chunk_size; kv_chunk_size = min(kv_chunk_size, seq_len - chunk_start); @@ -361,6 +357,7 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par st_local.o.cast_store(o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); if (lse != nullptr) { + st_local.m *= variant.sm_scale_log2; lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse(); } } @@ -456,11 +453,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params // do not apply rotary embedding to q matrix q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size); } -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - q_vec[i] = variant.QueryTransform(params, q_vec[i]); - } - block.sync(); // preload k/v tiles uint32_t stage_idx = 0; @@ -586,6 +578,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params st.o.cast_store(o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); // write lse if (lse != nullptr) { + st.m *= variant.sm_scale_log2; lse[bx * num_qo_heads + qo_head_idx] = st.get_lse(); } } diff --git a/include/flashinfer/attention/hopper/variant_helper.cuh b/include/flashinfer/attention/hopper/variant_helper.cuh new file mode 100644 index 000000000..22fda01f4 --- /dev/null +++ b/include/flashinfer/attention/hopper/variant_helper.cuh @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H +#define FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H + +#include + +#include + +namespace flashinfer { + +#define REGISTER_QUERY_TRANSFORM(params, q, ...) \ + template \ + __device__ __forceinline__ T QueryTransform(const MainloopParams& params, void* q_smem) { \ + __VA_ARGS__ \ + } + +#define REGISTER_KEY_TRANSFORM(params, k, ...) \ + template \ + __device__ __forceinline__ T KeyTransform(const MainloopParams& params, void* k_smem) { \ + __VA_ARGS__ \ + } + +#define REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, \ + kv_head_idx, ...) \ + template \ + __device__ __forceinline__ T LogitsTransform( \ + const MainloopParams& params, T logits, uint32_t batch_idx, uint32_t qo_idx, \ + uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { \ + __VA_ARGS__ \ + } + +#define REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, ...) \ + template \ + __device__ __forceinline__ bool LogitsMask(const MainloopParams& params, uint32_t batch_idx, \ + uint32_t qo_idx, uint32_t kv_idx, \ + uint32_t qo_head_idx, uint32_t kv_head_idx) { \ + __VA_ARGS__ \ + } + +struct AttentionVariantBase { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return logits; }) + + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return true; }) +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H diff --git a/include/flashinfer/attention/hopper/variants.cuh b/include/flashinfer/attention/hopper/variants.cuh index 69ad56eeb..8f17cdbdf 100644 --- a/include/flashinfer/attention/hopper/variants.cuh +++ b/include/flashinfer/attention/hopper/variants.cuh @@ -20,6 +20,7 @@ #include "../../math.cuh" #include "attention_updater.cuh" +#include "variant_helper.cuh" namespace flashinfer { @@ -36,12 +37,8 @@ struct StandardAttention { return OnlineSoftmax(sm_scale_log2); } - template - __device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits, - uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return logits; - } + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return logits; }) }; struct LogitsSoftCap { @@ -60,12 +57,8 @@ struct LogitsSoftCap { return OnlineSoftmax(0.); } - template - __device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits, - uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return math::tanh(logits * pre_tanh_scale) * post_tanh_scale; - } + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return math::tanh(logits * pre_tanh_scale) * post_tanh_scale; }) }; template diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 1adbbf51e..7e3528213 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -506,26 +506,6 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( } } -template -__device__ __forceinline__ void q_smem_inplace_transform(const Params& params, - typename KTraits::AttentionVariant variant, - smem_t* q_smem) { - using DTypeQ = typename KTraits::DTypeQ; - const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; - constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; -#pragma unroll - for (uint32_t i = 0; i < KTraits::CTA_TILE_Q * HEAD_DIM_QK / (NUM_WARPS * 256); ++i) { - vec_t tmp; - tmp.load((DTypeQ*)(q_smem->base) + (i * NUM_WARPS + warp_idx) * 256 + lane_idx * 8); -#pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - tmp[reg_id] = variant.QueryTransform(params, tmp[reg_id]); - } - tmp.store((DTypeQ*)(q_smem->base) + (i * NUM_WARPS + warp_idx) * 256 + lane_idx * 8); - } -} - template __device__ __forceinline__ void k_smem_inplace_apply_rotary( const uint32_t kv_idx_base, smem_t* k_smem, uint32_t* k_smem_offset_r, @@ -760,6 +740,7 @@ __device__ __forceinline__ void logits_mask( template __device__ __forceinline__ void update_mdo_states( + typename KTraits::AttentionVariant variant, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { @@ -768,6 +749,7 @@ __device__ __forceinline__ void update_mdo_states( constexpr bool use_softmax = AttentionVariant::use_softmax; if constexpr (use_softmax) { + const float sm_scale = variant.sm_scale_log2; if constexpr (std::is_same_v) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -784,7 +766,7 @@ __device__ __forceinline__ void update_mdo_states( m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x2)); m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x1)); - float o_scale = math::ptx_exp2(m_prev - m[mma_q][j]); + float o_scale = math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { @@ -795,14 +777,14 @@ __device__ __forceinline__ void update_mdo_states( } #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - s_frag[mma_q][mma_kv][j * 2 + 0] = - math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 0] - m[mma_q][j]); - s_frag[mma_q][mma_kv][j * 2 + 1] = - math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 1] - m[mma_q][j]); - s_frag[mma_q][mma_kv][j * 2 + 4] = - math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 4] - m[mma_q][j]); - s_frag[mma_q][mma_kv][j * 2 + 5] = - math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 5] - m[mma_q][j]); + s_frag[mma_q][mma_kv][j * 2 + 0] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 1] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 4] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 5] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - m[mma_q][j] * sm_scale); } } } @@ -1300,15 +1282,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache group_size, &qo_smem); cp_async::commit_group(); - cp_async::wait_group<0>(); - block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + cp_async::wait_group<0>(); + block.sync(); q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); block.sync(); } - q_smem_inplace_transform(params, variant, &qo_smem); smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); @@ -1387,7 +1368,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache } // compute m,d states in online softmax - update_mdo_states(s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); produce_kv( @@ -1434,10 +1415,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache if (qo_idx < qo_len) { if (partition_kv) { lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; } else { lse[qo_idx * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; } } } @@ -1693,10 +1674,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV q_stride_h, group_size, &qo_smem); cp_async::commit_group(); - cp_async::wait_group<0>(); - block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + cp_async::wait_group<0>(); + block.sync(); IdType* q_rope_offset = nullptr; if constexpr (has_maybe_q_rope_offset_v) { @@ -1712,7 +1693,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV } block.sync(); } - q_smem_inplace_transform(params, variant, &qo_smem); const uint32_t num_iterations = ceil_div( (MASK_MODE == MaskMode::kCausal @@ -1803,7 +1783,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV } // compute m,d states in online softmax - update_mdo_states(s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); produce_kv( @@ -1853,10 +1833,11 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + qo_head_idx] = + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; } } } @@ -1975,10 +1956,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC q_stride_h, group_size, &qo_smem); cp_async::commit_group(); - cp_async::wait_group<0>(); - block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + cp_async::wait_group<0>(); + block.sync(); IdType* q_rope_offset = nullptr; if constexpr (has_maybe_q_rope_offset_v) { q_rope_offset = params.maybe_q_rope_offset; @@ -1993,7 +1974,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC } block.sync(); } - q_smem_inplace_transform(params, variant, &qo_smem); smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); size_t thr_local_kv_offset[NUM_MMA_KV * KV_THR_LAYOUT_COL / 2 / NUM_WARPS_Q]; @@ -2096,7 +2076,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC } // compute m,d states in online softmax - update_mdo_states(s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, @@ -2146,10 +2126,11 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + qo_head_idx] = + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; } } } diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index c7dce7fc5..0d8f52264 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -398,7 +398,7 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in if (split_kv) { AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.v_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeO), 16, "batch_decode_tmp_v"); + num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(float), 16, "batch_decode_tmp_v"); plan_info.s_offset = float_allocator.aligned_alloc_offset( num_qo_heads * padded_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); @@ -676,7 +676,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i if (split_kv) { AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.v_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof_dtype_o, 16, + num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof(float), 16, "batch_prefill_tmp_v"); plan_info.s_offset = float_allocator.aligned_alloc_offset( num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "batch_prefill_tmp_s"); diff --git a/include/flashinfer/attention/variant_helper.cuh b/include/flashinfer/attention/variant_helper.cuh new file mode 100644 index 000000000..5836321fc --- /dev/null +++ b/include/flashinfer/attention/variant_helper.cuh @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_VARIANT_HELPER_H +#define FLASHINFER_ATTENTION_VARIANT_HELPER_H + +#include + +#include + +namespace flashinfer { + +#define REGISTER_QUERY_TRANSFORM(params, q, ...) \ + template \ + __device__ __forceinline__ T QueryTransform(const Params& params, void* q_smem) { \ + __VA_ARGS__ \ + } + +#define REGISTER_KEY_TRANSFORM(params, k, ...) \ + template \ + __device__ __forceinline__ T KeyTransform(const Params& params, void* k_smem) { \ + __VA_ARGS__ \ + } + +#define REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, \ + kv_head_idx, ...) \ + template \ + __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, \ + uint32_t qo_idx, uint32_t kv_idx, \ + uint32_t qo_head_idx, uint32_t kv_head_idx) { \ + __VA_ARGS__ \ + } + +#define REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, ...) \ + template \ + __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, \ + uint32_t qo_idx, uint32_t kv_idx, \ + uint32_t qo_head_idx, uint32_t kv_head_idx) { \ + __VA_ARGS__ \ + } + +struct AttentionVariantBase { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return logits; }) + + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return true; }) +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_VARIANT_HELPER_H diff --git a/include/flashinfer/attention/variants.cuh b/include/flashinfer/attention/variants.cuh index 6955bf295..4fa6cdb0d 100644 --- a/include/flashinfer/attention/variants.cuh +++ b/include/flashinfer/attention/variants.cuh @@ -22,193 +22,21 @@ #include "../math.cuh" #include "../utils.cuh" +#include "variant_helper.cuh" namespace flashinfer { -// Query Transform function that multiplies the query matrix by sm_scale -struct StandardAttention { - static constexpr bool use_softmax = true; - - uint32_t window_left, qo_len, kv_len; - - // Create closure - template - __device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx, - uint8_t* smem_ptr) { - qo_len = params.get_qo_len(batch_idx); - kv_len = params.get_kv_len(batch_idx); - window_left = kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::log2e; - } - - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return logits; - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - return true; - } -}; - DEFINE_HAS_MEMBER(maybe_mask_indptr) -struct CustomMaskAttention { - static constexpr bool use_softmax = true; - - uint8_t* custom_mask_ptr; - uint32_t window_left, qo_len, kv_len; - - // Create closure - template - __device__ __host__ CustomMaskAttention(const Params& params, uint32_t batch_idx, - uint8_t* smem_ptr) { - if constexpr (has_maybe_mask_indptr_v) { - custom_mask_ptr = params.maybe_custom_mask + params.maybe_mask_indptr[batch_idx]; - } else { - custom_mask_ptr = params.maybe_custom_mask; - } - qo_len = params.get_qo_len(batch_idx); - kv_len = params.get_kv_len(batch_idx); - window_left = kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::log2e; - } - - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return logits; - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - const uint32_t offset = qo_idx * kv_len + kv_idx; - return ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); - } -}; - -struct SlidingWindowAttention { - static constexpr bool use_softmax = true; - - uint32_t window_left, qo_len, kv_len; - - // Create closure - template - __device__ __host__ __forceinline__ SlidingWindowAttention(const Params& params, - uint32_t batch_idx, - uint8_t* smem_ptr) { - qo_len = params.get_qo_len(batch_idx); - kv_len = params.get_kv_len(batch_idx); - window_left = (params.window_left >= 0) ? params.window_left : kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::log2e; - } - - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return logits; - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - return (kv_idx + qo_len + window_left >= kv_len + qo_idx); - } -}; - -struct LogitsSoftCap { - static constexpr bool use_softmax = true; - - uint32_t window_left, qo_len, kv_len; - - template - __device__ __host__ LogitsSoftCap(const Params& params, uint32_t batch_idx, uint8_t* smem_ptr) { - qo_len = params.get_qo_len(batch_idx); - kv_len = params.get_kv_len(batch_idx); - window_left = kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap); - } - - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return params.logits_soft_cap * math::log2e * float(math::tanh(logits)); - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - return true; - } -}; - -struct ALIBIAttention { - static constexpr bool use_softmax = true; - - uint32_t window_left, qo_len, kv_len; - - template - __device__ __host__ ALIBIAttention(const Params& params, uint32_t batch_idx, uint8_t* smem_ptr) { - qo_len = params.get_qo_len(batch_idx); - kv_len = params.get_kv_len(batch_idx); - window_left = kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::log2e; - } - - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return logits + params.maybe_alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - return true; - } -}; - template -struct DefaultAttention { +struct DefaultAttention : AttentionVariantBase { static constexpr bool use_softmax = true; - uint32_t qo_len, kv_len; uint8_t* custom_mask_ptr; + uint32_t qo_len, kv_len; uint32_t window_left; + float sm_scale_log2; + float soft_cap_pre_tanh_scale; // Create closure template @@ -216,6 +44,12 @@ struct DefaultAttention { uint8_t* smem_ptr) { qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); + if constexpr (use_logits_soft_cap) { + soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap); + sm_scale_log2 = math::log2e * params.logits_soft_cap; + } else { + sm_scale_log2 = params.sm_scale * math::log2e; + } if constexpr (use_custom_mask) { if constexpr (has_maybe_mask_indptr_v) { custom_mask_ptr = params.maybe_custom_mask + params.maybe_mask_indptr[batch_idx]; @@ -228,32 +62,17 @@ struct DefaultAttention { } } - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - if constexpr (use_logits_soft_cap) { - return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap); - } else { - return float(q) * params.sm_scale * math::log2e; - } - } - - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { if constexpr (use_alibi) { logits = logits + params.maybe_alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); } if constexpr (use_logits_soft_cap) { - logits = params.logits_soft_cap * math::log2e * float(math::tanh(logits)); + logits = float(math::tanh(logits * soft_cap_pre_tanh_scale)); } return logits; - } + }) - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { bool mask = true; if constexpr (use_custom_mask) { const uint32_t offset = qo_idx * kv_len + kv_idx; @@ -263,9 +82,9 @@ struct DefaultAttention { mask &= (kv_idx + qo_len + window_left >= kv_len + qo_idx); } return mask; - } + }) }; -} // namespace flashinfer +}; // namespace flashinfer #endif // FLASHINFER_ATTENTION_VARIANTS_CUH_ diff --git a/src/bench_single_prefill.cu b/src/bench_single_prefill.cu index 36054a524..394c5f193 100644 --- a/src/bench_single_prefill.cu +++ b/src/bench_single_prefill.cu @@ -182,19 +182,18 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("custom_mask", {0}) \ .add_int64_axis("cooperative", {1}) -auto bench_flashinfer_single_prefill_fp8_kv = bench_flashinfer_single_prefill_fp8; -NVBENCH_BENCH(bench_flashinfer_single_prefill_fp8_kv) - .set_name(("bench_flashinfer_single_prefill_fp8_kv")) - .add_int64_axis("kv_len", {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) - .add_int64_axis("num_qo_heads", {32}) - .add_int64_axis("num_kv_heads", {32}) - .add_int64_axis("head_dim", {128}) - .add_int64_axis("causal", {0, 1}) - .add_int64_axis("kv_layout", {0, 1}) - .add_int64_axis("pos_encoding_mode", {0, 1}) - .add_int64_axis("use_fp16_qk_reduction", {0, 1}) - .add_int64_axis("custom_mask", {0}) - .add_int64_axis("cooperative", {1}); +// auto bench_flashinfer_single_prefill_fp8_kv = bench_flashinfer_single_prefill_fp8; +// NVBENCH_BENCH(bench_flashinfer_single_prefill_fp8_kv) +// .set_name(("bench_flashinfer_single_prefill_fp8_kv")) +// .add_int64_axis("kv_len", {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, +// 65536}) .add_int64_axis("num_qo_heads", {32}) .add_int64_axis("num_kv_heads", {32}) +// .add_int64_axis("head_dim", {128}) +// .add_int64_axis("causal", {0, 1}) +// .add_int64_axis("kv_layout", {0, 1}) +// .add_int64_axis("pos_encoding_mode", {0, 1}) +// .add_int64_axis("use_fp16_qk_reduction", {0, 1}) +// .add_int64_axis("custom_mask", {0}) +// .add_int64_axis("cooperative", {1}); #define BENCH_FLASHINFER_APPEND_PREFILL(dtype_in, dtype_out) \ auto bench_flashinfer_single_append_prefill_##dtype_in##_##dtype_out##_ = \ @@ -213,5 +212,5 @@ NVBENCH_BENCH(bench_flashinfer_single_prefill_fp8_kv) .add_int64_axis("custom_mask", {0}) \ .add_int64_axis("cooperative", {0, 1}) -BENCH_FLASHINFER_PREFILL(half, half); -BENCH_FLASHINFER_APPEND_PREFILL(half, half); +BENCH_FLASHINFER_PREFILL(half, float); +// BENCH_FLASHINFER_APPEND_PREFILL(half, half); diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 6ec9ffc6d..beaabd4d4 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -29,6 +29,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, bool use_fp16_qk_reduction, float rtol = 1e-3, float atol = 1e-3) { + static_assert(std::is_same_v); std::vector q(qo_len * num_qo_heads * head_dim); std::vector k(kv_len * num_kv_heads * head_dim); std::vector v(kv_len * num_kv_heads * head_dim); @@ -88,7 +89,7 @@ void TestSinglePrefillKernelLongContextCorrectness(bool use_fp16_qk_reduction) { for (size_t qo_len : {1, 31, 63, 127}) { for (size_t kv_len : {31717}) { for (size_t num_heads : {1}) { - for (size_t head_dim : {64, 128, 256}) { + for (size_t head_dim : {128, 256}) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -104,26 +105,26 @@ void TestSinglePrefillKernelLongContextCorrectness(bool use_fp16_qk_reduction) { } } -template -void TestSinglePrefillFP8KernelLongContextCorrectness(bool use_fp16_qk_reduction) { - for (size_t qo_len : {1, 31, 63, 127}) { - for (size_t kv_len : {31717}) { - for (size_t num_heads : {1}) { - for (size_t head_dim : {64, 128, 256}) { - for (bool causal : {false, true}) { - for (size_t pos_encoding_mode : {0}) { - for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( - qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), - PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); - } - } - } - } - } - } - } -} +// template +// void TestSinglePrefillFP8KernelLongContextCorrectness(bool use_fp16_qk_reduction) { +// for (size_t qo_len : {1, 31, 63, 127}) { +// for (size_t kv_len : {31717}) { +// for (size_t num_heads : {1}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } template void TestSinglePrefillKernelShortContextCorrectness(bool use_fp16_qk_reduction) { @@ -132,7 +133,7 @@ void TestSinglePrefillKernelShortContextCorrectness(bool use_fp16_qk_reduction) for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {4, 8, 32}) { - for (size_t head_dim : {64, 128, 256}) { + for (size_t head_dim : {128, 256}) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -149,36 +150,36 @@ void TestSinglePrefillKernelShortContextCorrectness(bool use_fp16_qk_reduction) } } -template -void TestSinglePrefillFP8KernelShortContextCorrectness(bool use_fp16_qk_reduction) { - float rtol = 1e-3; - float atol = 1e-3; - for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { - for (size_t num_qo_heads : {32}) { - for (size_t num_kv_heads : {4, 8, 32}) { - for (size_t head_dim : {64, 128, 256}) { - for (bool causal : {false, true}) { - for (size_t pos_encoding_mode : {0}) { - for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( - qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, - QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction, - rtol, atol); - } - } - } - } - } - } - } -} +// template +// void TestSinglePrefillFP8KernelShortContextCorrectness(bool use_fp16_qk_reduction) { +// float rtol = 1e-3; +// float atol = 1e-3; +// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { +// for (size_t num_qo_heads : {32}) { +// for (size_t num_kv_heads : {4, 8, 32}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness( +// qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction, rtol, atol); +// } +// } +// } +// } +// } +// } +// } +// } template void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) { for (size_t qo_len : {399, 400, 401}) { for (size_t kv_len : {533, 534, 535}) { for (size_t num_heads : {12}) { - for (size_t head_dim : {64, 128, 256}) { + for (size_t head_dim : {128, 256}) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -194,83 +195,83 @@ void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) { } } -template -void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) { - for (size_t qo_len : {399, 400, 401}) { - for (size_t kv_len : {533, 534, 535}) { - for (size_t num_heads : {12}) { - for (size_t head_dim : {64, 128, 256}) { - for (bool causal : {false, true}) { - for (size_t pos_encoding_mode : {0}) { - for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( - qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), - PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); - } - } - } - } - } - } - } -} +// template +// void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) { +// for (size_t qo_len : {399, 400, 401}) { +// for (size_t kv_len : {533, 534, 535}) { +// for (size_t num_heads : {12}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16) { - TestSinglePrefillKernelLongContextCorrectness(false); + TestSinglePrefillKernelLongContextCorrectness(false); } -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) { - TestSinglePrefillKernelLongContextCorrectness(true); -} +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) { +// TestSinglePrefillKernelLongContextCorrectness(true); +// } TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16) { - TestSinglePrefillKernelShortContextCorrectness(false); + TestSinglePrefillKernelShortContextCorrectness(false); } -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) { - TestSinglePrefillKernelShortContextCorrectness(true); -} +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) { +// TestSinglePrefillKernelShortContextCorrectness(true); +// } TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) { - TestSinglePrefillKernelCorrectness(false); + TestSinglePrefillKernelCorrectness(false); } -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) { - TestSinglePrefillKernelCorrectness(true); -} +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) { +// TestSinglePrefillKernelCorrectness(true); +// } -#ifdef FLASHINFER_ENABLE_BF16 -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessBF16) { - TestSinglePrefillKernelLongContextCorrectness(false); -} -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessBF16) { - TestSinglePrefillKernelShortContextCorrectness(false); -} -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) { - TestSinglePrefillKernelCorrectness(false); -} -#endif +// #ifdef FLASHINFER_ENABLE_BF16 +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessBF16) { +// TestSinglePrefillKernelLongContextCorrectness(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessBF16) { +// TestSinglePrefillKernelShortContextCorrectness(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) { +// TestSinglePrefillKernelCorrectness(false); +// } +// #endif -#ifdef FLASHINFER_ENABLE_FP8_E4M3 -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE4M3) { - TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); -} -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { - TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); -} -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE4M3) { - TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); -} -#endif +// #ifdef FLASHINFER_ENABLE_FP8_E4M3 +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); +// } +// #endif -#ifdef FLASHINFER_ENABLE_FP8_E5M2 -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE5M2) { - TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); -} -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { - TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); -} -TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE5M2) { - TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); -} -#endif +// #ifdef FLASHINFER_ENABLE_FP8_E5M2 +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); +// } +// #endif diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index ba9924230..fdcc0c16c 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -18,11 +18,12 @@ def test_single_decode_mask(): torch.manual_seed(42) variant_decl = r""" -struct SingleDecodeWithCustomMask { +struct SingleDecodeWithCustomMask : AttentionVariantBase { static constexpr bool use_softmax = true; uint8_t* custom_mask_ptr; uint32_t window_left, qo_len, kv_len; + float sm_scale_log2; // Create closure template @@ -32,18 +33,7 @@ def test_single_decode_mask(): qo_len = 1; kv_len = params.get_kv_len(batch_idx); window_left = kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::log2e; - } - - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { - return logits; + sm_scale_log2 = params.sm_scale * math::log2e; } template @@ -89,11 +79,12 @@ def test_single_decode_mask(): flash_sigmoid_sm80_decl = r""" -struct FlashSigmoid { +struct FlashSigmoid : AttentionVariantBase { static constexpr bool use_softmax = false; uint32_t window_left, qo_len, kv_len; - float sigmoid_bias_log2e; + float sigmoid_scale_log2; + float sigmoid_bias_log2; // Create closure template @@ -102,32 +93,21 @@ def test_single_decode_mask(): qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; - sigmoid_bias_log2e = params.sigmoid_bias * math::log2e; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.logits_scale * math::log2e; + sigmoid_bias_log2 = params.sigmoid_bias * math::log2e; + sigmoid_scale_log2 = params.logits_scale * math::log2e; } template __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { - return math::ptx_rcp(1.f + math::ptx_exp2(-float(logits + sigmoid_bias_log2e))); - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - return true; + return math::ptx_rcp(1.f + math::ptx_exp2(-float(logits * sigmoid_scale_log2 + sigmoid_bias_log2))); } }; """ flash_sigmoid_sm90_decl = r""" -struct FlashSigmoid { +struct FlashSigmoid : AttentionVariantBase { float logits_scale_log2, sigmoid_bias_log2e; // Init template @@ -191,10 +171,11 @@ def test_flash_sigmoid(): def test_dump_logits(): torch.manual_seed(42) variant_decl = r""" -struct DumpLogits { +struct DumpLogits : AttentionVariantBase { static constexpr bool use_softmax = true; uint32_t window_left, qo_len, kv_len; + float sm_scale_log2; // Create closure template @@ -203,11 +184,7 @@ def test_dump_logits(): qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::log2e; + sm_scale_log2 = params.sm_scale * math::log2e; } template @@ -219,13 +196,6 @@ def test_dump_logits(): } return logits; } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - return true; - } }; """ jit_module = gen_customize_single_prefill_module( @@ -607,10 +577,11 @@ def test_batch_prefill_sm90_flash_sigmoid(): def test_debug_print_logits(): torch.manual_seed(42) variant_decl = r""" -struct DebugPrintLogits { +struct DebugPrintLogits : AttentionVariantBase { static constexpr bool use_softmax = true; uint32_t window_left, qo_len, kv_len; + float sm_scale_log2; // Create closure template @@ -619,11 +590,7 @@ def test_debug_print_logits(): qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; - } - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) { - return float(q) * params.sm_scale * math::log2e; + sm_scale_log2 = params.sm_scale * math::log2e; } template @@ -636,13 +603,6 @@ def test_debug_print_logits(): } return logits; } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { - return true; - } }; """ jit_module = gen_customize_single_prefill_module( @@ -677,7 +637,7 @@ def test_debug_print_logits(): def test_sm90_debug_print_logits(): torch.manual_seed(42) variant_decl = r""" -struct DebugPrintLogits { +struct DebugPrintLogits : AttentionVariantBase { float sm_scale_log2; int qo_len, kv_len; @@ -752,7 +712,7 @@ def test_sm90_debug_print_logits(): if __name__ == "__main__": - test_single_decode_mask() + # test_single_decode_mask() test_flash_sigmoid() test_dump_logits() test_debug_print_logits() From db6c743d94e75f5b1c95079ac87834206e9cd161 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Feb 2025 01:47:02 +0000 Subject: [PATCH 2/8] upd --- src/bench_single_prefill.cu | 29 ++--- src/test_single_prefill.cu | 223 ++++++++++++++++++------------------ 2 files changed, 126 insertions(+), 126 deletions(-) diff --git a/src/bench_single_prefill.cu b/src/bench_single_prefill.cu index 394c5f193..36054a524 100644 --- a/src/bench_single_prefill.cu +++ b/src/bench_single_prefill.cu @@ -182,18 +182,19 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("custom_mask", {0}) \ .add_int64_axis("cooperative", {1}) -// auto bench_flashinfer_single_prefill_fp8_kv = bench_flashinfer_single_prefill_fp8; -// NVBENCH_BENCH(bench_flashinfer_single_prefill_fp8_kv) -// .set_name(("bench_flashinfer_single_prefill_fp8_kv")) -// .add_int64_axis("kv_len", {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, -// 65536}) .add_int64_axis("num_qo_heads", {32}) .add_int64_axis("num_kv_heads", {32}) -// .add_int64_axis("head_dim", {128}) -// .add_int64_axis("causal", {0, 1}) -// .add_int64_axis("kv_layout", {0, 1}) -// .add_int64_axis("pos_encoding_mode", {0, 1}) -// .add_int64_axis("use_fp16_qk_reduction", {0, 1}) -// .add_int64_axis("custom_mask", {0}) -// .add_int64_axis("cooperative", {1}); +auto bench_flashinfer_single_prefill_fp8_kv = bench_flashinfer_single_prefill_fp8; +NVBENCH_BENCH(bench_flashinfer_single_prefill_fp8_kv) + .set_name(("bench_flashinfer_single_prefill_fp8_kv")) + .add_int64_axis("kv_len", {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) + .add_int64_axis("num_qo_heads", {32}) + .add_int64_axis("num_kv_heads", {32}) + .add_int64_axis("head_dim", {128}) + .add_int64_axis("causal", {0, 1}) + .add_int64_axis("kv_layout", {0, 1}) + .add_int64_axis("pos_encoding_mode", {0, 1}) + .add_int64_axis("use_fp16_qk_reduction", {0, 1}) + .add_int64_axis("custom_mask", {0}) + .add_int64_axis("cooperative", {1}); #define BENCH_FLASHINFER_APPEND_PREFILL(dtype_in, dtype_out) \ auto bench_flashinfer_single_append_prefill_##dtype_in##_##dtype_out##_ = \ @@ -212,5 +213,5 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("custom_mask", {0}) \ .add_int64_axis("cooperative", {0, 1}) -BENCH_FLASHINFER_PREFILL(half, float); -// BENCH_FLASHINFER_APPEND_PREFILL(half, half); +BENCH_FLASHINFER_PREFILL(half, half); +BENCH_FLASHINFER_APPEND_PREFILL(half, half); diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index beaabd4d4..6ec9ffc6d 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -29,7 +29,6 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, bool use_fp16_qk_reduction, float rtol = 1e-3, float atol = 1e-3) { - static_assert(std::is_same_v); std::vector q(qo_len * num_qo_heads * head_dim); std::vector k(kv_len * num_kv_heads * head_dim); std::vector v(kv_len * num_kv_heads * head_dim); @@ -89,7 +88,7 @@ void TestSinglePrefillKernelLongContextCorrectness(bool use_fp16_qk_reduction) { for (size_t qo_len : {1, 31, 63, 127}) { for (size_t kv_len : {31717}) { for (size_t num_heads : {1}) { - for (size_t head_dim : {128, 256}) { + for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -105,26 +104,26 @@ void TestSinglePrefillKernelLongContextCorrectness(bool use_fp16_qk_reduction) { } } -// template -// void TestSinglePrefillFP8KernelLongContextCorrectness(bool use_fp16_qk_reduction) { -// for (size_t qo_len : {1, 31, 63, 127}) { -// for (size_t kv_len : {31717}) { -// for (size_t num_heads : {1}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness( -// qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), -// PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); -// } -// } -// } -// } -// } -// } -// } -// } +template +void TestSinglePrefillFP8KernelLongContextCorrectness(bool use_fp16_qk_reduction) { + for (size_t qo_len : {1, 31, 63, 127}) { + for (size_t kv_len : {31717}) { + for (size_t num_heads : {1}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + for (size_t kv_layout : {0, 1}) { + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); + } + } + } + } + } + } + } +} template void TestSinglePrefillKernelShortContextCorrectness(bool use_fp16_qk_reduction) { @@ -133,7 +132,7 @@ void TestSinglePrefillKernelShortContextCorrectness(bool use_fp16_qk_reduction) for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {4, 8, 32}) { - for (size_t head_dim : {128, 256}) { + for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -150,36 +149,36 @@ void TestSinglePrefillKernelShortContextCorrectness(bool use_fp16_qk_reduction) } } -// template -// void TestSinglePrefillFP8KernelShortContextCorrectness(bool use_fp16_qk_reduction) { -// float rtol = 1e-3; -// float atol = 1e-3; -// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { -// for (size_t num_qo_heads : {32}) { -// for (size_t num_kv_heads : {4, 8, 32}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness( -// qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, -// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction, rtol, atol); -// } -// } -// } -// } -// } -// } -// } -// } +template +void TestSinglePrefillFP8KernelShortContextCorrectness(bool use_fp16_qk_reduction) { + float rtol = 1e-3; + float atol = 1e-3; + for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { + for (size_t num_qo_heads : {32}) { + for (size_t num_kv_heads : {4, 8, 32}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + for (size_t kv_layout : {0, 1}) { + _TestSinglePrefillKernelCorrectness( + qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction, + rtol, atol); + } + } + } + } + } + } + } +} template void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) { for (size_t qo_len : {399, 400, 401}) { for (size_t kv_len : {533, 534, 535}) { for (size_t num_heads : {12}) { - for (size_t head_dim : {128, 256}) { + for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -195,83 +194,83 @@ void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) { } } -// template -// void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) { -// for (size_t qo_len : {399, 400, 401}) { -// for (size_t kv_len : {533, 534, 535}) { -// for (size_t num_heads : {12}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness( -// qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), -// PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); -// } -// } -// } -// } -// } -// } -// } -// } +template +void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) { + for (size_t qo_len : {399, 400, 401}) { + for (size_t kv_len : {533, 534, 535}) { + for (size_t num_heads : {12}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + for (size_t kv_layout : {0, 1}) { + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); + } + } + } + } + } + } + } +} TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16) { - TestSinglePrefillKernelLongContextCorrectness(false); + TestSinglePrefillKernelLongContextCorrectness(false); } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) { -// TestSinglePrefillKernelLongContextCorrectness(true); -// } +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) { + TestSinglePrefillKernelLongContextCorrectness(true); +} TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16) { - TestSinglePrefillKernelShortContextCorrectness(false); + TestSinglePrefillKernelShortContextCorrectness(false); } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) { -// TestSinglePrefillKernelShortContextCorrectness(true); -// } +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) { + TestSinglePrefillKernelShortContextCorrectness(true); +} TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) { - TestSinglePrefillKernelCorrectness(false); + TestSinglePrefillKernelCorrectness(false); } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) { -// TestSinglePrefillKernelCorrectness(true); -// } +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) { + TestSinglePrefillKernelCorrectness(true); +} -// #ifdef FLASHINFER_ENABLE_BF16 -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessBF16) { -// TestSinglePrefillKernelLongContextCorrectness(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessBF16) { -// TestSinglePrefillKernelShortContextCorrectness(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) { -// TestSinglePrefillKernelCorrectness(false); -// } -// #endif +#ifdef FLASHINFER_ENABLE_BF16 +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessBF16) { + TestSinglePrefillKernelLongContextCorrectness(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessBF16) { + TestSinglePrefillKernelShortContextCorrectness(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) { + TestSinglePrefillKernelCorrectness(false); +} +#endif -// #ifdef FLASHINFER_ENABLE_FP8_E4M3 -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE4M3) { -// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { -// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE4M3) { -// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); -// } -// #endif +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE4M3) { + TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { + TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE4M3) { + TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); +} +#endif -// #ifdef FLASHINFER_ENABLE_FP8_E5M2 -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE5M2) { -// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { -// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE5M2) { -// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); -// } -// #endif +#ifdef FLASHINFER_ENABLE_FP8_E5M2 +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE5M2) { + TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { + TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE5M2) { + TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); +} +#endif From 04bd2e7a7c594c9a69ab7c9410b76c7723eb5020 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Feb 2025 15:56:24 +0000 Subject: [PATCH 3/8] upd --- include/flashinfer/attention/prefill.cuh | 45 +++++++++++++++++------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 7e3528213..15f9a579c 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -789,6 +789,7 @@ __device__ __forceinline__ void update_mdo_states( } } } else if constexpr (std::is_same_v) { + const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { half m_prev[2]; @@ -808,7 +809,7 @@ __device__ __forceinline__ void update_mdo_states( __hmax2(*(half2*)&m[mma_q], math::shfl_xor_sync(*(half2*)&m[mma_q], 0x1)); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - float o_scale = math::ptx_exp2(float(m_prev[j] - m[mma_q][j])); + float o_scale = math::ptx_exp2(float(m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); d[mma_q][j] *= o_scale; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { @@ -821,9 +822,9 @@ __device__ __forceinline__ void update_mdo_states( #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { *(half2*)&s_frag[mma_q][mma_kv][j * 2] = - math::ptx_exp2(*(half2*)&s_frag[mma_q][mma_kv][j * 2] - m2); - *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = - math::ptx_exp2(*(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] - m2); + math::ptx_exp2(*(half2*)&s_frag[mma_q][mma_kv][j * 2] * sm_scale - m2 * sm_scale); + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = math::ptx_exp2( + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m2 * sm_scale); } } } @@ -941,6 +942,22 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_V } } +template +__device__ __forceinline__ void finalize_m(typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*m)[2]) { + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf)) { + m[mma_q][j] *= variant.sm_scale_log2; + } + } + } + } +} + /*! * \brief Synchronize the states of the MDO kernel across the threadblock along threadIdx.z. */ @@ -1388,6 +1405,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache cp_async::wait_group<0>(); block.sync(); + finalize_m(variant, m); + // threadblock synchronization threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx); @@ -1415,10 +1434,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache if (qo_idx < qo_len) { if (partition_kv) { lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[qo_idx * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -1803,6 +1822,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV cp_async::wait_group<0>(); block.sync(); + finalize_m(variant, m); + // threadblock synchronization threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx); @@ -1833,11 +1854,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -2096,6 +2116,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC cp_async::wait_group<0>(); block.sync(); + finalize_m(variant, m); + // threadblock synchronization threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx); @@ -2126,11 +2148,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]) * variant.sm_scale_log2; + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } From 86732716d09eb6080c5d45d709930e0618a44f42 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Feb 2025 16:07:00 +0000 Subject: [PATCH 4/8] upd --- include/flashinfer/attention/decode.cuh | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index af7b18b44..988bf5510 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -143,6 +143,18 @@ __device__ __forceinline__ void update_local_state(const T* smem, const float* s } } +template +__device__ __forceinline__ void finalize_m(AttentionVariant variant, state_t& st) { + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + if (st.m != -math::inf) { + st.m *= variant.sm_scale_log2; + } + } + } +} + /*! * \brief Synchronize the state of all warps inside a threadblock. * \tparam vec_size A template integer indicates the vector size @@ -349,6 +361,8 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par cp_async::wait_group<0>(); block.sync(); + finalize_m(variant, st_local); + // sync local state of all warps inside a threadblock sync_state(variant, st_local, reinterpret_cast(smem), smem_md); if constexpr (variant.use_softmax) { @@ -357,7 +371,6 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par st_local.o.cast_store(o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); if (lse != nullptr) { - st_local.m *= variant.sm_scale_log2; lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse(); } } @@ -568,6 +581,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params cp_async::wait_group<0>(); block.sync(); + finalize_m(variant, st); + // sync local state of all warps inside a threadblock sync_state(variant, st, reinterpret_cast(smem), smem_md); if constexpr (variant.use_softmax) { @@ -578,7 +593,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params st.o.cast_store(o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); // write lse if (lse != nullptr) { - st.m *= variant.sm_scale_log2; lse[bx * num_qo_heads + qo_head_idx] = st.get_lse(); } } From fd698342dbc4d536120be8668f0d10aa9a72befb Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Feb 2025 16:12:53 +0000 Subject: [PATCH 5/8] upd --- include/flashinfer/attention/decode.cuh | 29 ++++++------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 988bf5510..8c4c89c50 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -92,19 +92,19 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian const uint32_t pos = kv_idx_base + tz * tile_size + j; s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, kv_head_idx); - bool mask = iter_base + tz * tile_size + j < iter_bound; - mask &= variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, - kv_head_idx); - s[j] = mask ? s[j] : -math::inf; + s[j] *= variant.sm_scale_log2; + bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, + kv_head_idx); + s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf; st.m = max(st.m, s[j]); } if constexpr (variant.use_softmax) { - float o_scale = math::ptx_exp2(m_prev * sm_scale_log2 - st.m * sm_scale_log2); + float o_scale = math::ptx_exp2(m_prev - st.m); st.d *= o_scale; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { - s[j] = math::ptx_exp2(s[j] * sm_scale_log2 - st.m * sm_scale_log2); + s[j] = math::ptx_exp2(s[j] - st.m); st.d += s[j]; } #pragma unroll @@ -143,18 +143,6 @@ __device__ __forceinline__ void update_local_state(const T* smem, const float* s } } -template -__device__ __forceinline__ void finalize_m(AttentionVariant variant, state_t& st) { - if constexpr (variant.use_softmax) { -#pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - if (st.m != -math::inf) { - st.m *= variant.sm_scale_log2; - } - } - } -} - /*! * \brief Synchronize the state of all warps inside a threadblock. * \tparam vec_size A template integer indicates the vector size @@ -277,6 +265,7 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par // do not apply rotary embedding to q matrix q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size); } + block.sync(); uint32_t chunk_start = kv_chunk_idx * kv_chunk_size; kv_chunk_size = min(kv_chunk_size, seq_len - chunk_start); @@ -361,8 +350,6 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par cp_async::wait_group<0>(); block.sync(); - finalize_m(variant, st_local); - // sync local state of all warps inside a threadblock sync_state(variant, st_local, reinterpret_cast(smem), smem_md); if constexpr (variant.use_softmax) { @@ -581,8 +568,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params cp_async::wait_group<0>(); block.sync(); - finalize_m(variant, st); - // sync local state of all warps inside a threadblock sync_state(variant, st, reinterpret_cast(smem), smem_md); if constexpr (variant.use_softmax) { From 54a9f9acd629ed0fe01002cfd9aa5986cfe33a66 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Feb 2025 16:19:37 +0000 Subject: [PATCH 6/8] upd --- include/flashinfer/attention/decode.cuh | 5 +++-- tests/test_jit_example.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 8c4c89c50..15e73c269 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -67,7 +67,6 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian uint32_t qo_head_idx, uint32_t kv_head_idx, float* s, state_t& st) { uint32_t tx = threadIdx.x, tz = threadIdx.z; - const float sm_scale_log2 = variant.sm_scale_log2; float m_prev = st.m; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { @@ -92,7 +91,9 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian const uint32_t pos = kv_idx_base + tz * tile_size + j; s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, kv_head_idx); - s[j] *= variant.sm_scale_log2; + if constexpr (variant.use_softmax) { + s[j] *= variant.sm_scale_log2; + } bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, kv_head_idx); s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf; diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index fdcc0c16c..3bc08288a 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -192,7 +192,7 @@ def test_dump_logits(): uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { if (qo_idx < qo_len && kv_idx < kv_len) { - params.output_logits[qo_head_idx * (qo_len * kv_len) + qo_idx * kv_len + kv_idx] = logits * math::loge2; + params.output_logits[qo_head_idx * (qo_len * kv_len) + qo_idx * kv_len + kv_idx] = logits * params.sm_scale; } return logits; } @@ -712,7 +712,7 @@ def test_sm90_debug_print_logits(): if __name__ == "__main__": - # test_single_decode_mask() + test_single_decode_mask() test_flash_sigmoid() test_dump_logits() test_debug_print_logits() From e412e3e78e3fd98cc45dd85332cb590f684ed903 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Feb 2025 18:07:52 +0000 Subject: [PATCH 7/8] upd --- include/flashinfer/attention/decode.cuh | 1 + include/flashinfer/attention/hopper/variants.cuh | 4 ++-- include/flashinfer/attention/variants.cuh | 9 +++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 15e73c269..7f3da453e 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -94,6 +94,7 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian if constexpr (variant.use_softmax) { s[j] *= variant.sm_scale_log2; } + bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, kv_head_idx); s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf; diff --git a/include/flashinfer/attention/hopper/variants.cuh b/include/flashinfer/attention/hopper/variants.cuh index 8f17cdbdf..b9f658f84 100644 --- a/include/flashinfer/attention/hopper/variants.cuh +++ b/include/flashinfer/attention/hopper/variants.cuh @@ -54,11 +54,11 @@ struct LogitsSoftCap { template __device__ auto GetAttentionUpdater() { - return OnlineSoftmax(0.); + return OnlineSoftmax(post_tanh_scale); } REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, - { return math::tanh(logits * pre_tanh_scale) * post_tanh_scale; }) + { return math::tanh(logits * pre_tanh_scale); }) }; template diff --git a/include/flashinfer/attention/variants.cuh b/include/flashinfer/attention/variants.cuh index 4fa6cdb0d..9d6f546d7 100644 --- a/include/flashinfer/attention/variants.cuh +++ b/include/flashinfer/attention/variants.cuh @@ -48,7 +48,11 @@ struct DefaultAttention : AttentionVariantBase { soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap); sm_scale_log2 = math::log2e * params.logits_soft_cap; } else { - sm_scale_log2 = params.sm_scale * math::log2e; + if constexpr (use_alibi) { + sm_scale_log2 = math::log2e; + } else { + sm_scale_log2 = params.sm_scale * math::log2e; + } } if constexpr (use_custom_mask) { if constexpr (has_maybe_mask_indptr_v) { @@ -64,7 +68,8 @@ struct DefaultAttention : AttentionVariantBase { REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { if constexpr (use_alibi) { - logits = logits + params.maybe_alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); + logits = logits * params.sm_scale + + params.maybe_alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); } if constexpr (use_logits_soft_cap) { logits = float(math::tanh(logits * soft_cap_pre_tanh_scale)); From ba27c9a30bc968c22bd9ef98fad35794ec838b66 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Feb 2025 18:34:27 +0000 Subject: [PATCH 8/8] upd --- tests/test_jit_example.py | 68 ++++++++++++++------------------------- 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index 3bc08288a..84fb86bc0 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -36,13 +36,10 @@ def test_single_decode_mask(): sm_scale_log2 = params.sm_scale * math::log2e; } - template - __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, - uint32_t kv_head_idx) { + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { const uint32_t offset = kv_idx; return ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); - } + }) }; """ jit_module = gen_customize_single_decode_module( @@ -97,12 +94,9 @@ def test_single_decode_mask(): sigmoid_scale_log2 = params.logits_scale * math::log2e; } - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { return math::ptx_rcp(1.f + math::ptx_exp2(-float(logits * sigmoid_scale_log2 + sigmoid_bias_log2))); - } + }); }; """ @@ -122,13 +116,9 @@ def test_single_decode_mask(): return DefaultUpdater(); } - template - __device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits, - int batch_idx, - int qo_idx, int kv_idx, - int qo_head_idx, int kv_head_idx) { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { return math::ptx_rcp(1.f + math::ptx_exp2(-float(logits * logits_scale_log2 + sigmoid_bias_log2e))); - } + }); }; """ @@ -187,15 +177,12 @@ def test_dump_logits(): sm_scale_log2 = params.sm_scale * math::log2e; } - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { if (qo_idx < qo_len && kv_idx < kv_len) { params.output_logits[qo_head_idx * (qo_len * kv_len) + qo_idx * kv_len + kv_idx] = logits * params.sm_scale; } return logits; - } + }); }; """ jit_module = gen_customize_single_prefill_module( @@ -593,16 +580,13 @@ def test_debug_print_logits(): sm_scale_log2 = params.sm_scale * math::log2e; } - template - __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, - uint32_t qo_head_idx, uint32_t kv_head_idx) { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { if (logits >= 5) { printf("Large logits at qo_idx=%d, kv_idx=%d, qo_head_idx=%d, kv_head_idx=%d: %.3f\n", - qo_idx, kv_idx, qo_head_idx, kv_head_idx, float(logits)); + qo_idx, kv_idx, qo_head_idx, kv_head_idx, float(logits)); } return logits; - } + }); }; """ jit_module = gen_customize_single_prefill_module( @@ -659,27 +643,23 @@ def test_sm90_debug_print_logits(): } - template - __device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits, - int batch_idx, - int qo_idx, int kv_idx, - int qo_head_idx, int kv_head_idx) { + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { if (qo_idx < qo_len && kv_idx < kv_len) { - printf( - "---> LOGITS DEBUG: " - "qo_idx=%-5d " - "kv_idx=%-5d " - "sm_scale_log2=%-12.5f " - "logits=%-12.5f " - "\n", - qo_idx, - kv_idx, - sm_scale_log2, - static_cast(logits)); + printf( + "---> LOGITS DEBUG: " + "qo_idx=%-5d " + "kv_idx=%-5d " + "sm_scale_log2=%-12.5f " + "logits=%-12.5f " + "\n", + qo_idx, + kv_idx, + sm_scale_log2, + static_cast(logits)); } logits *= sm_scale_log2; return logits; - } + }) }; """ jit_module = gen_customize_single_prefill_module(