diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index 1a9a56219..c1c3eef6a 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -24,6 +24,7 @@ generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, + generate_dispatch_inc, generate_single_decode_inst, generate_single_prefill_inst, ) @@ -47,6 +48,19 @@ def write_if_different(path: Path, content: str) -> None: path.mkdir(parents=True, exist_ok=True) + write_if_different( + path / "dispatch.inc", + generate_dispatch_inc.get_dispatch_inc_str( + argparse.Namespace( + head_dims=head_dims, + head_dims_sm90=head_dims, + pos_encoding_modes=[0], + use_fp16_qk_reductions=[0], + mask_modes=mask_modes, + ) + ), + ) + write_if_different( path / "aot_default_additional_params.h", generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(), @@ -79,9 +93,10 @@ def write_if_different(path: Path, content: str) -> None: product(fp16_dtypes, fp8_dtypes) ): dtype_out = dtype_q - fname = f"single_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" + fname = f"single_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" content = generate_single_decode_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, dtype_q, dtype_kv, @@ -93,7 +108,8 @@ def write_if_different(path: Path, content: str) -> None: f"single_decode_with_kv_cache_dtype_q_{dtype_q}_" f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_out}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}" @@ -114,9 +130,10 @@ def write_if_different(path: Path, content: str) -> None: product(fp16_dtypes, fp8_dtypes) ): dtype_out = dtype_q - fname = f"batch_paged_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" + fname = f"batch_paged_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" content = generate_batch_paged_decode_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, dtype_q, dtype_kv, @@ -130,7 +147,8 @@ def write_if_different(path: Path, content: str) -> None: f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_out}_" f"dtype_idx_{idtype}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}" @@ -153,9 +171,10 @@ def write_if_different(path: Path, content: str) -> None: for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( product(prefill_dtypes, fp8_dtypes) ): - fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu" + fname = f"single_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu" content = generate_single_prefill_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -172,7 +191,8 @@ def write_if_different(path: Path, content: str) -> None: f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_q}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}_" @@ -198,9 +218,10 @@ def write_if_different(path: Path, content: str) -> None: for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( product(prefill_dtypes, fp8_dtypes) ): - fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" content = generate_batch_paged_prefill_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -211,9 +232,10 @@ def write_if_different(path: Path, content: str) -> None: ) write_if_different(path / fname, content) - fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" content = generate_batch_ragged_prefill_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -234,7 +256,8 @@ def write_if_different(path: Path, content: str) -> None: f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_q}_" f"dtype_idx_{idtype}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{sliding_window}_" f"use_logits_cap_{logits_soft_cap}_" diff --git a/aot_build_utils/generate_batch_paged_decode_inst.py b/aot_build_utils/generate_batch_paged_decode_inst.py index fcc2f6108..eee2bc4d7 100644 --- a/aot_build_utils/generate_batch_paged_decode_inst.py +++ b/aot_build_utils/generate_batch_paged_decode_inst.py @@ -22,7 +22,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, dtype_q, dtype_kv, @@ -35,25 +36,25 @@ def get_cu_file_str( using Params = BatchDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp_v, float* tmp_s, cudaStream_t stream); -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp_v, float* tmp_s, cudaStream_t stream); -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp_v, float* tmp_s, cudaStream_t stream); -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp_v, float* tmp_s, @@ -69,20 +70,22 @@ def get_cu_file_str( }} """.format( - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], dtype_q=dtype_literal[dtype_q], dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], - head_dim_kpe=head_dim // 8, + head_dim=head_dim_vo, # NOTE(Zihao): for MLA instantiation, we should move them to a standalone file + head_dim_kpe=head_dim_vo // 8, ) return content if __name__ == "__main__": pattern = ( - r"batch_paged_decode_head_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_decode_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) diff --git a/aot_build_utils/generate_batch_paged_prefill_inst.py b/aot_build_utils/generate_batch_paged_prefill_inst.py index 5a8c17394..0f11df7e5 100644 --- a/aot_build_utils/generate_batch_paged_prefill_inst.py +++ b/aot_build_utils/generate_batch_paged_prefill_inst.py @@ -27,7 +27,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -41,13 +42,14 @@ def get_cu_file_str( def get_insts(attention_variant, dtype_out): return "\n".join( [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{cta_tile_q}, {head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>( + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{cta_tile_q}, {head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>( Params params, {dtype_out}* tmp_v, float* tmp_s, cudaStream_t stream); """.format( cta_tile_q=cta_tile_q, - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], use_fp16_qk_reduction=use_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], @@ -92,7 +94,7 @@ def get_insts(attention_variant, dtype_out): if __name__ == "__main__": pattern = ( - r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py index a57d0b17b..111b29d41 100644 --- a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py +++ b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py @@ -22,7 +22,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -37,7 +38,8 @@ def get_cu_file_str( def get_insts(attention_variant): return """ template cudaError_t BatchPrefillWithPagedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, @@ -46,7 +48,8 @@ def get_insts(attention_variant): (Params& params, cudaStream_t stream); template cudaError_t BatchPrefillWithPagedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, @@ -55,7 +58,8 @@ def get_insts(attention_variant): (Params& params, cudaStream_t stream); template cudaError_t BatchPrefillWithPagedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, @@ -64,7 +68,8 @@ def get_insts(attention_variant): (Params& params, cudaStream_t stream); template cudaError_t BatchPrefillWithPagedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, @@ -72,7 +77,8 @@ def get_insts(attention_variant): Params> (Params& params, cudaStream_t stream); """.format( - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, mask_mode=mask_mode_literal[int(mask_mode)], attention_variant=attention_variant, ) @@ -107,7 +113,7 @@ def get_insts(attention_variant): if __name__ == "__main__": pattern = ( - r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_" r"dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" ) diff --git a/aot_build_utils/generate_batch_ragged_prefill_inst.py b/aot_build_utils/generate_batch_ragged_prefill_inst.py index 504a325ed..305d82b7e 100644 --- a/aot_build_utils/generate_batch_ragged_prefill_inst.py +++ b/aot_build_utils/generate_batch_ragged_prefill_inst.py @@ -27,7 +27,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -41,13 +42,14 @@ def get_cu_file_str( def get_insts(attention_variant, dtype_out): return "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{cta_tile_q}, {head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>( + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{cta_tile_q}, {head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>( Params params, {dtype_out}* tmp_v, float* tmp_s, cudaStream_t stream); """.format( cta_tile_q=cta_tile_q, - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], use_fp16_qk_reduction=use_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], @@ -94,7 +96,7 @@ def get_insts(attention_variant, dtype_out): if __name__ == "__main__": pattern = ( - r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"batch_ragged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py index aca14beaf..1776ef4db 100644 --- a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py +++ b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py @@ -27,7 +27,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -40,7 +41,8 @@ def get_cu_file_str( def get_insts(attention_variant): return """ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, @@ -48,7 +50,8 @@ def get_insts(attention_variant): Params>(Params& params, cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, @@ -56,7 +59,8 @@ def get_insts(attention_variant): Params>(Params& params, cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, @@ -64,14 +68,16 @@ def get_insts(attention_variant): Params>(Params& params, cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched - <{head_dim}, + <{head_dim_qk}, + {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}, Params>(Params& params, cudaStream_t stream); """.format( - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, mask_mode=mask_mode_literal[int(mask_mode)], attention_variant=attention_variant, ) @@ -107,7 +113,7 @@ def get_insts(attention_variant): if __name__ == "__main__": pattern = ( - r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"batch_ragged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/aot_build_utils/generate_single_decode_inst.py b/aot_build_utils/generate_single_decode_inst.py index 7d7f0b9e1..e72fda71e 100644 --- a/aot_build_utils/generate_single_decode_inst.py +++ b/aot_build_utils/generate_single_decode_inst.py @@ -22,7 +22,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, dtype_q, dtype_kv, @@ -34,25 +35,25 @@ def get_cu_file_str( using Params = SingleDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}>; -template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, cudaStream_t stream); -template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, cudaStream_t stream); -template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, cudaStream_t stream); -template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention< +template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention< /*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, @@ -60,7 +61,8 @@ def get_cu_file_str( }} """.format( - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], dtype_q=dtype_literal[dtype_q], dtype_kv=dtype_literal[dtype_kv], @@ -71,7 +73,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_decode_head_([0-9]+)_posenc_([0-9]+)_" + r"single_decode_head_qk_([0-9]+)_head_vo_([0-9]+)posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/aot_build_utils/generate_single_prefill_inst.py b/aot_build_utils/generate_single_prefill_inst.py index f0bcf490e..14535c04e 100644 --- a/aot_build_utils/generate_single_prefill_inst.py +++ b/aot_build_utils/generate_single_prefill_inst.py @@ -22,7 +22,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -36,25 +37,25 @@ def get_cu_file_str( using Params = SinglePrefillParams<{dtype_q}, {dtype_kv}, {dtype_out}>; -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< {use_custom_mask}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, cudaStream_t stream); -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< {use_custom_mask}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, cudaStream_t stream); -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< {use_custom_mask}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, cudaStream_t stream); -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, DefaultAttention< {use_custom_mask}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>( Params params, {dtype_out}* tmp, @@ -62,7 +63,8 @@ def get_cu_file_str( }} """.format( - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], use_fp16_qk_reduction=use_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], @@ -76,7 +78,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"single_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/aot_build_utils/generate_single_prefill_sm90_inst.py b/aot_build_utils/generate_single_prefill_sm90_inst.py index f19c5b9bb..291aad8ed 100644 --- a/aot_build_utils/generate_single_prefill_sm90_inst.py +++ b/aot_build_utils/generate_single_prefill_sm90_inst.py @@ -22,7 +22,8 @@ def get_cu_file_str( - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -45,24 +46,25 @@ def get_cu_file_str( using Params = SinglePrefillParams; template cudaError_t SinglePrefillWithKVCacheDispatched - <{head_dim}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, LogitsSoftCap, Params> + <{head_dim_qk}, {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, LogitsSoftCap, Params> (Params& params, cudaStream_t stream); template cudaError_t SinglePrefillWithKVCacheDispatched - <{head_dim}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, LogitsSoftCap, Params> + <{head_dim_qk}, {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, LogitsSoftCap, Params> (Params& params, cudaStream_t stream); template cudaError_t SinglePrefillWithKVCacheDispatched - <{head_dim}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, StandardAttention, Params> + <{head_dim_qk}, {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/true, StandardAttention, Params> (Params& params, cudaStream_t stream); template cudaError_t SinglePrefillWithKVCacheDispatched - <{head_dim}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, StandardAttention, Params> + <{head_dim_qk}, {head_dim_vo}, {mask_mode}, /*USE_SLIDING_WINDOW=*/false, StandardAttention, Params> (Params& params, cudaStream_t stream); }} """.format( - head_dim=head_dim, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, # pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], # use_fp16_qk_reduction=use_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], @@ -76,7 +78,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"single_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu" ) diff --git a/aot_build_utils/generate_sm90.py b/aot_build_utils/generate_sm90.py index dc63163a3..2466dd219 100644 --- a/aot_build_utils/generate_sm90.py +++ b/aot_build_utils/generate_sm90.py @@ -71,7 +71,8 @@ def write_if_different(path: Path, content: str) -> None: for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu" content = generate_single_prefill_sm90_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -88,7 +89,8 @@ def write_if_different(path: Path, content: str) -> None: f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_q}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}_" @@ -112,9 +114,10 @@ def write_if_different(path: Path, content: str) -> None: idtypes, ): for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): - fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" content = generate_batch_paged_prefill_sm90_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -125,9 +128,10 @@ def write_if_different(path: Path, content: str) -> None: ) write_if_different(path / fname, content) - fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" content = generate_batch_ragged_prefill_sm90_inst.get_cu_file_str( - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_fp16_qk_reduction, mask_mode, @@ -148,7 +152,8 @@ def write_if_different(path: Path, content: str) -> None: f"dtype_kv_{dtype_kv}_" f"dtype_o_{dtype_q}_" f"dtype_idx_{idtype}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{sliding_window}_" f"use_logits_cap_{logits_soft_cap}_" diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index bf2fe08c0..33d94ec8b 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -36,8 +36,9 @@ std::vector BatchDecodeWithPagedKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim, - at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream) { + bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk, + unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, + int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = @@ -48,14 +49,18 @@ std::vector BatchDecodeWithPagedKVCachePlan( auto q_scalar_type = empty_q_data.scalar_type(); auto kv_scalar_type = empty_kv_data.scalar_type(); + TORCH_CHECK(head_dim_qk == head_dim_vo, + "CUDA cores template only supports equal head dim for QK and VO, please use tensor " + "cores template for different head dim"); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, - USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< - GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, AttentionVariant, Params>; - cudaError_t status = DecodePlan( + GROUP_SIZE, HEAD_DIM_QK, POS_ENCODING_MODE, AttentionVariant, Params>; + cudaError_t status = DecodePlan( static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, static_cast(int_workspace_buffer.data_ptr()), static_cast(page_locked_int_workspace_buffer.data_ptr()), @@ -93,7 +98,12 @@ void BatchDecodeWithPagedKVCacheRun( page_size = paged_k_cache.size(1); num_kv_heads = paged_k_cache.size(2); } - uint32_t head_dim = q.size(2); + uint32_t head_dim_qk = q.size(2); + uint32_t head_dim_vo = paged_v_cache.size(3); + + TORCH_CHECK(head_dim_qk == head_dim_vo, + "CUDA cores template only supports equal head dim for QK and VO, please use tensor " + "cores template for different head dim"); if (maybe_lse) { const auto& lse = *maybe_lse; @@ -122,10 +132,10 @@ void BatchDecodeWithPagedKVCacheRun( cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, - USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { paged_kv_t paged_kv( - num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, + num_kv_heads, page_size, HEAD_DIM_QK, batch_size, kv_layout, static_cast(paged_k_cache.data_ptr()), static_cast(paged_v_cache.data_ptr()), kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), @@ -171,7 +181,7 @@ void BatchDecodeWithPagedKVCacheRun( params.padded_batch_size = plan_info.padded_batch_size; cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCacheDispatched(params, tmp_v, tmp_s, /*stream=*/stream); diff --git a/csrc/batch_decode_config.inc b/csrc/batch_decode_config.inc index 2e02ee5c6..f0e89fd56 100644 --- a/csrc/batch_decode_config.inc +++ b/csrc/batch_decode_config.inc @@ -15,32 +15,38 @@ */ // NOTE(Zihao): this is the include file for AOT mode #pragma once -#include -#include #include #include +#include +#include -#include "aot_extension_utils.h" #include "aot_default_additional_params.h" +#include "aot_extension_utils.h" using IdType = int32_t; #define ADDITIONAL_FUNC_PARAMS BATCH_DECODE_ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER BATCH_DECODE_ADDITIONAL_PARAMS_SETTER -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) { \ - DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ - using DTypeO = DTypeQ; \ - using Params = BatchDecodeParams; \ - constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \ - return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ - return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ - using AttentionVariant = DefaultAttention; \ - __VA_ARGS__(); \ - return true; \ - }); \ - }); \ - }); \ - }); \ -} +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, \ + POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, \ + AttentionVariant, Params, ...) \ + { \ + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ + using DTypeO = DTypeQ; \ + using Params = BatchDecodeParams; \ + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ + return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ + return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ + using AttentionVariant = \ + DefaultAttention; \ + __VA_ARGS__(); \ + return true; \ + }); \ + }); \ + }); \ + }); \ + } diff --git a/csrc/batch_decode_customize_config.jinja b/csrc/batch_decode_customize_config.jinja index 03d140f4c..24ba9f1a6 100644 --- a/csrc/batch_decode_customize_config.jinja +++ b/csrc/batch_decode_customize_config.jinja @@ -7,7 +7,7 @@ #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) { \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) { \ using AttentionVariant = {{ variant_name }}; \ __VA_ARGS__(); \ } @@ -18,7 +18,8 @@ using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = {{ idtype }}; -constexpr int HEAD_DIM = {{ head_dim }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; diff --git a/csrc/batch_decode_jit_pybind.cu b/csrc/batch_decode_jit_pybind.cu index 5d9f8c4da..db43cddd9 100644 --- a/csrc/batch_decode_jit_pybind.cu +++ b/csrc/batch_decode_jit_pybind.cu @@ -20,8 +20,9 @@ std::vector BatchDecodeWithPagedKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim, - at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream); + bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk, + unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, + int64_t cuda_stream); void BatchDecodeWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, diff --git a/csrc/batch_decode_kernel_inst.jinja b/csrc/batch_decode_kernel_inst.jinja index 459b2082e..df0c57bdb 100644 --- a/csrc/batch_decode_kernel_inst.jinja +++ b/csrc/batch_decode_kernel_inst.jinja @@ -6,7 +6,7 @@ using namespace flashinfer; namespace flashinfer { template cudaError_t -BatchDecodeWithPagedKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ variant_name }}, Params>( +BatchDecodeWithPagedKVCacheDispatched<{{ head_dim_qk }}, {{ pos_encoding_mode }}, {{ variant_name }}, Params>( Params params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream); diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index f7dff709f..e44121564 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -23,15 +23,15 @@ namespace flashinfer { -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, cudaStream_t stream); -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, cudaStream_t stream); @@ -44,7 +44,8 @@ std::vector BatchPrefillWithKVCachePlan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream) { + bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = @@ -58,7 +59,7 @@ std::vector BatchPrefillWithKVCachePlan( int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), kv_indptr.data_ptr(), total_num_rows, batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); @@ -77,15 +78,20 @@ void BatchPrefillWithRaggedKVCacheRun( QKVLayout kv_layout = static_cast(layout); int64_t num_qo_heads = q.size(1); - int64_t head_dim = q.size(2); + int64_t head_dim_qk = q.size(2); int64_t num_kv_heads = (kv_layout == QKVLayout::kNHD) ? k.size(1) : k.size(0); - uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), k_stride_n, k_stride_h, v_stride_n, + v_stride_h; if (kv_layout == QKVLayout::kNHD) { - kv_stride_n = k.stride(0); - kv_stride_h = k.stride(1); + k_stride_n = k.stride(0); + k_stride_h = k.stride(1); + v_stride_n = v.stride(0); + v_stride_h = v.stride(1); } else { - kv_stride_h = k.stride(0); - kv_stride_n = k.stride(1); + k_stride_h = k.stride(0); + k_stride_n = k.stride(1); + v_stride_h = v.stride(0); + v_stride_n = v.stride(1); } if (maybe_lse) { @@ -105,8 +111,9 @@ void BatchPrefillWithRaggedKVCacheRun( cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, - USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, + RaggedParams, PagedParams, [&] { RaggedParams params; params.q = static_cast(q.data_ptr()); @@ -120,8 +127,10 @@ void BatchPrefillWithRaggedKVCacheRun( params.num_kv_heads = num_kv_heads; params.q_stride_n = q_stride_n; params.q_stride_h = q_stride_h; - params.kv_stride_n = kv_stride_n; - params.kv_stride_h = kv_stride_h; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; params.window_left = window_left; params.request_indices = nullptr; @@ -171,7 +180,7 @@ void BatchPrefillWithRaggedKVCacheRun( DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched< - CTA_TILE_Q, HEAD_DIM, POS_ENCODING_MODE, + CTA_TILE_Q, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, /*use_fp16_qk_reduction=*/USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, RaggedParams>(params, tmp_v, tmp_s, stream); }); @@ -196,7 +205,7 @@ void BatchPrefillWithPagedKVCacheRun( int64_t batch_size = paged_kv_indptr.size(0) - 1; int64_t num_qo_heads = q.size(1); int64_t num_kv_heads, page_size; - uint32_t head_dim = q.size(2); + uint32_t head_dim_qk = q.size(2); if (kv_layout == QKVLayout::kHND) { num_kv_heads = paged_k_cache.size(1); page_size = paged_k_cache.size(2); @@ -232,13 +241,14 @@ void BatchPrefillWithPagedKVCacheRun( cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, - USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, + RaggedParams, PagedParams, [&] { PagedParams params; params.q = static_cast(q.data_ptr()); paged_kv_t paged_kv( - num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, + num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, static_cast(paged_k_cache.data_ptr()), static_cast(paged_v_cache.data_ptr()), kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), @@ -301,7 +311,7 @@ void BatchPrefillWithPagedKVCacheRun( DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { status = flashinfer::BatchPrefillWithPagedKVCacheDispatched< - CTA_TILE_Q, HEAD_DIM, POS_ENCODING_MODE, + CTA_TILE_Q, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, /*use_fp16_qk_reduction=*/USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, PagedParams>(params, tmp_v, tmp_s, stream); }); diff --git a/csrc/batch_prefill_config.inc b/csrc/batch_prefill_config.inc index cda91bfbe..5dcd730b8 100644 --- a/csrc/batch_prefill_config.inc +++ b/csrc/batch_prefill_config.inc @@ -30,9 +30,9 @@ using IdType = int32_t; #define ADDITIONAL_FUNC_PARAMS BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER BATCH_PREFILL_ADDITIONAL_PARAMS_SETTER -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, POS_ENCODING_MODE, \ - USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, \ - AttentionVariant, RaggedParams, PagedParams, ...) \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, \ + POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, \ + USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \ { \ DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \ @@ -43,7 +43,8 @@ using IdType = int32_t; constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ constexpr bool USE_FP16_QK_REDUCTION = false; \ constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \ - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \ + return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ using AttentionVariant = \ @@ -56,4 +57,4 @@ using IdType = int32_t; }); \ }); \ }); \ -} + } diff --git a/csrc/batch_prefill_customize_config.jinja b/csrc/batch_prefill_customize_config.jinja index 6770eb958..25778f638 100644 --- a/csrc/batch_prefill_customize_config.jinja +++ b/csrc/batch_prefill_customize_config.jinja @@ -8,7 +8,7 @@ #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ constexpr auto use_custom_mask = MASK_MODE == MaskMode::kCustom; \ using AttentionVariant = {{ variant_name }}; \ @@ -21,7 +21,8 @@ using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = {{ idtype }}; -constexpr int HEAD_DIM = {{ head_dim }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; @@ -46,8 +47,10 @@ struct RaggedParams { uint32_t num_kv_heads; uint32_t q_stride_n; uint32_t q_stride_h; - uint32_t kv_stride_n; - uint32_t kv_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; int32_t window_left; IdType* request_indices; diff --git a/csrc/batch_prefill_jit_pybind.cu b/csrc/batch_prefill_jit_pybind.cu index f07ddc4c8..6d35deef4 100644 --- a/csrc/batch_prefill_jit_pybind.cu +++ b/csrc/batch_prefill_jit_pybind.cu @@ -21,7 +21,8 @@ std::vector BatchPrefillWithKVCachePlan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream); + bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, diff --git a/csrc/batch_prefill_paged_kernel_inst.jinja b/csrc/batch_prefill_paged_kernel_inst.jinja index 7f3357a1d..667887703 100644 --- a/csrc/batch_prefill_paged_kernel_inst.jinja +++ b/csrc/batch_prefill_paged_kernel_inst.jinja @@ -7,7 +7,7 @@ constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom; {% for cta_tile_q in [16, 64, 128] %} template cudaError_t BatchPrefillWithPagedKVCacheDispatched< - /*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}}, + /*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}}, {{ variant_name }}, PagedParams>(PagedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream); {% endfor %} diff --git a/csrc/batch_prefill_paged_sm90_kernel_inst.jinja b/csrc/batch_prefill_paged_sm90_kernel_inst.jinja index efc5aeae2..9e159e006 100644 --- a/csrc/batch_prefill_paged_sm90_kernel_inst.jinja +++ b/csrc/batch_prefill_paged_sm90_kernel_inst.jinja @@ -5,7 +5,8 @@ namespace flashinfer { {% for same_scheduler_for_all_heads in ["true", "false"] %} template cudaError_t BatchPrefillWithPagedKVCacheDispatched - <{{ head_dim }}, + <{{ head_dim_qk }}, + {{ head_dim_vo }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, diff --git a/csrc/batch_prefill_ragged_kernel_inst.jinja b/csrc/batch_prefill_ragged_kernel_inst.jinja index cbdb7ace8..39c49227c 100644 --- a/csrc/batch_prefill_ragged_kernel_inst.jinja +++ b/csrc/batch_prefill_ragged_kernel_inst.jinja @@ -7,7 +7,7 @@ constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom; {% for cta_tile_q in [16, 64, 128] %} template cudaError_t BatchPrefillWithRaggedKVCacheDispatched< - /*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}}, + /*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}}, {{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream); {% endfor %} diff --git a/csrc/batch_prefill_ragged_sm90_kernel_inst.jinja b/csrc/batch_prefill_ragged_sm90_kernel_inst.jinja index 4d4862868..90a44b4d7 100644 --- a/csrc/batch_prefill_ragged_sm90_kernel_inst.jinja +++ b/csrc/batch_prefill_ragged_sm90_kernel_inst.jinja @@ -5,7 +5,8 @@ namespace flashinfer { {% for same_scheduler_for_all_heads in ["true", "false"] %} template cudaError_t BatchPrefillWithRaggedKVCacheDispatched - <{{ head_dim }}, + <{{ head_dim_qk }}, + {{ head_dim_vo }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index 739a13411..6ee020a7b 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -25,11 +25,11 @@ namespace flashinfer { -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params& params, cudaStream_t stream); -template cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream); @@ -42,7 +42,8 @@ std::vector BatchPrefillWithKVCacheSM90Plan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream) { + bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = @@ -57,8 +58,8 @@ std::vector BatchPrefillWithKVCacheSM90Plan( int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), kv_indptr.data_ptr(), kv_len_arr.data_ptr(), total_num_rows, - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, causal, - enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, + causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TORCH_CHECK(status == cudaSuccess, "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); @@ -84,7 +85,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run( void* float_buffer_ptr = float_workspace_buffer.data_ptr(); void* int_buffer_ptr = int_workspace_buffer.data_ptr(); - unsigned int head_dim = q.size(2); + unsigned int head_dim_qk = q.size(2); + unsigned int head_dim_vo = v.size(2); auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = k.scalar_type(); @@ -95,8 +97,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run( bool use_swa = window_left != -1; DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, - AttentionVariant, RaggedParams, PagedParams, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { RaggedParams params; params.q_ptr = static_cast(q.data_ptr()); @@ -121,7 +123,6 @@ void BatchPrefillWithRaggedKVCacheSM90Run( } params.nnz_qo = q.size(0); params.nnz_kv = k.size(0); - params.head_dim = head_dim; params.num_qo_heads = q.size(1); params.num_kv_heads = k.size(1); params.group_size = params.num_qo_heads / params.num_kv_heads; @@ -142,10 +143,9 @@ void BatchPrefillWithRaggedKVCacheSM90Run( bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { - cudaError_t status = - BatchPrefillWithRaggedKVCacheDispatched(params, stream); + cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, USE_SLIDING_WINDOW, SAME_SCHEDULER_FOR_ALL_HEADS, + AttentionVariant>(params, stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", cudaGetErrorString(status)); @@ -171,7 +171,8 @@ void BatchPrefillWithPagedKVCacheSM90Run( } QKVLayout kv_layout = static_cast(layout); unsigned int num_kv_heads, page_size; - unsigned int head_dim = q.size(2); + unsigned int head_dim_qk = q.size(2); + unsigned int head_dim_vo = paged_v_cache.size(3); if (kv_layout == QKVLayout::kHND) { num_kv_heads = paged_k_cache.size(1); page_size = paged_k_cache.size(2); @@ -191,8 +192,8 @@ void BatchPrefillWithPagedKVCacheSM90Run( bool use_swa = window_left != -1; DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, - AttentionVariant, RaggedParams, PagedParams, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { PagedParams params; params.q_ptr = static_cast(q.data_ptr()); @@ -218,7 +219,6 @@ void BatchPrefillWithPagedKVCacheSM90Run( params.v_stride_n = paged_v_cache.stride(2); } params.nnz_qo = q.size(0); - params.head_dim = head_dim; params.num_qo_heads = q.size(1); params.num_kv_heads = num_kv_heads; params.group_size = params.num_qo_heads / num_kv_heads; @@ -241,10 +241,9 @@ void BatchPrefillWithPagedKVCacheSM90Run( bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { - cudaError_t status = - BatchPrefillWithPagedKVCacheDispatched(params, stream); + cudaError_t status = BatchPrefillWithPagedKVCacheDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, USE_SLIDING_WINDOW, SAME_SCHEDULER_FOR_ALL_HEADS, + AttentionVariant>(params, stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", cudaGetErrorString(status)); diff --git a/csrc/batch_prefill_sm90_config.inc b/csrc/batch_prefill_sm90_config.inc index 8053bc0b5..d344915f9 100644 --- a/csrc/batch_prefill_sm90_config.inc +++ b/csrc/batch_prefill_sm90_config.inc @@ -27,28 +27,30 @@ using IdType = int32_t; #define ADDITIONAL_FUNC_PARAMS BATCH_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER BATCH_PREFILL_SM90_ADDITIONAL_PARAMS_SETTER -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, USE_SLIDING_WINDOW, \ - USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \ - { \ - DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ - if (q_scalar_type != kv_scalar_type) { \ - return false; \ - } \ - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \ - using DTypeQ = cutlass_dtype_t; \ - using DTypeKV = DTypeQ; \ - using DTypeO = DTypeQ; \ - using RaggedParams = BatchPrefillRaggedParams; \ - using PagedParams = BatchPrefillPagedParams; \ - return DISPATCH_head_dim_sm90(head_dim, HEAD_DIM, [&] { \ - return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ - return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ - using AttentionVariant = DefaultAttention; \ - __VA_ARGS__(); \ - return true; \ - }); \ - }); \ - }); \ - }); \ - }); \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, \ + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, \ + PagedParams, ...) \ + { \ + DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ + if (q_scalar_type != kv_scalar_type) { \ + return false; \ + } \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \ + using DTypeQ = cutlass_dtype_t; \ + using DTypeKV = DTypeQ; \ + using DTypeO = DTypeQ; \ + using RaggedParams = BatchPrefillRaggedParams; \ + using PagedParams = BatchPrefillPagedParams; \ + return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ + return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ + using AttentionVariant = DefaultAttention; \ + __VA_ARGS__(); \ + return true; \ + }); \ + }); \ + }); \ + }); \ + }); \ } diff --git a/csrc/batch_prefill_sm90_customize_config.jinja b/csrc/batch_prefill_sm90_customize_config.jinja index 70078c25a..73cdd25ce 100644 --- a/csrc/batch_prefill_sm90_customize_config.jinja +++ b/csrc/batch_prefill_sm90_customize_config.jinja @@ -8,7 +8,7 @@ #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__();}) using namespace flashinfer; @@ -18,7 +18,8 @@ using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; using IdType = cutlass_dtype_t<{{ idtype }}>; -constexpr int HEAD_DIM = {{ head_dim }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; diff --git a/csrc/batch_prefill_sm90_jit_pybind.cu b/csrc/batch_prefill_sm90_jit_pybind.cu index 45f37af44..5466a6ac8 100644 --- a/csrc/batch_prefill_sm90_jit_pybind.cu +++ b/csrc/batch_prefill_sm90_jit_pybind.cu @@ -21,7 +21,8 @@ std::vector BatchPrefillWithKVCacheSM90Plan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream); + bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index 59234182e..2dc620603 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -44,8 +44,8 @@ std::vector BatchDecodeWithPagedKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim, - at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream); + bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk, + unsigned head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream); void BatchDecodeWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, @@ -107,7 +107,8 @@ std::vector BatchPrefillWithKVCachePlan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream); + bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, diff --git a/csrc/flashinfer_ops_sm90.cu b/csrc/flashinfer_ops_sm90.cu index 55e4a2e8c..7e2ab7bdd 100644 --- a/csrc/flashinfer_ops_sm90.cu +++ b/csrc/flashinfer_ops_sm90.cu @@ -32,7 +32,8 @@ std::vector BatchPrefillWithKVCacheSM90Plan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream); + bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, diff --git a/csrc/single_decode.cu b/csrc/single_decode.cu index 74713de72..b3d31ed5f 100644 --- a/csrc/single_decode.cu +++ b/csrc/single_decode.cu @@ -47,7 +47,8 @@ void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::T CHECK_EQ(q.size(1), k.size(2)); CHECK_EQ(v.scalar_type(), k.scalar_type()); unsigned int num_qo_heads = q.size(0); - unsigned int head_dim = q.size(1); + unsigned int head_dim_qk = q.size(1); + unsigned int head_dim_vo = v.size(2); unsigned int kv_len, num_kv_heads; QKVLayout kv_layout = static_cast(layout); if (kv_layout == QKVLayout::kNHD) { @@ -64,9 +65,13 @@ void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::T cudaStream_t stream = reinterpret_cast(cuda_stream); + TORCH_CHECK(head_dim_qk == head_dim_vo, + "CUDA cores template only supports equal head dim for QK and VO, please use tensor " + "cores template for different head dim"); + DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, - USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { Params params; params.q = static_cast(q.data_ptr()); @@ -77,18 +82,18 @@ void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::T params.kv_len = kv_len; params.num_qo_heads = num_qo_heads; params.num_kv_heads = num_kv_heads; - params.q_stride_n = num_qo_heads * head_dim; - params.q_stride_h = head_dim; - params.kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim; - params.kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim; - params.head_dim = head_dim; + params.q_stride_n = num_qo_heads * head_dim_qk; + params.q_stride_h = head_dim_qk; + params.kv_stride_n = + (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim_vo : head_dim_vo; + params.kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim_vo : kv_len * head_dim_vo; params.window_left = window_left; params.kv_chunk_size = 0; ADDITIONAL_PARAMS_SETTER cudaError_t status = - flashinfer::SingleDecodeWithKVCacheDispatched( params, static_cast(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + diff --git a/csrc/single_decode_config.inc b/csrc/single_decode_config.inc index 4536f968f..4cebfb098 100644 --- a/csrc/single_decode_config.inc +++ b/csrc/single_decode_config.inc @@ -15,31 +15,37 @@ */ // NOTE(Zihao): this is the include file for AOT mode #pragma once -#include #include #include +#include -#include "aot_extension_utils.h" #include "aot_default_additional_params.h" +#include "aot_extension_utils.h" using IdType = int32_t; #define ADDITIONAL_FUNC_PARAMS SINGLE_DECODE_ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER SINGLE_DECODE_ADDITIONAL_PARAMS_SETTER -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) {\ - DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ - using DTypeO = DTypeQ; \ - using Params = SingleDecodeParams; \ - constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \ - return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ - return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ - using AttentionVariant = DefaultAttention; \ - __VA_ARGS__(); \ - return true; \ - }); \ - }); \ - }); \ - }); \ -} +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, \ + POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, \ + AttentionVariant, Params, ...) \ + { \ + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ + using DTypeO = DTypeQ; \ + using Params = SingleDecodeParams; \ + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ + return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ + return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ + using AttentionVariant = \ + DefaultAttention; \ + __VA_ARGS__(); \ + return true; \ + }); \ + }); \ + }); \ + }); \ + } diff --git a/csrc/single_decode_customize_config.jinja b/csrc/single_decode_customize_config.jinja index e773975c5..8a6baec5f 100644 --- a/csrc/single_decode_customize_config.jinja +++ b/csrc/single_decode_customize_config.jinja @@ -6,7 +6,7 @@ #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) {\ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) {\ using AttentionVariant = {{ variant_name }}; \ __VA_ARGS__(); \ } @@ -17,7 +17,8 @@ using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = int32_t; -constexpr int HEAD_DIM = {{ head_dim }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; @@ -40,28 +41,9 @@ struct Params { uint32_t q_stride_h; uint32_t kv_stride_n; uint32_t kv_stride_h; - uint32_t head_dim; int32_t window_left; uint32_t kv_chunk_size; - __host__ __device__ __forceinline__ size_t get_q_elem_offset(uint32_t qo_idx, - uint32_t qo_head_idx, - uint32_t feat_idx) const { - return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, q_stride_h); - } - - __host__ __device__ __forceinline__ size_t get_o_elem_offset(uint32_t qo_idx, - uint32_t qo_head_idx, - uint32_t feat_idx) const { - return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, num_qo_heads * head_dim, head_dim); - } - - __host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx, - uint32_t kv_head_idx, - uint32_t feat_idx) const { - return get_elem_offset_impl(kv_idx, kv_head_idx, feat_idx, kv_stride_n, kv_stride_h); - } - __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return 1; } __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { diff --git a/csrc/single_decode_kernel_inst.jinja b/csrc/single_decode_kernel_inst.jinja index 79f89eb0e..76e4a26d9 100644 --- a/csrc/single_decode_kernel_inst.jinja +++ b/csrc/single_decode_kernel_inst.jinja @@ -6,7 +6,7 @@ using namespace flashinfer; namespace flashinfer { template cudaError_t SingleDecodeWithKVCacheDispatched< - {{ head_dim }}, {{ pos_encoding_mode }}, {{ variant_name }}, Params>( + {{ head_dim_qk }}, {{ pos_encoding_mode }}, {{ variant_name }}, Params>( Params params, {{ dtype_o }}* tmp, cudaStream_t stream); diff --git a/csrc/single_prefill.cu b/csrc/single_prefill.cu index fda91059c..df0e0b115 100644 --- a/csrc/single_prefill.cu +++ b/csrc/single_prefill.cu @@ -22,8 +22,9 @@ namespace flashinfer { -template +template cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, cudaStream_t stream); @@ -36,22 +37,27 @@ void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at:: unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { auto device = q.device(); - unsigned int head_dim = q.size(2); + unsigned int head_dim_qk = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; QKVLayout kv_layout = static_cast(layout); qo_len = q.size(0); num_qo_heads = q.size(1); - uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), k_stride_n, k_stride_h, v_stride_n, + v_stride_h; if (kv_layout == QKVLayout::kNHD) { kv_len = k.size(0); num_kv_heads = k.size(1); - kv_stride_n = k.stride(0); - kv_stride_h = k.stride(1); + k_stride_n = k.stride(0); + k_stride_h = k.stride(1); + v_stride_n = v.stride(0); + v_stride_h = v.stride(1); } else { kv_len = k.size(1); num_kv_heads = k.size(0); - kv_stride_h = k.stride(0); - kv_stride_n = k.stride(1); + k_stride_h = k.stride(0); + k_stride_n = k.stride(1); + v_stride_h = v.stride(0); + v_stride_n = v.stride(1); } if (maybe_lse) { const auto& lse = *maybe_lse; @@ -67,8 +73,9 @@ void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at:: cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, - USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, Params, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, Params, + [&] { Params params; params.q = static_cast(q.data_ptr()); @@ -82,16 +89,18 @@ void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at:: params.kv_len = kv_len; params.q_stride_n = q_stride_n; params.q_stride_h = q_stride_h; - params.kv_stride_n = kv_stride_n; - params.kv_stride_h = kv_stride_h; - params.head_dim = head_dim; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; + params.window_left = window_left; params.partition_kv = false; ADDITIONAL_PARAMS_SETTER cudaError_t status = flashinfer::SinglePrefillWithKVCacheDispatched< - HEAD_DIM, POS_ENCODING_MODE, + HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, /*use_fp16_qk_reduction=*/USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant>( params, static_cast(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, diff --git a/csrc/single_prefill_config.inc b/csrc/single_prefill_config.inc index 229159dd3..3265632a3 100644 --- a/csrc/single_prefill_config.inc +++ b/csrc/single_prefill_config.inc @@ -27,29 +27,30 @@ using IdType = int32_t; #define ADDITIONAL_FUNC_PARAMS SINGLE_PREFILL_ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER SINGLE_PREFILL_ADDITIONAL_PARAMS_SETTER -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, POS_ENCODING_MODE, \ - USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, \ - AttentionVariant, Params, ...) \ - { \ - DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ - return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \ - q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ - using DTypeO = DTypeQ; \ - using Params = SinglePrefillParams; \ - constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ - constexpr bool USE_FP16_QK_REDUCTION = false; \ - constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \ - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \ - return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ - return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ - using AttentionVariant = \ - DefaultAttention; \ - __VA_ARGS__(); \ - return true; \ - }); \ - }); \ - }); \ - }); \ - }); \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, \ + POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, \ + USE_FP16_QK_REDUCTION, AttentionVariant, Params, ...) \ + { \ + DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ + return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \ + q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ + using DTypeO = DTypeQ; \ + using Params = SinglePrefillParams; \ + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \ + return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ + return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ + using AttentionVariant = \ + DefaultAttention; \ + __VA_ARGS__(); \ + return true; \ + }); \ + }); \ + }); \ + }); \ + }); \ } diff --git a/csrc/single_prefill_customize_config.jinja b/csrc/single_prefill_customize_config.jinja index 7f20bb283..9195dde3a 100644 --- a/csrc/single_prefill_customize_config.jinja +++ b/csrc/single_prefill_customize_config.jinja @@ -8,7 +8,7 @@ #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, Params, ...) \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, Params, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \ using AttentionVariant = {{ variant_name }}; \ @@ -22,7 +22,8 @@ using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = int32_t; -constexpr int HEAD_DIM = {{ head_dim }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; @@ -47,8 +48,10 @@ struct Params { uint32_t num_kv_heads; uint32_t q_stride_n; uint32_t q_stride_h; - uint32_t kv_stride_n; - uint32_t kv_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; uint32_t head_dim; int32_t window_left; diff --git a/csrc/single_prefill_kernel_inst.jinja b/csrc/single_prefill_kernel_inst.jinja index bba3bd4ac..7ff99e5f4 100644 --- a/csrc/single_prefill_kernel_inst.jinja +++ b/csrc/single_prefill_kernel_inst.jinja @@ -8,7 +8,7 @@ namespace flashinfer { constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom; template cudaError_t SinglePrefillWithKVCacheDispatched< - {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, {{ variant_name }}, Params>( + {{ head_dim_qk }}, {{ head_dim_vo }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, {{ variant_name }}, Params>( Params params, {{ dtype_o }}* tmp, cudaStream_t stream); diff --git a/csrc/single_prefill_sm90.cu b/csrc/single_prefill_sm90.cu index e0f57f874..4e89a1ecd 100644 --- a/csrc/single_prefill_sm90.cu +++ b/csrc/single_prefill_sm90.cu @@ -23,7 +23,7 @@ namespace flashinfer { -template cudaError_t SinglePrefillWithKVCacheDispatched(Params& params, cudaStream_t stream); @@ -36,7 +36,8 @@ void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { - unsigned int head_dim = q.size(2); + unsigned int head_dim_qk = q.size(2); + unsigned int head_dim_vo = v.size(2); unsigned int num_qo_heads = q.size(1); unsigned int qo_len = q.size(0); @@ -48,8 +49,8 @@ void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, const MaskMode mask_mode = static_cast(mask_mode_code); DISPATCH_context( - DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, - AttentionVariant, Params, [&] { + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { Params params; params.q_ptr = static_cast(q.data_ptr()); params.k_ptr = static_cast(k.data_ptr()); @@ -73,7 +74,6 @@ void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, } params.qo_len = q.size(0); params.kv_len = k.size(0); - params.head_dim = head_dim; params.num_qo_heads = q.size(1); params.num_kv_heads = k.size(1); params.causal = mask_mode == MaskMode::kCausal; @@ -83,8 +83,9 @@ void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, ADDITIONAL_PARAMS_SETTER cudaError_t status = - SinglePrefillWithKVCacheDispatched(params, stream); + SinglePrefillWithKVCacheDispatched(params, + stream); TORCH_CHECK(status == cudaSuccess, "single_prefill_with_kv_cache_sm90 failed with error: " + std::string(cudaGetErrorString(status))); return true; diff --git a/csrc/single_prefill_sm90_config.inc b/csrc/single_prefill_sm90_config.inc index 34f91cafc..2ec696b7f 100644 --- a/csrc/single_prefill_sm90_config.inc +++ b/csrc/single_prefill_sm90_config.inc @@ -27,27 +27,28 @@ using IdType = int32_t; #define ADDITIONAL_FUNC_PARAMS SINGLE_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER SINGLE_PREFILL_SM90_ADDITIONAL_PARAMS_SETTER -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, USE_SLIDING_WINDOW, \ - USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ - { \ - DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ - if (q_scalar_type != kv_scalar_type) { \ - return false; \ - } \ - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \ - using DTypeQ = cutlass_dtype_t; \ - using DTypeKV = DTypeQ; \ - using DTypeO = DTypeQ; \ - using Params = SinglePrefillParams; \ - return DISPATCH_head_dim_sm90(head_dim, HEAD_DIM, [&] { \ - return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ - return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ - using AttentionVariant = DefaultAttention; \ - __VA_ARGS__(); \ - return true; \ - }); \ - }); \ - }); \ - }); \ - }); \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, \ + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ + { \ + DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ + if (q_scalar_type != kv_scalar_type) { \ + return false; \ + } \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \ + using DTypeQ = cutlass_dtype_t; \ + using DTypeKV = DTypeQ; \ + using DTypeO = DTypeQ; \ + using Params = SinglePrefillParams; \ + return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ + return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ + using AttentionVariant = DefaultAttention; \ + __VA_ARGS__(); \ + return true; \ + }); \ + }); \ + }); \ + }); \ + }); \ } diff --git a/csrc/single_prefill_sm90_customize_config.jinja b/csrc/single_prefill_sm90_customize_config.jinja index 982074630..0c8a0e4d5 100644 --- a/csrc/single_prefill_sm90_customize_config.jinja +++ b/csrc/single_prefill_sm90_customize_config.jinja @@ -10,7 +10,7 @@ #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} -#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__(); }) using namespace flashinfer; @@ -20,7 +20,8 @@ using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; using IdType = cutlass_dtype_t; -constexpr int HEAD_DIM = {{ head_dim }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; diff --git a/csrc/single_prefill_sm90_kernel_inst.jinja b/csrc/single_prefill_sm90_kernel_inst.jinja index c5bbf4c63..0f72ade95 100644 --- a/csrc/single_prefill_sm90_kernel_inst.jinja +++ b/csrc/single_prefill_sm90_kernel_inst.jinja @@ -6,7 +6,7 @@ using namespace flashinfer; namespace flashinfer { template cudaError_t SinglePrefillWithKVCacheDispatched - <{{ head_dim }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, {{ variant_name }}, Params>( + <{{ head_dim_qk }}, {{ head_dim_vo }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, {{ variant_name }}, Params>( Params& params, cudaStream_t stream); }; diff --git a/custom_backend.py b/custom_backend.py index 276d09783..af7fb2a6a 100644 --- a/custom_backend.py +++ b/custom_backend.py @@ -8,7 +8,7 @@ def _get_requires_for_build(): requires = [] if os.environ.get("FLASHINFER_ENABLE_AOT", "0") == "1": - requires += ["torch", "ninja"] + requires += ["torch", "ninja", "numpy"] return requires diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 9a9c2d931..703950199 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -459,7 +459,8 @@ def single_decode_with_kv_cache( q.dtype, k.dtype, q.dtype, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -488,7 +489,8 @@ def single_decode_with_kv_cache( q.dtype, k.dtype, q.dtype, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -651,8 +653,10 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, - pin_memory=True, device="cpu", + (8 * 1024 * 1024,), + dtype=torch.uint8, + pin_memory=True, + device="cpu", ) if use_cuda_graph: @@ -864,7 +868,8 @@ def plan( kv_data_type, q_data_type, indptr.dtype, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -885,6 +890,7 @@ def plan( page_size, self.is_cuda_graph_enabled, head_dim, + head_dim, False, # causal get_cuda_stream(device), ) @@ -897,7 +903,8 @@ def plan( kv_data_type, q_data_type, indptr.dtype, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -916,6 +923,7 @@ def plan( window_left, logits_soft_cap, head_dim, + head_dim, torch.empty(0, dtype=q_data_type), torch.empty(0, dtype=kv_data_type), get_cuda_stream(device), @@ -1279,8 +1287,10 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, - pin_memory=True, device="cpu", + (8 * 1024 * 1024,), + dtype=torch.uint8, + pin_memory=True, + device="cpu", ) if use_cuda_graph: diff --git a/flashinfer/jit/attention.py b/flashinfer/jit/attention.py index 63264288b..000bf47f6 100644 --- a/flashinfer/jit/attention.py +++ b/flashinfer/jit/attention.py @@ -101,7 +101,8 @@ def get_single_decode_uri( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -110,7 +111,8 @@ def get_single_decode_uri( f"single_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}" @@ -122,7 +124,8 @@ def get_batch_decode_uri( dtype_kv: torch.dtype, dtype_o: torch.dtype, dtype_idx: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -132,7 +135,8 @@ def get_batch_decode_uri( f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}" @@ -144,7 +148,8 @@ def get_batch_decode_mla_uri( dtype_kv: torch.dtype, dtype_o: torch.dtype, dtype_idx: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> str: @@ -153,7 +158,8 @@ def get_batch_decode_mla_uri( f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}" ) @@ -218,7 +224,8 @@ def get_single_prefill_uri( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -228,7 +235,8 @@ def get_single_prefill_uri( f"single_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}_" @@ -242,7 +250,8 @@ def get_batch_prefill_uri( dtype_kv: torch.dtype, dtype_o: torch.dtype, dtype_idx: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -253,7 +262,8 @@ def get_batch_prefill_uri( f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_{head_dim}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" f"posenc_{pos_encoding_mode}_" f"use_swa_{use_sliding_window}_" f"use_logits_cap_{use_logits_soft_cap}_" @@ -265,7 +275,8 @@ def gen_single_decode_module( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -274,7 +285,8 @@ def gen_single_decode_module( dtype_q, dtype_kv, dtype_o, - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -284,7 +296,8 @@ def gen_single_decode_module( dtype_q, dtype_kv, dtype_o, - head_dim, + head_dim_qk, + head_dim_vo, ["maybe_alibi_slopes"], # additional_tensor_names ["float"], # additional_tensor_dtypes [ @@ -307,7 +320,8 @@ def gen_single_prefill_module( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -318,7 +332,8 @@ def gen_single_prefill_module( dtype_q, dtype_kv, dtype_o, - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -350,7 +365,8 @@ def gen_single_prefill_module( dtype_q, dtype_kv, dtype_o, - head_dim, + head_dim_qk, + head_dim_vo, additional_tensor_names, additional_tensor_dtypes, additional_scalar_names, @@ -369,7 +385,8 @@ def gen_batch_decode_module( dtype_kv: torch.dtype, dtype_o: torch.dtype, dtype_idx: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -379,7 +396,8 @@ def gen_batch_decode_module( dtype_kv, dtype_o, dtype_idx, - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -390,7 +408,8 @@ def gen_batch_decode_module( dtype_kv, dtype_o, dtype_idx, - head_dim, + head_dim_qk, + head_dim_vo, ["maybe_alibi_slopes"], # additional_tensor_names ["float"], # additional_tensor_dtypes [ @@ -414,7 +433,8 @@ def gen_batch_prefill_module( dtype_kv: torch.dtype, dtype_o: torch.dtype, dtype_idx: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, @@ -426,7 +446,8 @@ def gen_batch_prefill_module( dtype_kv, dtype_o, dtype_idx, - head_dim, + head_dim_qk, + head_dim_vo, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -468,7 +489,8 @@ def gen_batch_prefill_module( dtype_kv, dtype_o, dtype_idx, - head_dim, + head_dim_qk, + head_dim_vo, additional_tensor_names, additional_tensor_dtypes, additional_scalar_names, @@ -487,7 +509,8 @@ def gen_customize_single_decode_module( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, additional_tensor_names: List[str], additional_tensor_dtypes: List[str], additional_scalar_names: List[str], @@ -526,7 +549,8 @@ def gen_customize_single_decode_module( "dtype_q": dtype_map[dtype_q], "dtype_kv": dtype_map[dtype_kv], "dtype_o": dtype_map[dtype_o], - "head_dim": head_dim, + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], "use_sliding_window": str(use_sliding_window).lower(), "use_logits_soft_cap": str(use_logits_soft_cap).lower(), @@ -570,7 +594,8 @@ def gen_customize_single_prefill_module( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, additional_tensor_names: List[str], additional_tensor_dtypes: List[str], additional_scalar_names: List[str], @@ -588,7 +613,8 @@ def gen_customize_single_prefill_module( "dtype_q": dtype_map[dtype_q], "dtype_kv": dtype_map[dtype_kv], "dtype_o": dtype_map[dtype_o], - "head_dim": head_dim, + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], "use_sliding_window": str(use_sliding_window).lower(), "use_logits_soft_cap": str(use_logits_soft_cap).lower(), @@ -721,7 +747,8 @@ def gen_customize_batch_decode_module( dtype_kv: torch.dtype, dtype_o: torch.dtype, idtype: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, additional_tensor_names: List[str], additional_tensor_dtypes: List[str], additional_scalar_names: List[str], @@ -752,7 +779,8 @@ def gen_customize_batch_decode_module( "dtype_kv": dtype_map[dtype_kv], "dtype_o": dtype_map[dtype_o], "idtype": dtype_map[idtype], - "head_dim": head_dim, + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], "use_sliding_window": str(use_sliding_window).lower(), "use_logits_soft_cap": str(use_logits_soft_cap).lower(), @@ -803,7 +831,8 @@ def gen_customize_batch_prefill_module( dtype_kv: torch.dtype, dtype_o: torch.dtype, idtype: torch.dtype, - head_dim: int, + head_dim_qk: int, + head_dim_vo: int, additional_tensor_names: List[str], additional_tensor_dtypes: List[str], additional_scalar_names: List[str], @@ -822,7 +851,8 @@ def gen_customize_batch_prefill_module( "dtype_kv": dtype_map[dtype_kv], "dtype_o": dtype_map[dtype_o], "idtype": dtype_map[idtype], - "head_dim": head_dim, + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], "use_sliding_window": str(use_sliding_window).lower(), "use_logits_soft_cap": str(use_logits_soft_cap).lower(), diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index d4239cdbe..843c57364 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -604,7 +604,7 @@ def single_prefill_with_kv_cache_with_jit_module( tmp = _get_cache_buf( "single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, device=device ) - o = torch.empty_like(q) + o = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=device) lse = None if return_lse: lse = torch.empty( @@ -690,14 +690,14 @@ def single_prefill_with_kv_cache( Parameters ---------- q : torch.Tensor - The query tensor, shape: ``[qo_len, num_qo_heads, head_dim]``. + The query tensor, shape: ``[qo_len, num_qo_heads, head_dim_qk]``. k : torch.Tensor - The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is + The key tensor, shape: ``[kv_len, num_kv_heads, head_dim_qk]`` if :attr:`kv_layout` + is ``NHD``, or ``[num_kv_heads, kv_len, head_dim_qk]`` if :attr:`kv_layout` is ``HND``. v : torch.Tensor - The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD``, ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is + The key tensor, shape: ``[kv_len, num_kv_heads, head_dim_vo]`` if :attr:`kv_layout` + is ``NHD``, ``[num_kv_heads, kv_len, head_dim_vo]`` if :attr:`kv_layout` is ``HND``. custom_mask : Optional[torch.Tensor] The custom boolean mask tensor, shape: ``[qo_len, kv_len]``. @@ -733,7 +733,7 @@ def single_prefill_with_kv_cache( :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] - The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim_qk)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to 1.0. rope_theta : Optional[float] @@ -748,10 +748,10 @@ def single_prefill_with_kv_cache( Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_len, num_qo_heads, head_dim_vo]``. If :attr:`return_lse` is ``True``, a tuple of two tensors: - * The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. + * The attention output, shape: ``[qo_len, num_qo_heads, head_dim_vo]``. * The log sum exp value, shape: ``[qo_len, num_qo_heads]``. Examples @@ -833,12 +833,13 @@ def single_prefill_with_kv_cache( ) module_getter = get_single_prefill_module(backend) - out = torch.empty_like(q) + out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device) module_getter( q.dtype, k.dtype, q.dtype, - q.shape[-1], + q.shape[-1], # head_dim_qk + v.shape[-1], # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -1178,8 +1179,9 @@ def plan( paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, - head_dim: int, + head_dim_qk: int, page_size: int, + head_dim_vo: Optional[int] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, @@ -1211,10 +1213,13 @@ def plan( The number of query/output heads. num_kv_heads : int The number of key/value heads. - head_dim : int - The dimension of the heads. + head_dim_qk : int + The dimension of the query/key heads. page_size : int The size of each page in the paged kv-cache. + head_dim_vo : Optional[int] + The dimension of the value/output heads, if not provided, will be set to + ``head_dim_qk``. custom_mask : Optional[torch.Tensor] The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. The elements in the mask tensor should be either ``True`` or ``False``, @@ -1285,6 +1290,8 @@ def plan( if logits_soft_cap is None: logits_soft_cap = 0.0 + if head_dim_vo is None: + head_dim_vo = head_dim_qk batch_size = len(qo_indptr) - 1 if custom_mask is not None or packed_custom_mask is not None: @@ -1402,7 +1409,8 @@ def plan( kv_data_type, q_data_type, paged_kv_indptr.dtype, - head_dim, + head_dim_qk, + head_dim_vo, PosEncodingMode[pos_encoding_mode].value, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -1417,7 +1425,9 @@ def plan( if page_size != 1: vector_sparse_indptr_host = torch.cat( [ - torch.tensor([0], dtype=torch.int32, device=kv_lens_arr_host.device), + torch.tensor( + [0], dtype=torch.int32, device=kv_lens_arr_host.device + ), torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), ], dim=0, @@ -1441,7 +1451,8 @@ def plan( num_kv_heads, page_size, self.is_cuda_graph_enabled, - head_dim, + head_dim_qk, + head_dim_vo, causal, get_cuda_stream(device), ) @@ -1582,7 +1593,9 @@ def run( (q.size(0), q.size(1)), dtype=torch.float32, device=q.device ) - out = torch.empty_like(q) + out = torch.empty( + q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device + ) if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value @@ -1860,8 +1873,10 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, dtype=torch.uint8, - pin_memory=True, device="cpu", + self._int_workspace_buffer.shape, + dtype=torch.uint8, + pin_memory=True, + device="cpu", ) self._use_cuda_graph = use_cuda_graph if use_cuda_graph: @@ -1924,7 +1939,8 @@ def plan( kv_indptr: torch.Tensor, num_qo_heads: int, num_kv_heads: int, - head_dim: int, + head_dim_qk: int, + head_dim_vo: Optional[int] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, @@ -1950,8 +1966,11 @@ def plan( The number of query/output heads. num_kv_heads : int The number of key/value heads. - head_dim : int - The dimension of the heads. + head_dim_qk : int + The dimension of the heads on query/key tensor. + head_dim_vo : Optional[int] + The dimension of the heads on value/output tensor. + If not provided, will be set to ``head_dim_vo``. custom_mask : Optional[torch.Tensor] The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. The elements in the mask tensor should be either ``True`` or ``False``, @@ -1991,7 +2010,7 @@ def plan( where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to - ``1.0 / sqrt(head_dim)``. + ``1.0 / sqrt(head_dim_qk)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -2018,6 +2037,8 @@ def plan( if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) + if head_dim_vo is None: + head_dim_vo = head_dim_qk if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -2103,7 +2124,8 @@ def plan( kv_data_type, q_data_type, kv_indptr.dtype, - head_dim, + head_dim_qk, + head_dim_vo, PosEncodingMode[pos_encoding_mode].value, window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -2127,7 +2149,8 @@ def plan( num_kv_heads, 1, # page_size self.is_cuda_graph_enabled, - head_dim, + head_dim_qk, + head_dim_vo, causal, get_cuda_stream(device), ) @@ -2202,21 +2225,21 @@ def run( Parameters ---------- q : torch.Tensor - The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` + The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_qk]`` k : torch.Tensor - The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` + The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_qk]`` v : torch.Tensor - The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` + The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_vo]`` return_lse : bool Whether to return the logsumexp of attention output Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_vo]``. If :attr:`return_lse` is ``True``, a tuple of two tensors: - * The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + * The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim_vo]``. * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ _check_cached_qkv_data_type( @@ -2241,7 +2264,7 @@ def run( lse = torch.empty( (q.size(0), q.size(1)), dtype=torch.float32, device=q.device ) - out = torch.empty_like(q) + out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device) if is_float8(q): logging.warning( diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index e98adc59c..c7940d23c 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -397,7 +397,8 @@ def plan( kv_data_type, q_data_type, indptr.dtype, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, False, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap @@ -583,8 +584,8 @@ def run( self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len, - lse, out, + lse, TensorLayout[self._kv_layout].value, -1, # window_left _get_cache_alibi_slopes_buf(q.shape[1], self.device), diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 54c4e51c6..b9375a0b5 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -219,6 +219,10 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par const DTypeQ* q = params.q; const DTypeKV* k = params.k; const DTypeKV* v = params.v; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t kv_stride_n = params.kv_stride_n; + const uint32_t kv_stride_h = params.kv_stride_h; DTypeO* o = params.o; float* lse = params.lse; uint32_t kv_chunk_size = params.kv_chunk_size; @@ -254,11 +258,10 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } // apply rotary embedding to q matrix - q_vec = vec_apply_llama_rope(q + params.get_q_elem_offset(0, qo_head_idx, 0), - freq, seq_len - 1); + q_vec = vec_apply_llama_rope(q + qo_head_idx * q_stride_h, freq, seq_len - 1); } else { // do not apply rotary embedding to q matrix - q_vec.cast_load(q + params.get_q_elem_offset(0, qo_head_idx, tx * vec_size)); + q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size); } // multiple q_vec by sm_scale #pragma unroll @@ -280,9 +283,8 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par cp_async::pred_load( k_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - k + params.get_kv_elem_offset( - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, - tx * vec_size), + k + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); @@ -290,9 +292,8 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par cp_async::pred_load( v_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - v + params.get_kv_elem_offset( - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, - tx * vec_size), + v + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); @@ -320,9 +321,8 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - k + params.get_kv_elem_offset( - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, - tx * vec_size), + k + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); @@ -340,9 +340,8 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - v + params.get_kv_elem_offset( - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, - tx * vec_size), + v + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); diff --git a/include/flashinfer/attention/default_decode_params.cuh b/include/flashinfer/attention/default_decode_params.cuh index 350e3b923..d06e46338 100644 --- a/include/flashinfer/attention/default_decode_params.cuh +++ b/include/flashinfer/attention/default_decode_params.cuh @@ -44,7 +44,6 @@ struct SingleDecodeParams { uint32_t q_stride_h; uint32_t kv_stride_n; uint32_t kv_stride_h; - uint32_t head_dim; int32_t window_left; float logits_soft_cap; float sm_scale; @@ -66,7 +65,6 @@ struct SingleDecodeParams { q_stride_h(0), kv_stride_n(0), kv_stride_h(0), - head_dim(0), window_left(0), logits_soft_cap(0.0f), sm_scale(0.0f), @@ -93,7 +91,6 @@ struct SingleDecodeParams { q_stride_h(head_dim), kv_stride_n((kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim), kv_stride_h((kv_layout == QKVLayout::kNHD) ? head_dim : seq_len * head_dim), - head_dim(head_dim), window_left(window_left), logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), @@ -101,24 +98,6 @@ struct SingleDecodeParams { rope_rcp_theta(1.f / rope_theta), kv_chunk_size(0) {} - __host__ __device__ __forceinline__ size_t get_q_elem_offset(uint32_t qo_idx, - uint32_t qo_head_idx, - uint32_t feat_idx) const { - return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, q_stride_h); - } - - __host__ __device__ __forceinline__ size_t get_o_elem_offset(uint32_t qo_idx, - uint32_t qo_head_idx, - uint32_t feat_idx) const { - return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, num_qo_heads * head_dim, head_dim); - } - - __host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx, - uint32_t kv_head_idx, - uint32_t feat_idx) const { - return get_elem_offset_impl(kv_idx, kv_head_idx, feat_idx, kv_stride_n, kv_stride_h); - } - __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return 1; } __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { diff --git a/include/flashinfer/attention/default_prefill_params.cuh b/include/flashinfer/attention/default_prefill_params.cuh index dcc2fa9da..99f6f42ae 100644 --- a/include/flashinfer/attention/default_prefill_params.cuh +++ b/include/flashinfer/attention/default_prefill_params.cuh @@ -44,8 +44,10 @@ struct SinglePrefillParams { uint32_t num_kv_heads; uint32_t q_stride_n; uint32_t q_stride_h; - uint32_t kv_stride_n; - uint32_t kv_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; uint32_t head_dim; int32_t window_left; float logits_soft_cap; @@ -69,8 +71,10 @@ struct SinglePrefillParams { num_kv_heads(0), q_stride_n(0), q_stride_h(0), - kv_stride_n(0), - kv_stride_h(0), + k_stride_n(0), + k_stride_h(0), + v_stride_n(0), + v_stride_h(0), head_dim(0), window_left(0), logits_soft_cap(0.0f), @@ -99,8 +103,10 @@ struct SinglePrefillParams { kv_len(kv_len), q_stride_n(q_stride_n), q_stride_h(q_stride_h), - kv_stride_n(kv_stride_n), - kv_stride_h(kv_stride_h), + k_stride_n(kv_stride_n), + k_stride_h(kv_stride_h), + v_stride_n(kv_stride_n), + v_stride_h(kv_stride_h), head_dim(head_dim), window_left(window_left), logits_soft_cap(logits_soft_cap), @@ -141,8 +147,10 @@ struct BatchPrefillRaggedParams { uint32_t num_kv_heads; uint32_t q_stride_n; uint32_t q_stride_h; - uint32_t kv_stride_n; - uint32_t kv_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; int32_t window_left; float logits_soft_cap; float sm_scale; @@ -178,8 +186,10 @@ struct BatchPrefillRaggedParams { num_kv_heads(0), q_stride_n(0), q_stride_h(0), - kv_stride_n(0), - kv_stride_h(0), + k_stride_n(0), + k_stride_h(0), + v_stride_n(0), + v_stride_h(0), window_left(0), logits_soft_cap(0.0f), sm_scale(0.0f), @@ -222,8 +232,10 @@ struct BatchPrefillRaggedParams { num_kv_heads(num_kv_heads), q_stride_n(q_stride_n), q_stride_h(q_stride_h), - kv_stride_n(kv_stride_n), - kv_stride_h(kv_stride_h), + k_stride_n(kv_stride_n), + k_stride_h(kv_stride_h), + v_stride_n(kv_stride_n), + v_stride_h(kv_stride_h), window_left(window_left), logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), diff --git a/include/flashinfer/attention/hopper/attention_updater.cuh b/include/flashinfer/attention/hopper/attention_updater.cuh index e8d7c5d22..d15aded89 100644 --- a/include/flashinfer/attention/hopper/attention_updater.cuh +++ b/include/flashinfer/attention/hopper/attention_updater.cuh @@ -142,6 +142,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso template struct DefaultUpdater { using TensorT = decltype(make_tensor(Shape>{})); + constexpr static float fill_value = 0.f; template CUTLASS_DEVICE DefaultUpdater(MainloopParams params) {}; @@ -165,6 +166,7 @@ struct DefaultUpdater { template struct OnlineSoftmax { + constexpr static float fill_value = -math::inf; using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum, scores_scale; float sm_scale_log2; diff --git a/include/flashinfer/attention/hopper/default_params.cuh b/include/flashinfer/attention/hopper/default_params.cuh index 2f3d1a2de..f92b21849 100644 --- a/include/flashinfer/attention/hopper/default_params.cuh +++ b/include/flashinfer/attention/hopper/default_params.cuh @@ -51,7 +51,6 @@ struct SinglePrefillParams { int qo_len; int kv_len; - int head_dim; int num_qo_heads; int num_kv_heads; int group_size; @@ -97,7 +96,6 @@ struct BatchPrefillRaggedParams { int64_t nnz_qo; int64_t nnz_kv; - int head_dim; int num_qo_heads; int num_kv_heads; int group_size; @@ -143,7 +141,6 @@ struct BatchPrefillPagedParams { int64_t o_stride_h; int64_t nnz_qo; - int head_dim; int num_qo_heads; int num_kv_heads; int group_size; diff --git a/include/flashinfer/attention/hopper/epilogue.cuh b/include/flashinfer/attention/hopper/epilogue.cuh index 7f8b5a32c..781f8d222 100644 --- a/include/flashinfer/attention/hopper/epilogue.cuh +++ b/include/flashinfer/attention/hopper/epilogue.cuh @@ -68,8 +68,8 @@ struct CollectiveEpilogue { using DTypeO = typename Ktraits::DTypeO; static constexpr int CTA_Q = Ktraits::CTA_Q; static constexpr int CTA_KV = Ktraits::CTA_KV; - static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; - using TileShape_QKD = Shape, Int, Int>; + static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO; + using TileShape_PDV = Shape, Int, Int>; static constexpr int NUM_WARPS = Ktraits::NUM_WARPS; static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; @@ -78,9 +78,9 @@ struct CollectiveEpilogue { static constexpr int NUM_MMA_THREADS = NUM_THREADS - NUM_COPY_THREADS; using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< - GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_QKD{})), - decltype(cute::get<2>(TileShape_QKD{}))>()); - using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{}))); + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})), + decltype(cute::get<1>(TileShape_PDV{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{}))); using SmemCopyAtomO = Copy_Atom; using SharedStorage = cute::array_aligned>; @@ -97,11 +97,11 @@ struct CollectiveEpilogue { using TMA_O = decltype(make_tma_copy( GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{}, - select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for O + select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v); - static_assert(HEAD_DIM % VEC_SIZE == 0); - static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM / VEC_SIZE; + static_assert(HEAD_DIM_VO % VEC_SIZE == 0); + static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0); static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW; using TiledCopyOAtom = cute::Copy_Atom, DTypeO>; @@ -116,11 +116,11 @@ struct CollectiveEpilogue { // used for rmem -> smem O copy in fp8 kernel to undo column permutation using ThreadLayoutrO = Layout, _4, _1>, Stride<_4, _32, _1, _0>>; - using ValueLayoutrO = - Layout, Int>, Stride<_0, _2, Stride<_4, _1>, _8>>; + using ValueLayoutrO = Layout, Int>, + Stride<_0, _2, Stride<_4, _1>, _8>>; using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, DTypeO>{}, ThreadLayoutrO{}, ValueLayoutrO{})); - using TiledCopyShaperO = Shape<_8, Int, _16, Int>; + using TiledCopyShaperO = Shape<_8, Int, _16, Int>; using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout{})); // Host side kernel arguments @@ -174,7 +174,7 @@ struct CollectiveEpilogue { Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape>{}, qo_head_idx, qo_indptr, qo_len)(_, qo_tile_idx); - Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_QKD{})); + Tensor caccO = cute::make_identity_tensor(select<0, 1>(TileShape_PDV{})); auto thread_mma = tiled_mma.get_thread_slice(thread_idx); Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0, 0>(taccOcO))::value == 2); @@ -201,7 +201,7 @@ struct CollectiveEpilogue { } TiledCopyO gmem_tiled_copy_O; write_O(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O, - select<0, 2>(TileShape_QKD{}), sO, thread_idx, qo_tile_idx, + select<0, 1>(TileShape_PDV{}), sO, thread_idx, qo_tile_idx, qo_head_idx, qo_indptr, qo_len, write_warp_idx); } @@ -216,7 +216,7 @@ struct CollectiveEpilogue { auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.O_ptr), epilogue_params.layout_O); - Tensor gO = get_local_tile_tensor(mO, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + Tensor gO = get_local_tile_tensor(mO, select<0, 1>(TileShape_PDV{}), qo_head_idx, qo_indptr, qo_len)(_, _, qo_tile_idx); // (O, D) Tensor cO = cute::make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx) Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); @@ -233,7 +233,7 @@ struct CollectiveEpilogue { Tensor tOrOGroup = flatten_1(tOrO); // (CPY, (CPY_O, CPY_D)) Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D)) - const int qo_tile_size = get<0>(TileShape_QKD{}); + const int qo_tile_size = get<0>(TileShape_PDV{}); int valid_qo_tile_size = std::min(qo_len - qo_tile_idx * qo_tile_size, qo_tile_size); if (valid_qo_tile_size == qo_tile_size) { copy(tiled_copy_O, tOrOGroup, tOgOGroup); diff --git a/include/flashinfer/attention/hopper/kernel_traits.cuh b/include/flashinfer/attention/hopper/kernel_traits.cuh index 9fc8b10d2..2ac599ca6 100644 --- a/include/flashinfer/attention/hopper/kernel_traits.cuh +++ b/include/flashinfer/attention/hopper/kernel_traits.cuh @@ -39,8 +39,8 @@ struct SharedStorageQKVO { }; }; -template struct AttentionKernelTraits { using AttentionVariant = AttentionVariant_; @@ -54,8 +54,10 @@ struct AttentionKernelTraits { static constexpr int CTA_Q = CTA_Q_; static_assert(CTA_Q % 64 == 0); static constexpr int CTA_KV = CTA_KV_; - static constexpr int HEAD_DIM = HEAD_DIM_; - static_assert(HEAD_DIM % 32 == 0); + static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_; + static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_; + static_assert(HEAD_DIM_QK % 32 == 0); + static_assert(HEAD_DIM_VO % 32 == 0); static constexpr int NUM_WARPS = ((CTA_Q / 64) + 1) * 4; static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; @@ -63,7 +65,8 @@ struct AttentionKernelTraits { // where only one warp inside a warp group is used for TMA. static constexpr int NUM_PRODUCER_THREADS = cutlass::NumThreadsPerWarp; - using TileShape_QKD = Shape, Int, Int>; + using TileShape_QKD = Shape, Int, Int>; + using TileShape_PDV = Shape, Int, Int>; static constexpr int NUM_STAGES = NUM_STAGES_; @@ -71,9 +74,8 @@ struct AttentionKernelTraits { using TiledMmaQK = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutQKD{})); using TiledMmaPV = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(TileShape_QKD{})), GMMA::Major::K, - GMMA::Major::MN>(), + cute::GMMA::rs_op_selector(), AtomLayoutQKD{})); static constexpr int NUM_MMA_THREADS = size(TiledMmaQK{}); @@ -91,22 +93,22 @@ struct AttentionKernelTraits { make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector< - GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), - decltype(cute::get<2>(TileShape_QKD{}))>()); + GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})), + decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, - make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{}), Int{}))); + make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int{}))); // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutVt = decltype(composition( - SmemLayoutV{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}), - get<1>(TileShape_QKD{}), Int{}), + SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}), + get<2>(TileShape_PDV{}), Int{}), Step<_2, _1, _3>{}))); using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< - GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_QKD{})), - decltype(cute::get<2>(TileShape_QKD{}))>()); - using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{}))); + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})), + decltype(cute::get<1>(TileShape_PDV{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{}))); using MainloopPipeline = std::conditional_t, typename cutlass::PipelineAsync>; diff --git a/include/flashinfer/attention/hopper/mainloop.cuh b/include/flashinfer/attention/hopper/mainloop.cuh index b11442587..2a8f93620 100644 --- a/include/flashinfer/attention/hopper/mainloop.cuh +++ b/include/flashinfer/attention/hopper/mainloop.cuh @@ -29,12 +29,14 @@ struct CollectiveMainloop { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using TileShape_QKD = typename Ktraits::TileShape_QKD; + using TileShape_PDV = typename Ktraits::TileShape_PDV; static constexpr int CTA_Q = get<0>(TileShape_QKD{}); static constexpr int CTA_KV = get<1>(TileShape_QKD{}); static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; - static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK; + static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO; using GmemTiledCopyQ = cute::SM90_TMA_LOAD; using GmemTiledCopyKV = cute::SM90_TMA_LOAD; @@ -68,7 +70,7 @@ struct CollectiveMainloop { GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideT{}, int32_t(0)), StrideT{}), - take<0, 2>(SmemLayoutV{}), select<1, 2>(TileShape_QKD{}), _1{})); // no mcast + take<0, 2>(SmemLayoutV{}), select<2, 1>(TileShape_PDV{}), _1{})); // no mcast static constexpr bool USE_TMA_LOAD_KV = true; using MainloopPipeline = typename Ktraits::MainloopPipeline; @@ -80,11 +82,13 @@ struct CollectiveMainloop { static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesV = + static_cast(size(take<0, 2>(SmemLayoutV{})) * cutlass::sizeof_bits_v / 8); // Whether use scheduler barrier or hardware warp scheduler, using heuristic based on data type // and head dim static constexpr bool UseSchedulerBarrier = - cutlass::sizeof_bits_v == 8 ? HEAD_DIM >= 128 : HEAD_DIM <= 128; + cutlass::sizeof_bits_v == 8 ? HEAD_DIM_VO >= 128 : HEAD_DIM_VO <= 128; using WarpScheduler = WarpScheduler; // Host side kernel arguments @@ -120,7 +124,7 @@ struct CollectiveMainloop { select<1, 2>(TileShape_QKD{}), _1{}); // no mcast Tensor mV = make_tensor(make_gmem_ptr(args.V_ptr), args.layout_V); TMA_V tma_load_V = make_tma_copy(GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _0{}), - select<1, 2>(TileShape_QKD{}), _1{}); // no mcast + select<2, 1>(TileShape_PDV{}), _1{}); // no mcast return {args.layout_Q, args.layout_K, args.layout_V, tma_load_Q, tma_load_K, tma_load_V, args.window_left, args.additional_params}; } @@ -170,7 +174,7 @@ struct CollectiveMainloop { qo_len)(_, _, q_tile_idx); // (Q, D) Tensor gK = get_local_tile_tensor(mK, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, kv_len); // (K, D, _) - Tensor gV = get_local_tile_tensor(mV, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, + Tensor gV = get_local_tile_tensor(mV, select<2, 1>(TileShape_PDV{}), kv_head_idx, kv_indptr, kv_len); // (K, D, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); diff --git a/include/flashinfer/attention/hopper/mainloop_mma.cuh b/include/flashinfer/attention/hopper/mainloop_mma.cuh index ffe2bd296..d784e0a70 100644 --- a/include/flashinfer/attention/hopper/mainloop_mma.cuh +++ b/include/flashinfer/attention/hopper/mainloop_mma.cuh @@ -104,16 +104,16 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var qo_head_idx, kv_head_idx); if constexpr (!CAUSAL) { // Just masking based on col if (kv_idx >= kv_len) { - tSrS(i) = -math::inf; + tSrS(i) = AttentionUpdater::fill_value; } } else { if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { - tSrS(i) = -math::inf; + tSrS(i) = AttentionUpdater::fill_value; } } if constexpr (LEFT_SLIDING_WINDOW) { if (kv_idx < col_limit_left(qo_idx)) { - tSrS(i) = -math::inf; + tSrS(i) = AttentionUpdater::fill_value; } } } @@ -151,11 +151,11 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if (kv_idx >= col_limit_right(qo_idx)) { - tSrS(i) = -math::inf; + tSrS(i) = AttentionUpdater::fill_value; } if constexpr (LEFT_SLIDING_WINDOW) { if (kv_idx < col_limit_left(qo_idx)) { - tSrS(i) = -math::inf; + tSrS(i) = AttentionUpdater::fill_value; } } } @@ -227,7 +227,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if (kv_idx < col_limit_left(qo_idx)) { - tSrS(i) = -math::inf; + tSrS(i) = AttentionUpdater::fill_value; } } attention_updater.update(tSrS); diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index 316bfd1fd..09355bacf 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -50,6 +50,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp using DTypeO = typename Ktraits::DTypeO; using DTypeQKAccum = typename Ktraits::DTypeQKAccum; using TileShape_QKD = typename Ktraits::TileShape_QKD; + using TileShape_PDV = typename Ktraits::TileShape_PDV; using AttentionVariant = typename Ktraits::AttentionVariant; static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; @@ -85,7 +86,6 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer : MainloopPipeline::ThreadCategory::Consumer; if constexpr (use_tma_load_kv) { - pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NUM_MMA_THREADS; } else { @@ -100,6 +100,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); MainloopPipeline pipeline_k = [&] { if constexpr (use_tma_load_kv) { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; return MainloopPipeline(shared_storage.pipeline_k, pipeline_params, /*cluster_shape=*/Shape<_1, _1, _1>{}); } else { @@ -109,6 +110,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp MainloopPipeline pipeline_v = [&] { if constexpr (use_tma_load_kv) { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; return MainloopPipeline(shared_storage.pipeline_v, pipeline_params, /*cluster_shape=*/Shape<_1, _1, _1>{}); } else { @@ -188,7 +190,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 2>(TileShape_QKD{})); + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); AttentionUpdater attention_updater(mainloop_params); auto block_coord = work_tile_info.get_block_coord(scheduler_params); @@ -237,7 +239,6 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaS using DTypeQ = typename KernelTraits::DTypeQ; using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; - using TileShape_QKD = typename KernelTraits::TileShape_QKD; using CollectiveMainloop = CollectiveMainloop; @@ -245,19 +246,23 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaS using Scheduler = SingleTileScheduler; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {params.q_ptr, - get_gmem_layout(params.qo_len, params.num_qo_heads, params.head_dim, params.q_stride_n, + get_gmem_layout(params.qo_len, params.num_qo_heads, KernelTraits::HEAD_DIM_QK, + params.q_stride_n, params.q_stride_h), // layout_Q params.k_ptr, - get_gmem_layout(params.kv_len, params.num_kv_heads, params.head_dim, params.k_stride_n, + get_gmem_layout(params.kv_len, params.num_kv_heads, KernelTraits::HEAD_DIM_QK, + params.k_stride_n, params.k_stride_h), // layout_K params.v_ptr, - get_gmem_layout(params.kv_len, params.num_kv_heads, params.head_dim, params.v_stride_n, + get_gmem_layout(params.kv_len, params.num_kv_heads, KernelTraits::HEAD_DIM_VO, + params.v_stride_n, params.v_stride_h), // layout_V params.window_left, params.additional_params}); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ static_cast(params.o_ptr), - get_gmem_layout(params.qo_len, params.num_qo_heads, params.head_dim, params.o_stride_n, + get_gmem_layout(params.qo_len, params.num_qo_heads, KernelTraits::HEAD_DIM_VO, + params.o_stride_n, params.o_stride_h), // layout_O static_cast(params.lse_ptr), get_lse_gmem_layout(params.qo_len, params.num_qo_heads), // layout_LSE @@ -299,7 +304,6 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; using IdType = typename KernelTraits::IdType; - using TileShape_QKD = typename KernelTraits::TileShape_QKD; using CollectiveMainloop = SparseCollectiveMainloop; @@ -310,20 +314,22 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {params.q_ptr, - get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.q_stride_n, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_QK, + params.q_stride_n, params.q_stride_h), // layout_Q params.k_ptr, // NOTE(Zihao): nnz was useless here, we can just pass 0 - get_gmem_layout(/*nnz=*/0, params.num_kv_heads, params.head_dim, params.k_stride_n, + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM_QK, params.k_stride_n, params.k_stride_h), // layout_K params.v_ptr, - get_gmem_layout(/*nnz=*/0, params.num_kv_heads, params.head_dim, params.v_stride_n, + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM_VO, params.v_stride_n, params.v_stride_h), // layout_V params.kv_indices, params.window_left, params.additional_params}); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ params.o_ptr, - get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.o_stride_n, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_VO, + params.o_stride_n, params.o_stride_h), // layout_O params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE }); @@ -366,7 +372,6 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; using IdType = typename KernelTraits::IdType; - using TileShape_QKD = typename KernelTraits::TileShape_QKD; using CollectiveMainloop = CollectiveMainloop; @@ -376,20 +381,24 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, BatchPrefillPersistentTileScheduler>; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {params.q_ptr, - get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.q_stride_n, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_QK, + params.q_stride_n, params.q_stride_h), // layout_Q params.k_ptr, // NOTE(Zihao): nnz was useless here, we can just pass 0 - get_gmem_layout(params.nnz_kv, params.num_kv_heads, params.head_dim, params.k_stride_n, + get_gmem_layout(params.nnz_kv, params.num_kv_heads, KernelTraits::HEAD_DIM_QK, + params.k_stride_n, params.k_stride_h), // layout_K params.v_ptr, - get_gmem_layout(params.nnz_kv, params.num_kv_heads, params.head_dim, params.v_stride_n, + get_gmem_layout(params.nnz_kv, params.num_kv_heads, KernelTraits::HEAD_DIM_VO, + params.v_stride_n, params.v_stride_h), // layout_V params.window_left, params.additional_params}); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ params.o_ptr, - get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.o_stride_n, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_VO, + params.o_stride_n, params.o_stride_h), // layout_O params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE }); @@ -425,117 +434,109 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, return cudaSuccess; } -template +constexpr auto getCTATileSize() { + if constexpr (HEAD_DIM_QK == HEAD_DIM_VO) { + if constexpr (HEAD_DIM_QK == 64) { + return std::make_tuple(192, 128); + } else if constexpr (HEAD_DIM_QK == 128) { + if constexpr (CAUSAL) { + return std::make_tuple(128, 128); + } else { + return std::make_tuple(128, 192); + } + } else { + return std::make_tuple(128, 64); + } + } else { + // NOTE(Zihao) hack for deepseek prefill + static_assert(HEAD_DIM_QK == 192 && HEAD_DIM_VO == 128); + return std::make_tuple(128, 128); + } +} + +template cudaError_t SinglePrefillWithKVCacheDispatched(Params& params, cudaStream_t stream) { - static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + static_assert(HEAD_DIM_VO == 64 || HEAD_DIM_VO == 128 || HEAD_DIM_VO == 256); if (MASK_MODE == MaskMode::kCustom) { return cudaErrorNotSupported; // Not supported yet. } constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; - if constexpr (HEAD_DIM == 64) { - SinglePrefillWithKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); - } else if constexpr (HEAD_DIM == 128) { - SinglePrefillWithKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); - } else { - // HEAD_DIM == 256; - SinglePrefillWithKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); - } + constexpr auto CTA_TILE_SIZE = getCTATileSize(); + SinglePrefillWithKVCacheKernelTraitsDispatched< + AttentionKernelTraits(CTA_TILE_SIZE), + /*CTA_KV_=*/get<1>(CTA_TILE_SIZE), + /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, + typename Params::DTypeO, typename Params::IdType, AttentionVariant>, + LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); cudaError_t status = cudaGetLastError(); return status; } -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params& params, cudaStream_t stream) { - static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + static_assert(HEAD_DIM_VO == 64 || HEAD_DIM_VO == 128 || HEAD_DIM_VO == 256); if (MASK_MODE == MaskMode::kCustom) { return cudaErrorNotSupported; // Not supported yet. } constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; - if constexpr (HEAD_DIM == 64) { - BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); - } else if constexpr (HEAD_DIM == 128) { - BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); - } else { - // HEAD_DIM == 256; - BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); - } + constexpr auto CTA_TILE_SIZE = getCTATileSize(); + BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< + AttentionKernelTraits(CTA_TILE_SIZE), + /*CTA_KV_=*/get<1>(CTA_TILE_SIZE), + /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, + typename Params::DTypeO, typename Params::IdType, AttentionVariant>, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); cudaError_t status = cudaGetLastError(); return status; } -template cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) { - static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + static_assert(HEAD_DIM_VO == 64 || HEAD_DIM_VO == 128 || HEAD_DIM_VO == 256); if (MASK_MODE == MaskMode::kCustom) { return cudaErrorNotSupported; // Not supported yet. } constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; - if constexpr (HEAD_DIM == 64) { - // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later - BatchPrefillWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); - } else if constexpr (HEAD_DIM == 128) { - BatchPrefillWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + if constexpr (HEAD_DIM_QK == HEAD_DIM_VO) { + if constexpr (HEAD_DIM_VO == 64) { + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } else if constexpr (HEAD_DIM_VO == 128) { + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } else { + // HEAD_DIM == 256; + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } } else { - // HEAD_DIM == 256; - // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later - BatchPrefillWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + return cudaErrorNotSupported; } cudaError_t status = cudaGetLastError(); return status; diff --git a/include/flashinfer/attention/hopper/sparse_mainloop.cuh b/include/flashinfer/attention/hopper/sparse_mainloop.cuh index 033d4a9f3..7d93ec44e 100644 --- a/include/flashinfer/attention/hopper/sparse_mainloop.cuh +++ b/include/flashinfer/attention/hopper/sparse_mainloop.cuh @@ -39,11 +39,14 @@ struct SparseCollectiveMainloop { using DTypeKV = typename Ktraits::DTypeKV; using IdType = typename Ktraits::IdType; using TileShape_QKD = typename Ktraits::TileShape_QKD; + using TileShape_PDV = typename Ktraits::TileShape_PDV; static constexpr int CTA_Q = get<0>(TileShape_QKD{}); static constexpr int CTA_KV = get<1>(TileShape_QKD{}); static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; - static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK; + static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO; + static_assert(HEAD_DIM_QK == HEAD_DIM_VO); static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; using GmemTiledCopyQ = cute::SM90_TMA_LOAD; @@ -51,11 +54,16 @@ struct SparseCollectiveMainloop { using AlignmentTypeKV = cute::uint_byte_t(sizeof(DTypeKV)) * AlignmentKV>; // NOTE(Zihao): use SM80_CP_ASYNC for sparse loading of KV-cache using GmemCopyAtomKV = cute::Copy_Atom, DTypeKV>; - using GmemTiledCopyKV = + using GmemTiledCopyK = decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, cutlass::detail::TagToStrideB_t, decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); + using GmemTiledCopyV = + decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< + GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, + cutlass::detail::TagToStrideB_t, + decltype(cute::get<2>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutQ = typename Ktraits::SmemLayoutQ; using SmemLayoutK = typename Ktraits::SmemLayoutK; @@ -86,7 +94,7 @@ struct SparseCollectiveMainloop { static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr bool UseSchedulerBarrier = - cutlass::sizeof_bits_v == 8 ? HEAD_DIM >= 128 : HEAD_DIM <= 128; + cutlass::sizeof_bits_v == 8 ? HEAD_DIM_VO >= 128 : HEAD_DIM_VO <= 128; using WarpScheduler = WarpScheduler; // Host side kernel arguments @@ -185,34 +193,46 @@ struct SparseCollectiveMainloop { q_tile_idx, qo_len, kv_len); } - constexpr int HEAD_DIM = get<2>(TileShape_QKD{}); + constexpr int HEAD_DIM_QK = get<2>(TileShape_QKD{}); + constexpr int HEAD_DIM_VO = get<1>(TileShape_PDV{}); constexpr int CTA_KV = get<1>(TileShape_QKD{}); auto indexed_gather = BlockSparseIndexedGather(mainloop_params.kv_indices + kv_indptr); - Tensor mK = make_block_sparse_tensor( // (kv_len, D) + Tensor mK = make_block_sparse_tensor( // (kv_len, D_K) make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)), - make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_K), indexed_gather); - Tensor mV = make_block_sparse_tensor( // (kv_len, D) + make_shape(kv_len, HEAD_DIM_QK), stride<0>(mainloop_params.layout_K), indexed_gather); + Tensor mV = make_block_sparse_tensor( // (kv_len, D_V) make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)), - make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_V), indexed_gather); - - Tensor gK = local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) - Tensor gV = local_tile(mV, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) - Tensor cKV = cute::make_identity_tensor(gK.shape()); - - GmemTiledCopyKV gmem_tiled_copy_kv; - auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx); - - Tensor tKgK = gmem_thr_copy_kv.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) - Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) - Tensor tVgV = gmem_thr_copy_kv.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) - Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) - Tensor tKVcKV = gmem_thr_copy_kv.partition_D(cKV); // (CPY, CPY_KV, CPY_D) - Tensor tKVcKVGroup = flatten_1(tKVcKV); // (CPY, (CPY_KV, CPY_D)) + make_shape(kv_len, HEAD_DIM_VO), stride<0>(mainloop_params.layout_V), indexed_gather); + + Tensor gK = + local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D_K, kv) + Tensor gV = + local_tile(mV, select<2, 1>(TileShape_PDV{}), make_coord(_, _0{})); // (KV, D_V, kv) + Tensor cK = cute::make_identity_tensor(gK.shape()); + Tensor cV = cute::make_identity_tensor(gV.shape()); + + GmemTiledCopyK gmem_tiled_copy_k; + GmemTiledCopyV gmem_tiled_copy_v; + auto gmem_thr_copy_k = gmem_tiled_copy_k.get_slice(thread_idx); + auto gmem_thr_copy_v = gmem_tiled_copy_v.get_slice(thread_idx); + + Tensor tKgK = gmem_thr_copy_k.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) + Tensor tKsK = gmem_thr_copy_k.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tVgV = gmem_thr_copy_v.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) + Tensor tVsV = gmem_thr_copy_v.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tKcK = gmem_thr_copy_k.partition_D(cK); // (CPY, CPY_KV, CPY_D) + Tensor tKcKGroup = flatten_1(tKcK); // (CPY, (CPY_KV, CPY_D)) + Tensor tVcV = gmem_thr_copy_v.partition_D(cV); // (CPY, CPY_KV, CPY_D) + Tensor tVcVGroup = flatten_1(tVcV); // (CPY, (CPY_KV, CPY_D)) int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); - auto predicate_fn = [&](auto coords) { - auto s_coords = tKVcKVGroup(_0{}, coords); + auto k_predicate_fn = [&](auto coords) { + auto s_coords = tKcKGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_last_kv_tile_size); + }; + auto v_predicate_fn = [&](auto coords) { + auto s_coords = tVcVGroup(_0{}, coords); return elem_less(get<0>(s_coords), valid_last_kv_tile_size); }; @@ -222,7 +242,7 @@ struct SparseCollectiveMainloop { Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tKsKiGroup = flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup); + copy_if(gmem_tiled_copy_k, k_predicate_fn, tKgKiGroup, tKsKiGroup); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; @@ -251,7 +271,7 @@ struct SparseCollectiveMainloop { Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tVsViGroup = flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; @@ -260,7 +280,7 @@ struct SparseCollectiveMainloop { pipeline_k.producer_acquire(smem_pipe_write_k); Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + copy(gmem_tiled_copy_k, tKgKi, tKsKi); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; @@ -269,7 +289,7 @@ struct SparseCollectiveMainloop { Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tVsViGroup = flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); --kv_tile_idx; @@ -282,7 +302,7 @@ struct SparseCollectiveMainloop { Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + copy(gmem_tiled_copy_k, tKgKi, tKsKi); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; @@ -290,7 +310,7 @@ struct SparseCollectiveMainloop { pipeline_v.producer_acquire(smem_pipe_write_v); Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + copy(gmem_tiled_copy_v, tVgVi, tVsVi); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; @@ -302,7 +322,7 @@ struct SparseCollectiveMainloop { pipeline_v.producer_acquire(smem_pipe_write_v); Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D)) - copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + copy(gmem_tiled_copy_v, tVgVi, tVsVi); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; } diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 91e6776ee..f0b7e0cd4 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -24,7 +24,6 @@ #include "../cp_async.cuh" #include "../fastdiv.cuh" #include "../frag_layout_swizzle.cuh" -#include "../layout.cuh" #include "../math.cuh" #include "../mma.cuh" #include "../page.cuh" @@ -69,11 +68,13 @@ constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) { namespace { template -constexpr bool is_invalid_configuration(uint32_t NUM_MMA_Q, uint32_t NUM_MMA_D, uint32_t NUM_MMA_KV, +constexpr bool is_invalid_configuration(uint32_t NUM_MMA_Q, uint32_t NUM_MMA_D_QK, + uint32_t NUM_MMA_D_VO, uint32_t NUM_MMA_KV, uint32_t NUM_WARPS_Q, uint32_t NUM_WARPS_KV) { - return ((NUM_MMA_D < 4) || (NUM_MMA_D == 4 && NUM_MMA_KV % 2 == 1) || - (NUM_MMA_D > 4 && NUM_MMA_D % (2 * NUM_WARPS_Q) != 0) || - (NUM_MMA_Q * (8 * NUM_MMA_D + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= 256) || + return ((NUM_MMA_D_VO < 4) || (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || + (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && NUM_MMA_D_VO > 4 && + NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || + (NUM_MMA_Q * (8 * NUM_MMA_D_VO + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= 256) || (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || (sizeof(DTypeKV) == 1 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); } @@ -176,7 +177,7 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half /*! * \brief Produce k/v fragments from global memory to shared memory. * \tparam fill_mode The fill mode of the shared memory. - * \tparam NUM_MMA_D The number of fragments in y dimension. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. * \tparam NUM_MMA_KV The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam T The data type of the input tensor. @@ -188,12 +189,12 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half template __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, - T** gptr, const uint32_t kv_stride_n, + T** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, const uint32_t kv_len) { // NOTE(Zihao): for fp8, this function doesn't work for head_dim = 64 at the moment constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t num_warps = NUM_WARPS_Q * NUM_WARPS_KV; - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); + constexpr uint32_t upcast_head_dim_kv = head_dim / upcast_size(); const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; if constexpr (swizzle_mode == SwizzleMode::k128B) { @@ -206,15 +207,15 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(T)); ++j) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * num_elems_per_128b(); + *gptr += 8 * upcast_size(); } kv_idx += num_warps * 4; *smem_offset = - smem.template advance_offset_by_row(*smem_offset) - + smem.template advance_offset_by_row(*smem_offset) - sizeof(T) * NUM_MMA_D; - *gptr += num_warps * 4 * kv_stride_n - sizeof(T) * NUM_MMA_D * num_elems_per_128b(); + *gptr += num_warps * 4 * stride_n - sizeof(T) * NUM_MMA_D * upcast_size(); } - *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * upcast_head_dim_kv; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; // NOTE(Zihao): NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps @@ -223,15 +224,15 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = - smem.template advance_offset_by_row(*smem_offset); + smem.template advance_offset_by_row(*smem_offset); kv_idx += num_warps * 8; - *gptr += num_warps * 8 * kv_stride_n; + *gptr += num_warps * 8 * stride_n; } - *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * upcast_head_dim_kv; } } -template __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offset, const paged_kv_t& paged_kv, @@ -240,9 +241,9 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint3 // NOTE(Zihao): for fp8, this function doesn't work for head_dim = 64 at the moment constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; - constexpr uint32_t head_dim = NUM_MMA_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D_VO * 16; constexpr uint32_t num_warps = NUM_WARPS_Q * NUM_WARPS_KV; - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); + constexpr uint32_t upcast_head_dim_kv = head_dim / upcast_size(); const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; if constexpr (swizzle_mode == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; @@ -252,17 +253,17 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint3 for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; #pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { + for (uint32_t j = 0; j < NUM_MMA_D_VO / (8 / sizeof(DType)); ++j) { smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - gptr += 8 * num_elems_per_128b(); + gptr += 8 * upcast_size(); } kv_idx += num_warps * 4; *smem_offset = - smem.template advance_offset_by_row(*smem_offset) - - sizeof(DType) * NUM_MMA_D; + smem.template advance_offset_by_row(*smem_offset) - + sizeof(DType) * NUM_MMA_D_VO; } - *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * upcast_head_dim_kv; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; // NOTE(Zihao): NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps @@ -273,19 +274,19 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint3 smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); kv_idx += num_warps * 8; *smem_offset = - smem.template advance_offset_by_row(*smem_offset); + smem.template advance_offset_by_row(*smem_offset); } - *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * upcast_head_dim_kv; } } -template +template __device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const float rope_rcp_scale, const float rope_rcp_theta) { - constexpr uint32_t head_dim = NUM_MMA_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D_VO * 16; const uint32_t lane_idx = threadIdx.x; #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D / 2; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO / 2; ++mma_d) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { rope_freq[mma_d][j] = @@ -298,13 +299,15 @@ __device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const floa } } -template -__device__ __forceinline__ void init_states(AttentionVariant variant, float (*o_frag)[NUM_MMA_D][8], - DTypeQKAccum (*m)[2], float (*d)[2]) { +template +__device__ __forceinline__ void init_states(AttentionVariant variant, + float (*o_frag)[NUM_MMA_D_VO][8], DTypeQKAccum (*m)[2], + float (*d)[2]) { #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_frag[mma_q][mma_d][reg_id] = 0.f; @@ -324,7 +327,7 @@ __device__ __forceinline__ void init_states(AttentionVariant variant, float (*o_ } } -template __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t qo_upper_bound, @@ -332,12 +335,12 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t q_stride_h, const uint_fastdiv group_size, smem_t* q_smem) { - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim = NUM_MMA_D_QK * 16; + constexpr uint32_t upcast_head_dim_q = head_dim / upcast_size(); const uint32_t lane_idx = threadIdx.x, warp_idx_x = get_warp_idx_q(); if (get_warp_idx_kv() == 0) { - uint32_t q_smem_offset_w = q_smem->get_permuted_offset( + uint32_t q_smem_offset_w = q_smem->get_permuted_offset( warp_idx_x * NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -349,41 +352,41 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t q_idx = q; DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; #pragma unroll - for (uint32_t mma_do = 0; mma_do < NUM_MMA_D / 4; ++mma_do) { + for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_QK / 4; ++mma_do) { // load q fragment from gmem to smem q_smem->load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do); - q_ptr += 8 * num_elems_per_128b(); + q_ptr += 8 * upcast_size(); } q_smem_offset_w = - q_smem->template advance_offset_by_row<4, channel_size_128b_q>(q_smem_offset_w) - - 2 * NUM_MMA_D; + q_smem->template advance_offset_by_row<4, upcast_head_dim_q>(q_smem_offset_w) - + 2 * NUM_MMA_D_QK; } } } } -template __device__ __forceinline__ void q_smem_inplace_apply_rotary( const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, smem_t* q_smem, uint32_t* q_smem_offset_r, float (*rope_freq)[4]) { if (get_warp_idx_kv() == 0) { - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim = NUM_MMA_D_QK * 16; + constexpr uint32_t upcast_head_dim_q = head_dim / upcast_size(); const uint32_t lane_idx = threadIdx.x; uint32_t q_frag_local[2][4]; - static_assert(NUM_MMA_D % 4 == 0, "NUM_MMA_D must be a multiple of 4"); + static_assert(NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll - for (uint32_t mma_di = 0; mma_di < NUM_MMA_D / 2; ++mma_di) { + for (uint32_t mma_di = 0; mma_di < NUM_MMA_D_QK / 2; ++mma_di) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); + q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_rope( (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[mma_di], @@ -394,31 +397,31 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( q_smem_offset_r_first_half = q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); } - *q_smem_offset_r += 16 * channel_size_128b_q; + *q_smem_offset_r += 16 * upcast_head_dim_q; } - *q_smem_offset_r -= NUM_MMA_Q * 16 * channel_size_128b_q; + *q_smem_offset_r -= NUM_MMA_Q * 16 * upcast_head_dim_q; } } -template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( const uint32_t q_packed_idx_base, const IdType* q_rope_offset, smem_t* q_smem, const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4]) { if (get_warp_idx_kv() == 0) { - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim = NUM_MMA_D_QK * 16; + constexpr uint32_t upcast_head_dim_q = head_dim / upcast_size(); const uint32_t lane_idx = threadIdx.x; uint32_t q_frag_local[2][4]; - static_assert(NUM_MMA_D % 4 == 0, "NUM_MMA_D must be a multiple of 4"); + static_assert(NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D must be a multiple of 4"); #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll - for (uint32_t mma_di = 0; mma_di < NUM_MMA_D / 2; ++mma_di) { + for (uint32_t mma_di = 0; mma_di < NUM_MMA_D_QK / 2; ++mma_di) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); + q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_rope_with_pos( (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[mma_di], @@ -428,21 +431,21 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( q_smem_offset_r_first_half = q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); } - *q_smem_offset_r += 16 * channel_size_128b_q; + *q_smem_offset_r += 16 * upcast_head_dim_q; } - *q_smem_offset_r -= NUM_MMA_Q * 16 * channel_size_128b_q; + *q_smem_offset_r -= NUM_MMA_Q * 16 * upcast_head_dim_q; } } -template __device__ __forceinline__ void q_smem_inplace_transform(const Params& params, AttentionVariant variant, smem_t* q_smem) { using DTypeQ = typename Params::DTypeQ; const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim = NUM_MMA_D_QK * 16; + constexpr uint32_t upcast_head_dim_q = head_dim / upcast_size(); constexpr uint32_t num_warps = NUM_WARPS_Q * NUM_WARPS_KV; #pragma unroll for (uint32_t i = 0; i < NUM_MMA_Q * head_dim / (NUM_WARPS_KV * 16); ++i) { @@ -456,18 +459,18 @@ __device__ __forceinline__ void q_smem_inplace_transform(const Params& params, } } -template __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_idx_base, smem_t* k_smem, uint32_t* k_smem_offset_r, float (*rope_freq)[4]) { static_assert(sizeof(DTypeKV) == 2); - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim = NUM_MMA_D_QK * 16; + constexpr uint32_t upcast_head_dim_kv = head_dim / upcast_size(); uint32_t k_frag_local[2][4]; const uint32_t lane_idx = threadIdx.x; - if constexpr (NUM_MMA_D == 4 && NUM_WARPS_Q == 4) { + if constexpr (NUM_MMA_D_QK == 4 && NUM_WARPS_Q == 4) { static_assert(NUM_WARPS_KV == 1); const uint32_t warp_idx = get_warp_idx_q(); // horizontal-axis: y @@ -476,10 +479,11 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id // | 1-16 | 16-32 | 32-48 | 48-64 | // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | - static_assert(NUM_MMA_KV % 2 == 0, "when NUM_MMA_D == 4, NUM_MMA_KV must be a multiple of 2"); + static_assert(NUM_MMA_KV % 2 == 0, + "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / 4; *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * channel_size_128b_kv; + (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * upcast_head_dim_kv; #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV / 2; ++i) { // uint32_t mma_kv = warp_idx / 2 + i * 2; @@ -493,15 +497,15 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id rope_freq[mma_di], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); - *k_smem_offset_r += 32 * channel_size_128b_kv; + *k_smem_offset_r += 32 * upcast_head_dim_kv; kv_idx += 32; } *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - - ((warp_idx / 2) + NUM_MMA_KV) * 16 * channel_size_128b_kv; + ((warp_idx / 2) + NUM_MMA_KV) * 16 * upcast_head_dim_kv; } else { const uint32_t warp_idx_x = get_warp_idx_q(), warp_idx_z = get_warp_idx_kv(); - static_assert(NUM_MMA_D % (2 * NUM_WARPS_Q) == 0); + static_assert(NUM_MMA_D_QK % (2 * NUM_WARPS_Q) == 0); // horizontal axis: y // vertical axis: z // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 | ... @@ -514,11 +518,11 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id for (uint32_t i = 0; i < NUM_MMA_KV; ++i) { uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; #pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (2 * NUM_WARPS_Q); ++j) { + for (uint32_t j = 0; j < NUM_MMA_D_QK / (2 * NUM_WARPS_Q); ++j) { uint32_t mma_di = warp_idx_x + j * NUM_WARPS_Q; k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = - k_smem->template advance_offset_by_column(k_smem_offset_r_first_half, 0); + k_smem->template advance_offset_by_column(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); @@ -527,37 +531,38 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id k_smem_offset_r_first_half = k_smem->template advance_offset_by_column<2 * NUM_WARPS_Q>( k_smem_offset_r_first_half, mma_di); } - *k_smem_offset_r += 16 * channel_size_128b_kv; + *k_smem_offset_r += 16 * upcast_head_dim_kv; kv_idx += 16; } *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - NUM_MMA_KV * 16 * channel_size_128b_kv; + (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - NUM_MMA_KV * 16 * upcast_head_dim_kv; } } -template +template __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, uint32_t* k_smem_offset_r, DTypeQKAccum (*s_frag)[NUM_MMA_KV][8]) { - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim = NUM_MMA_D_QK * 16; + constexpr uint32_t upcast_head_dim_q = head_dim / upcast_size(); + constexpr uint32_t upcast_head_dim_k = head_dim / upcast_size(); uint32_t a_frag[NUM_MMA_Q][4], b_frag[4]; // compute q*k^T #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[mma_q]); *q_smem_offset_r = - q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); + q_smem->template advance_offset_by_row<16, upcast_head_dim_q>(*q_smem_offset_r); } *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, mma_d) - - NUM_MMA_Q * 16 * channel_size_128b_q; + NUM_MMA_Q * 16 * upcast_head_dim_q; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { @@ -575,7 +580,7 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); } *k_smem_offset_r = - k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r); + k_smem->template advance_offset_by_row<16, upcast_head_dim_k>(*k_smem_offset_r); #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { @@ -603,18 +608,18 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d / 2); } - *k_smem_offset_r -= NUM_MMA_KV * 16 * channel_size_128b_kv; + *k_smem_offset_r -= NUM_MMA_KV * 16 * upcast_head_dim_k; } else { *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d) - - NUM_MMA_KV * 16 * channel_size_128b_kv; + NUM_MMA_KV * 16 * upcast_head_dim_k; } } - *q_smem_offset_r -= NUM_MMA_D * 2; - *k_smem_offset_r -= NUM_MMA_D * sizeof(DTypeKV); + *q_smem_offset_r -= NUM_MMA_D_QK * 2; + *k_smem_offset_r -= NUM_MMA_D_QK * sizeof(DTypeKV); } -template +template __device__ __forceinline__ void logits_transform( const Params& params, AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, @@ -648,8 +653,8 @@ __device__ __forceinline__ void logits_transform( } } -template +template __device__ __forceinline__ void logits_mask(const Params& params, AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, @@ -691,11 +696,11 @@ __device__ __forceinline__ void logits_mask(const Params& params, AttentionVaria } } -template __device__ __forceinline__ void update_mdo_states(AttentionVariant variant, DTypeQKAccum (*s_frag)[NUM_MMA_KV][8], - float (*o_frag)[NUM_MMA_D][8], + float (*o_frag)[NUM_MMA_D_VO][8], DTypeQKAccum (*m)[2], float (*d)[2]) { if constexpr (variant.use_softmax) { if constexpr (std::is_same_v) { @@ -717,7 +722,7 @@ __device__ __forceinline__ void update_mdo_states(AttentionVariant variant, float o_scale = math::ptx_exp2(m_prev - m[mma_q][j]); d[mma_q][j] *= o_scale; #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; @@ -759,7 +764,7 @@ __device__ __forceinline__ void update_mdo_states(AttentionVariant variant, float o_scale = math::ptx_exp2(float(m_prev[j] - m[mma_q][j])); d[mma_q][j] *= o_scale; #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; @@ -779,15 +784,15 @@ __device__ __forceinline__ void update_mdo_states(AttentionVariant variant, } } -template __device__ __forceinline__ void compute_sfm_v(AttentionVariant variant, smem_t* v_smem, uint32_t* v_smem_offset_r, DTypeQKAccum (*s_frag)[NUM_MMA_KV][8], - float (*o_frag)[NUM_MMA_D][8], float (*d)[2]) { - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); + float (*o_frag)[NUM_MMA_D_VO][8], float (*d)[2]) { + constexpr uint32_t head_dim = NUM_MMA_D_VO * 16; + constexpr uint32_t upcast_head_dim_v = head_dim / upcast_size(); DTypeQ s_frag_f16[NUM_MMA_Q][NUM_MMA_KV][8]; if constexpr (std::is_same_v) { @@ -817,7 +822,7 @@ __device__ __forceinline__ void compute_sfm_v(AttentionVariant variant, #pragma unroll for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { uint32_t b_frag[4]; if constexpr (sizeof(DTypeKV) == 1) { uint32_t b_frag_f8[2]; @@ -853,15 +858,17 @@ __device__ __forceinline__ void compute_sfm_v(AttentionVariant variant, } } *v_smem_offset_r = - v_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*v_smem_offset_r) - - sizeof(DTypeKV) * NUM_MMA_D; + v_smem->template advance_offset_by_row<16, upcast_head_dim_v>(*v_smem_offset_r) - + sizeof(DTypeKV) * NUM_MMA_D_VO; } - *v_smem_offset_r -= 16 * NUM_MMA_KV * channel_size_128b_kv; + *v_smem_offset_r -= 16 * NUM_MMA_KV * upcast_head_dim_v; } -template -__device__ __forceinline__ void normalize_d(AttentionVariant variant, float (*o_frag)[NUM_MMA_D][8], - DTypeQKAccum (*m)[2], float (*d)[2]) { +template +__device__ __forceinline__ void normalize_d(AttentionVariant variant, + float (*o_frag)[NUM_MMA_D_VO][8], DTypeQKAccum (*m)[2], + float (*d)[2]) { if constexpr (variant.use_softmax) { float d_rcp[NUM_MMA_Q][2]; // compute reciprocal of d @@ -877,7 +884,7 @@ __device__ __forceinline__ void normalize_d(AttentionVariant variant, float (*o_ #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_frag[mma_q][mma_d][reg_id] = @@ -888,27 +895,39 @@ __device__ __forceinline__ void normalize_d(AttentionVariant variant, float (*o_ } } +template +constexpr size_t SmemSizeThreadBlockAttnSync() { + if constexpr (NUM_WARPS_KV == 1) { + return 0; + } else { + return (NUM_WARPS_Q * NUM_WARPS_KV) * (NUM_MMA_Q * 16) * HEAD_DIM_VO * sizeof(float) + + (NUM_WARPS_Q * NUM_WARPS_KV) * (NUM_MMA_Q * 16) * 2 * sizeof(float); + } +} + /*! * \brief Synchronize the states of the MDO kernel across the threadblock along threadIdx.z. */ -template __device__ __forceinline__ void threadblock_sync_mdo_states( - AttentionVariant variant, float (*o_frag)[NUM_MMA_D][8], float* smem_workspace, + AttentionVariant variant, float (*o_frag)[NUM_MMA_D_VO][8], float* smem_workspace, DTypeQKAccum (*m)[2], float (*d)[2], const uint32_t warp_idx, const uint32_t lane_idx) { // only necessary when blockDim.z > 1 if constexpr (NUM_WARPS_KV > 1) { - float2* smem_md = (float2*)(smem_workspace + - NUM_MMA_Q * NUM_MMA_D * NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE * 8); - // o: [num_warps, NUM_MMA_Q, NUM_MMA_D, WARP_SIZE(32), 8] - // md: [num_warps, NUM_MMA_Q, 2, WARP_SIZE(32), 2 (m/d)] + float2* smem_md = (float2*)(smem_workspace + NUM_MMA_Q * NUM_MMA_D_VO * NUM_WARPS_Q * + NUM_WARPS_KV * WARP_SIZE * 8); + // o: [num_warps, NUM_MMA_Q, NUM_MMA_D_VO, WARP_SIZE(32), 8] + // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { vec_t::memcpy( smem_workspace + - (((warp_idx * NUM_MMA_Q + mma_q) * NUM_MMA_D + mma_d) * WARP_SIZE + lane_idx) * 8, + (((warp_idx * NUM_MMA_Q + mma_q) * NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * + 8, o_frag[mma_q][mma_d]); } } @@ -918,7 +937,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - smem_md[((warp_idx * NUM_MMA_Q + mma_q) * 2 + j) * WARP_SIZE + lane_idx] = + smem_md[((warp_idx * NUM_MMA_Q + mma_q) * 2 + j) * 8 + lane_idx / 4] = make_float2(float(m[mma_q][j]), d[mma_q][j]); } } @@ -938,8 +957,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_q) * 2 + j) * - WARP_SIZE + - lane_idx]; + 8 + + lane_idx / 4]; float m_prev = m_new, d_prev = d_new; m_new = max(m_new, md.x); d_new = d_prev * math::ptx_exp2(m_prev - m_new) + md.y * math::ptx_exp2(md.x - m_new); @@ -952,8 +971,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_q) * 2 + j) * - WARP_SIZE + - lane_idx]; + 8 + + lane_idx / 4]; float mi = md.x; o_scale[j][i] = math::ptx_exp2(float(mi - m_new)); } @@ -962,7 +981,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( } #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { vec_t o_new; o_new.fill(0.f); #pragma unroll @@ -971,7 +990,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( oi.load(smem_workspace + ((((i * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q + mma_q) * - NUM_MMA_D + + NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * @@ -990,7 +1009,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { vec_t o_new; o_new.fill(0.f); #pragma unroll @@ -999,7 +1018,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( oi.load(smem_workspace + ((((i * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q + mma_q) * - NUM_MMA_D + + NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * @@ -1016,14 +1035,14 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( } } -template __device__ __forceinline__ void write_o_reg_gmem( - float (*o_frag)[NUM_MMA_D][8], smem_t* o_smem, DTypeO* o_ptr_base, + float (*o_frag)[NUM_MMA_D_VO][8], smem_t* o_smem, DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv group_size) { - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim = NUM_MMA_D_VO * 16; + constexpr uint32_t upcast_head_dim_o = head_dim / upcast_size(); const uint32_t warp_idx_x = get_warp_idx_q(); const uint32_t lane_idx = threadIdx.x; @@ -1031,27 +1050,27 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_VO; ++mma_d) { uint32_t o_frag_f16[4]; vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; - ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = + ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * upcast_head_dim_o))[lane_idx % 4] = o_frag_f16[1]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + - 8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[3]; + 8 * upcast_head_dim_o))[lane_idx % 4] = o_frag_f16[3]; #endif } } - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( warp_idx_x * NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -1063,16 +1082,16 @@ __device__ __forceinline__ void write_o_reg_gmem( const uint32_t o_idx = q; DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h; #pragma unroll - for (uint32_t mma_do = 0; mma_do < NUM_MMA_D / 4; ++mma_do) { + for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_VO / 4; ++mma_do) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } - o_ptr += 8 * num_elems_per_128b(); + o_ptr += 8 * upcast_size(); o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do); } o_smem_offset_w = - o_smem->template advance_offset_by_row<4, channel_size_128b_out>(o_smem_offset_w) - - 2 * NUM_MMA_D; + o_smem->template advance_offset_by_row<4, upcast_head_dim_o>(o_smem_offset_w) - + 2 * NUM_MMA_D_VO; } } } @@ -1086,7 +1105,7 @@ __device__ __forceinline__ void write_o_reg_gmem( * \tparam mask_mode The mask mode used in the attention operation. * \tparam POS_ENCODING_MODE The positional encoding mode. * \tparam NUM_MMA_Q The number of fragments in x dimension. - * \tparam NUM_MMA_D The number of fragments in y dimension. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. * \tparam NUM_MMA_KV The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam DTypeQ The data type of the query tensor. @@ -1104,8 +1123,8 @@ __device__ __forceinline__ void write_o_reg_gmem( * used in RoPE. */ template + uint32_t NUM_MMA_D_QK, uint32_t NUM_MMA_D_VO, uint32_t NUM_MMA_KV, uint32_t NUM_WARPS_Q, + uint32_t NUM_WARPS_KV, typename DTypeQKAccum, typename AttentionVariant, typename Params> __global__ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKVCacheKernel( const uint_fastdiv group_size, const __grid_constant__ Params params) { @@ -1127,8 +1146,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV const bool partition_kv = params.partition_kv; const uint32_t q_stride_n = params.q_stride_n; const uint32_t q_stride_h = params.q_stride_h; - const uint32_t kv_stride_n = params.kv_stride_n; - const uint32_t kv_stride_h = params.kv_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; const int32_t maybe_window_left = params.window_left; static_assert(sizeof(DTypeQ) == 2); @@ -1137,8 +1158,6 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; constexpr uint32_t num_rows_per_cta = NUM_MMA_Q * NUM_WARPS_Q * 16; - const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, /*head_dim=*/NUM_MMA_D * 16); const uint32_t num_chunks = gridDim.y; const uint32_t max_chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; @@ -1152,44 +1171,44 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV AttentionVariant variant(params, /*batch_idx=*/0, smem); const uint32_t window_left = variant.window_left; - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim_vo = NUM_MMA_D_VO * 16; + constexpr uint32_t head_dim_qk = NUM_MMA_D_QK * 16; + constexpr uint32_t upcast_head_dim_q = head_dim_qk / upcast_size(); + constexpr uint32_t upcast_head_dim_k = head_dim_qk / upcast_size(); + constexpr uint32_t upcast_head_dim_v = head_dim_vo / upcast_size(); + constexpr uint32_t upcast_head_dim_o = head_dim_vo / upcast_size(); DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; DTypeQKAccum m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; - float rope_freq[NUM_MMA_D / 2][4]; + float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; - init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta); + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta); } - init_states(variant, o_frag, m, d); + init_states(variant, o_frag, m, d); // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q * 16; constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); + const uint32_t o_stride_n = num_qo_heads * head_dim_vo, o_stride_h = head_dim_vo; DTypeQ* q_ptr_base = - q + qkv_info.get_q_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + q + (kv_head_idx * group_size) * q_stride_h + (lane_idx % 8) * upcast_size(); DTypeO* o_ptr_base = partition_kv - ? o + chunk_idx * num_qo_heads * head_dim + - qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()) - : o + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h + + (lane_idx % 8) * upcast_size() + : o + (kv_head_idx * group_size) * o_stride_h + (lane_idx % 8) * upcast_size(); - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_q() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem( + load_q_global_smem( qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); cp_async::commit_group(); @@ -1197,23 +1216,22 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV block.sync(); if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, - &q_smem_offset_r, rope_freq); + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); block.sync(); } - q_smem_inplace_transform( + q_smem_inplace_transform( params, variant, &qo_smem); constexpr SwizzleMode swizzle_mode_kv = - (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + (sizeof(DTypeKV) == 1 && head_dim_vo == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k128B ? 4 : 8; constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k128B ? 8 : 4; smem_t k_smem(smem + - (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim), - v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) + - NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV)) * - 16 * head_dim); + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim_qk), + v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) * 16 * head_dim_qk) + + (NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV) * 16 * head_dim_qk)); const uint32_t num_iterations = ceil_div( MASK_MODE == MaskMode::kCausal @@ -1237,28 +1255,28 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV (16 * NUM_WARPS_KV * NUM_MMA_KV); DTypeKV* k_ptr = - k + qkv_info.get_kv_elem_offset( - chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, kv_head_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b()); + k + (chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols) * k_stride_n + + kv_head_idx * k_stride_h + (lane_idx % kv_frag_cols) * upcast_size(); DTypeKV* v_ptr = - v + qkv_info.get_kv_elem_offset( - chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, kv_head_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b()); + v + (chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols) * v_stride_n + + kv_head_idx * v_stride_h + (lane_idx % kv_frag_cols) * upcast_size(); - uint32_t k_smem_offset_r = k_smem.get_permuted_offset( + uint32_t k_smem_offset_r = k_smem.get_permuted_offset( get_warp_idx_kv() * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.get_permuted_offset( + v_smem_offset_r = v_smem.get_permuted_offset( get_warp_idx_kv() * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = k_smem.get_permuted_offset( + k_smem_offset_w = k_smem.get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols), + v_smem_offset_w = v_smem.get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); - produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, 0, chunk_size); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size); cp_async::commit_group(); - produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, 0, chunk_size); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size); cp_async::commit_group(); #pragma unroll 1 @@ -1267,7 +1285,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV block.sync(); if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( chunk_start + iter * 16 * NUM_WARPS_KV * NUM_MMA_KV, &k_smem, &k_smem_offset_r, rope_freq); @@ -1275,10 +1293,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - logits_transform( + logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, @@ -1286,7 +1304,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { - logits_mask( + logits_mask( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, @@ -1294,42 +1312,42 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV } // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); - produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, + chunk_size); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v( + compute_sfm_v( variant, &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); - produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, + chunk_size); cp_async::commit_group(); } cp_async::wait_group<0>(); block.sync(); // threadblock synchronization - threadblock_sync_mdo_states( + threadblock_sync_mdo_states( variant, o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(variant, o_frag, m, d); + normalize_d(variant, o_frag, m, d); // write back - write_o_reg_gmem( + write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ - partition_kv ? num_qo_heads * head_dim * num_chunks : num_qo_heads * head_dim, - /*o_stride_h=*/head_dim, group_size); + partition_kv ? num_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size); // write lse if constexpr (variant.use_softmax) { @@ -1362,8 +1380,9 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV #endif } -template +template cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; @@ -1383,10 +1402,11 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D const uint32_t group_size = num_qo_heads / num_kv_heads; const uint_fastdiv group_size_fastdiv(group_size); - constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; uint32_t cta_tile_q = 0; int64_t unpacked_qo_len = qo_len * group_size; - if (unpacked_qo_len > 64 && HEAD_DIM < 256) { + if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) { cta_tile_q = 128; } else { auto compute_capacity = GetCudaComputeCapability(); @@ -1421,28 +1441,30 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks // TODO(Zihao): fix the following computation - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_mma_kv_reg = - (HEAD_DIM >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) ? 2 : (8 / NUM_MMA_Q); // TODO(Zihao): fix the following computation const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / + (max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / (2 * NUM_WARPS_KV); // control NUM_MMA_KV for maximum warp occupancy DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { if constexpr (is_invalid_configuration( - NUM_MMA_Q, NUM_MMA_D, NUM_MMA_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { + NUM_MMA_Q, NUM_MMA_D_QK, NUM_MMA_D_VO, NUM_MMA_KV, NUM_WARPS_Q, + NUM_WARPS_KV)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q - << " NUM_MMA_D=" << NUM_MMA_D << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); @@ -1450,13 +1472,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; constexpr uint32_t num_rows_per_cta = NUM_MMA_Q * NUM_WARPS_Q * 16; auto kernel = - SinglePrefillWithKVCacheKernel; + SinglePrefillWithKVCacheKernel; // TODO(Zihao): fix the following computation - uint32_t smem_size = (NUM_MMA_Q * NUM_WARPS_Q * sizeof(DTypeQ) + - NUM_MMA_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * - 16 * HEAD_DIM; + size_t smem_size = + max(SmemSizeThreadBlockAttnSync(), + NUM_MMA_Q * NUM_WARPS_Q * 16 * HEAD_DIM_QK * sizeof(DTypeQ) + + NUM_MMA_KV * NUM_WARPS_KV * 16 * (HEAD_DIM_QK + HEAD_DIM_VO) * sizeof(DTypeKV)); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); int num_blocks_per_sm = 0; @@ -1487,7 +1511,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D } else { // Use cooperative groups to increase occupancy params.partition_kv = true; - float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM); + float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); auto o = params.o; auto lse = params.lse; params.o = tmp; @@ -1499,10 +1523,10 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, - HEAD_DIM, stream)); + HEAD_DIM_VO, stream)); } else { FLASHINFER_CUDA_CALL( - AttentionSum(tmp, o, num_chunks, qo_len, num_qo_heads, HEAD_DIM, stream)); + AttentionSum(tmp, o, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream)); } } } @@ -1512,8 +1536,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D } template + uint32_t NUM_MMA_D_QK, uint32_t NUM_MMA_D_VO, uint32_t NUM_MMA_KV, uint32_t NUM_WARPS_Q, + uint32_t NUM_WARPS_KV, typename DTypeQKAccum, typename AttentionVariant, typename Params> __global__ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRaggedKVCacheKernel( const uint_fastdiv group_size, const __grid_constant__ Params params) { @@ -1541,13 +1565,16 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag const bool partition_kv = params.partition_kv; const uint32_t q_stride_n = params.q_stride_n; const uint32_t q_stride_h = params.q_stride_h; - const uint32_t kv_stride_n = params.kv_stride_n; - const uint32_t kv_stride_h = params.kv_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; const int32_t maybe_window_left = params.window_left; static_assert(sizeof(DTypeQ) == 2); static_assert(sizeof(DTypeO) == 2); - constexpr uint32_t head_dim = NUM_MMA_D * 16; + constexpr uint32_t head_dim_qk = NUM_MMA_D_QK * 16; + constexpr uint32_t head_dim_vo = NUM_MMA_D_VO * 16; const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); auto block = cg::this_thread_block(); @@ -1570,50 +1597,49 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; - const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, /*head_dim=*/NUM_MMA_D * 16); const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + constexpr uint32_t upcast_head_dim_q = head_dim_qk / upcast_size(); + constexpr uint32_t upcast_head_dim_k = head_dim_qk / upcast_size(); + constexpr uint32_t upcast_head_dim_v = head_dim_vo / upcast_size(); + constexpr uint32_t upcast_head_dim_o = head_dim_vo / upcast_size(); DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; DTypeQKAccum m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; - float rope_freq[NUM_MMA_D / 2][4]; + float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; - init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta); + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta); } - init_states(variant, o_frag, m, d); + init_states(variant, o_frag, m, d); const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q * 16; constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); + const uint32_t o_stride_n = num_qo_heads * head_dim_vo, o_stride_h = head_dim_vo; - DTypeQ* q_ptr_base = - q + qkv_info.get_q_elem_offset(q_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n + + kv_head_idx * group_size * q_stride_h + + (lane_idx % 8) * upcast_size(); DTypeO* o_ptr_base = partition_kv - ? o + kv_tile_idx * num_qo_heads * head_dim + - qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()) - : o + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + (lane_idx % 8) * upcast_size() + : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h + + (lane_idx % 8) * upcast_size(); - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_q() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem( + load_q_global_smem( qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); @@ -1628,18 +1654,18 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag q_rope_offset = params.maybe_q_rope_offset; } if (!q_rope_offset) { - q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, - &qo_smem, &q_smem_offset_r, rope_freq); + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); } else { - q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq); } block.sync(); } - q_smem_inplace_transform( + q_smem_inplace_transform( params, variant, &qo_smem); const uint32_t num_iterations = ceil_div( @@ -1665,41 +1691,42 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag (16 * NUM_WARPS_KV * NUM_MMA_KV); constexpr SwizzleMode swizzle_mode_kv = - (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + (sizeof(DTypeKV) == 1 && head_dim_vo == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k128B ? 4 : 8; constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k128B ? 8 : 4; smem_t k_smem(smem + - (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim), - v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) + - NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV)) * - 16 * head_dim); + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim_qk), + v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) * 16 * head_dim_qk) + + (NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV) * 16 * head_dim_qk)); - uint32_t k_smem_offset_r = k_smem.get_permuted_offset( + uint32_t k_smem_offset_r = k_smem.get_permuted_offset( get_warp_idx_kv() * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.get_permuted_offset( + v_smem_offset_r = v_smem.get_permuted_offset( get_warp_idx_kv() * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = k_smem.get_permuted_offset( + k_smem_offset_w = k_smem.get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols), + v_smem_offset_w = v_smem.get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); DTypeKV* k_ptr = - k + qkv_info.get_kv_elem_offset(kv_indptr[request_idx] + chunk_start + - warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, - kv_head_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b()); + k + + (kv_indptr[request_idx] + chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols) * + k_stride_n + + kv_head_idx * k_stride_h + (lane_idx % kv_frag_cols) * upcast_size(); DTypeKV* v_ptr = - v + qkv_info.get_kv_elem_offset(kv_indptr[request_idx] + chunk_start + - warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, - kv_head_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b()); + v + + (kv_indptr[request_idx] + chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols) * + v_stride_n + + kv_head_idx * v_stride_h + (lane_idx % kv_frag_cols) * upcast_size(); - produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, 0, chunk_size); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size); cp_async::commit_group(); - produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, 0, chunk_size); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size); cp_async::commit_group(); #pragma unroll 1 @@ -1712,7 +1739,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag if constexpr (has_maybe_k_rope_offset_v) { k_rope_offset = params.maybe_k_rope_offset; } - k_smem_inplace_apply_rotary( (k_rope_offset == nullptr ? 0 : k_rope_offset[request_idx]) + chunk_start + iter * 16 * NUM_WARPS_KV * NUM_MMA_KV, @@ -1721,10 +1748,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - logits_transform( + logits_transform( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, @@ -1732,7 +1759,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { - logits_mask( + logits_mask( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, @@ -1740,44 +1767,44 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag } // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); - produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, + chunk_size); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v( + compute_sfm_v( variant, &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); - produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, + chunk_size); cp_async::commit_group(); } cp_async::wait_group<0>(); block.sync(); // threadblock synchronization - threadblock_sync_mdo_states( + threadblock_sync_mdo_states( variant, o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(variant, o_frag, m, d); + normalize_d(variant, o_frag, m, d); const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // write back - write_o_reg_gmem( + write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ - partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, - /*o_stride_h=*/head_dim, group_size); + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size); // write lse if constexpr (variant.use_softmax) { @@ -1812,8 +1839,8 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag } template + uint32_t NUM_MMA_D_QK, uint32_t NUM_MMA_D_VO, uint32_t NUM_MMA_KV, uint32_t NUM_WARPS_Q, + uint32_t NUM_WARPS_KV, typename DTypeQKAccum, typename AttentionVariant, typename Params> __global__ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPagedKVCacheKernel( const uint_fastdiv group_size, const __grid_constant__ Params params) { @@ -1867,45 +1894,46 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); - constexpr uint32_t head_dim = NUM_MMA_D * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + constexpr uint32_t head_dim_qk = NUM_MMA_D_QK * 16; + constexpr uint32_t head_dim_vo = NUM_MMA_D_VO * 16; + constexpr uint32_t upcast_head_dim_q = head_dim_qk / upcast_size(); + constexpr uint32_t upcast_head_dim_k = head_dim_qk / upcast_size(); + constexpr uint32_t upcast_head_dim_v = head_dim_vo / upcast_size(); + constexpr uint32_t upcast_head_dim_o = head_dim_vo / upcast_size(); DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; DTypeQKAccum m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; - float rope_freq[NUM_MMA_D / 2][4]; + float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; - init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta); + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta); } - init_states(variant, o_frag, m, d); + init_states(variant, o_frag, m, d); const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q * 16; const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h; constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); - DTypeQ* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), - q_stride_n, q_stride_h); + const uint32_t o_stride_n = num_qo_heads * head_dim_vo, o_stride_h = head_dim_vo; + DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n + + (kv_head_idx * group_size) * q_stride_h + + (lane_idx % 8) * upcast_size(); DTypeO* o_ptr_base = - partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + - get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), - num_qo_heads * head_dim, head_dim) - : o + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), - num_qo_heads * head_dim, head_dim); - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + partition_kv + ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + (lane_idx % 8) * upcast_size() + : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h + + (lane_idx % 8) * upcast_size(); + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_q() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem( + load_q_global_smem( qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); @@ -1919,39 +1947,40 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag q_rope_offset = params.maybe_q_rope_offset; } if (q_rope_offset == nullptr) { - q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, - &qo_smem, &q_smem_offset_r, rope_freq); + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); } else { - q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq); } block.sync(); } - q_smem_inplace_transform( + q_smem_inplace_transform( params, variant, &qo_smem); constexpr SwizzleMode swizzle_mode_kv = - (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + (sizeof(DTypeKV) == 1 && head_dim_vo == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k128B ? 4 : 8; constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k128B ? 8 : 4; smem_t k_smem(smem + - (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim), - v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) + - NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV)) * - 16 * head_dim); + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim_qk), + v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) * 16 * head_dim_qk) + + (NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV) * 16 * head_dim_qk)); size_t kv_offset[NUM_MMA_KV * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q]; - uint32_t k_smem_offset_r = k_smem.get_permuted_offset( + uint32_t k_smem_offset_r = k_smem.get_permuted_offset( get_warp_idx_kv() * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.get_permuted_offset( + v_smem_offset_r = v_smem.get_permuted_offset( get_warp_idx_kv() * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = k_smem.get_permuted_offset( + k_smem_offset_w = k_smem.get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols), + v_smem_offset_w = v_smem.get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; @@ -1966,14 +1995,14 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag kv_frag_rows * NUM_WARPS_Q * NUM_WARPS_KV * i, page_iter, entry_idx); kv_offset[i] = paged_kv.protective_get_kv_offset( - page_iter, kv_head_idx, entry_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b(), last_indptr); + page_iter, kv_head_idx, entry_idx, (lane_idx % kv_frag_cols) * upcast_size(), + last_indptr); } - page_produce_kv( - k_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); + page_produce_kv( + k_smem, &k_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); cp_async::commit_group(); - page_produce_kv( - v_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); + page_produce_kv( + v_smem, &v_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); cp_async::commit_group(); const uint32_t num_iterations = ceil_div( @@ -2010,14 +2039,14 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag kv_frag_rows * NUM_WARPS_Q * NUM_WARPS_KV * i, page_iter, entry_idx); kv_offset[i] = paged_kv.protective_get_kv_offset( - page_iter, kv_head_idx, entry_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b(), last_indptr); + page_iter, kv_head_idx, entry_idx, (lane_idx % kv_frag_cols) * upcast_size(), + last_indptr); } cp_async::wait_group<1>(); block.sync(); if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + chunk_start + iter * 16 * NUM_WARPS_KV * NUM_MMA_KV, @@ -2026,10 +2055,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - logits_transform( + logits_transform( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, @@ -2037,7 +2066,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { - logits_mask( + logits_mask( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, @@ -2045,23 +2074,23 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag } // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); - page_produce_kv( - k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, + page_produce_kv( + k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, kv_offset, chunk_size); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v( + compute_sfm_v( variant, &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); - page_produce_kv( - v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, + page_produce_kv( + v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, kv_offset, chunk_size); cp_async::commit_group(); } @@ -2069,20 +2098,20 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag block.sync(); // threadblock synchronization - threadblock_sync_mdo_states( + threadblock_sync_mdo_states( variant, o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(variant, o_frag, m, d); + normalize_d(variant, o_frag, m, d); const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // write_back - write_o_reg_gmem( + write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ - partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, - /*o_stride_h=*/head_dim, group_size); + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size); // write lse if constexpr (variant.use_softmax) { @@ -2116,13 +2145,14 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag #endif } -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; const uint32_t padded_batch_size = params.padded_batch_size; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t num_kv_heads = params.num_kv_heads; @@ -2139,7 +2169,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); - constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; using DTypeQKAccum = typename std::conditional, half, float>::type; @@ -2151,39 +2182,42 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks // TODO(Zihao): fix the following computation - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_mma_kv_reg = - (HEAD_DIM >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) ? 2 : (8 / NUM_MMA_Q); // TODO(Zihao): fix the following computation const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / + (max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / (2 * NUM_WARPS_KV); DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { if constexpr (is_invalid_configuration( - NUM_MMA_Q, NUM_MMA_D, NUM_MMA_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { + NUM_MMA_Q, NUM_MMA_D_QK, NUM_MMA_D_VO, NUM_MMA_KV, NUM_WARPS_Q, + NUM_WARPS_KV)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q - << " NUM_MMA_D=" << NUM_MMA_D << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { // TODO(Zihao): fix the following computation - uint32_t smem_size = (NUM_MMA_Q * NUM_WARPS_Q * sizeof(DTypeQ) + - NUM_MMA_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * - 16 * HEAD_DIM; + size_t smem_size = max( + SmemSizeThreadBlockAttnSync(), + NUM_MMA_Q * NUM_WARPS_Q * 16 * HEAD_DIM_QK * sizeof(DTypeQ) + + NUM_MMA_KV * NUM_WARPS_KV * 16 * (HEAD_DIM_QK + HEAD_DIM_VO) * sizeof(DTypeKV)); auto kernel = - BatchPrefillWithRaggedKVCacheKernel; + BatchPrefillWithRaggedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (tmp_v == nullptr) { @@ -2205,11 +2239,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates( tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM, stream)); + params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } else { FLASHINFER_CUDA_CALL( VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM, stream)); + params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } } } @@ -2217,13 +2251,14 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para return cudaSuccess; } -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; const uint32_t padded_batch_size = params.padded_batch_size; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t num_kv_heads = params.paged_kv.num_heads; @@ -2241,7 +2276,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); - constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; using DTypeQKAccum = typename std::conditional, half, float>::type; @@ -2253,39 +2289,42 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks // TODO(Zihao): fix the following computation - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_mma_kv_reg = - (HEAD_DIM >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) ? 2 : (8 / NUM_MMA_Q); // TODO(Zihao): fix the following computation const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / + (max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / (2 * NUM_WARPS_KV); DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { if constexpr (is_invalid_configuration( - NUM_MMA_Q, NUM_MMA_D, NUM_MMA_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { + NUM_MMA_Q, NUM_MMA_D_QK, NUM_MMA_D_VO, NUM_MMA_KV, NUM_WARPS_Q, + NUM_WARPS_KV)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q - << " NUM_MMA_D=" << NUM_MMA_D << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { // TODO(Zihao): fix the following computation - uint32_t smem_size = (NUM_MMA_Q * NUM_WARPS_Q * sizeof(DTypeQ) + - NUM_MMA_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * - 16 * HEAD_DIM; + size_t smem_size = max( + SmemSizeThreadBlockAttnSync(), + NUM_MMA_Q * NUM_WARPS_Q * 16 * HEAD_DIM_QK * sizeof(DTypeQ) + + NUM_MMA_KV * NUM_WARPS_KV * 16 * (HEAD_DIM_QK + HEAD_DIM_VO) * sizeof(DTypeKV)); auto kernel = - BatchPrefillWithPagedKVCacheKernel; + BatchPrefillWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (tmp_v == nullptr) { @@ -2306,11 +2345,11 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates( tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM, stream)); + params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } else { FLASHINFER_CUDA_CALL( VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM, stream)); + params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } } } diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 2a46d9275..4c1b5af4c 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -626,8 +626,9 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, uint32_t page_size, bool enable_cuda_graph, - uint32_t sizeof_dtype_o, cudaStream_t stream) { + uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -648,7 +649,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, max_batch_size_if_split, + num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); plan_info.cta_tile_q = cta_tile_q; @@ -696,7 +697,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i if (split_kv) { AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.v_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * cta_tile_q * head_dim * sizeof_dtype_o, 16, + num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof_dtype_o, 16, "batch_prefill_tmp_v"); plan_info.s_offset = float_allocator.aligned_alloc_offset( num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "batch_prefill_tmp_s"); @@ -783,15 +784,13 @@ struct PrefillPlanSM90Info { }; template -inline cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_size_in_bytes, - void* int_buffer, void* page_locked_int_buffer, - size_t int_workspace_size_in_bytes, - PrefillPlanSM90Info& plan_info, IdType* qo_indptr_h, - IdType* kv_indptr_h, IdType* kv_len_arr_h, - uint32_t total_num_rows, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size, bool causal, bool enable_cuda_graph, - uint32_t sizeof_dtype_o, cudaStream_t stream) { +inline cudaError_t PrefillSM90Plan( + void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, + void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, + PrefillPlanSM90Info& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h, + uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool causal, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -820,7 +819,7 @@ inline cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_si std::sort(idx_qo_kv_len_vec.begin(), idx_qo_kv_len_vec.end(), [](const auto& a, const auto& b) { return std::get<2>(a) > std::get<2>(b); }); int cta_tile_q = 128; - if (head_dim == 64) { + if (head_dim_vo == 64) { cta_tile_q = 192; } diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index 0b0800d04..8c76f1ef0 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -40,7 +40,7 @@ using b128_t = uint4; * \tparam T The data type of the elements. */ template -constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { +constexpr __host__ __device__ __forceinline__ uint32_t upcast_size() { return sizeof(b128_t) / sizeof(T); } diff --git a/setup.py b/setup.py index f7d95585d..827c1c2a3 100644 --- a/setup.py +++ b/setup.py @@ -133,7 +133,7 @@ def generate_cuda() -> None: ext_modules = [] cmdclass = {} -install_requires = ["torch", "ninja"] +install_requires = ["numpy", "torch", "ninja"] generate_build_meta({}) if enable_aot: diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index c216fdca9..a3c3ebaf5 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -159,15 +159,15 @@ class BatchDecodeHandler { cudaStream_t stream_; }; -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, cudaStream_t stream); -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, cudaStream_t stream); @@ -188,7 +188,7 @@ class BatchPrefillHandler { return PrefillPlan(float_buffer, float_workspace_size_in_bytes, int_buffer, page_locked_buffer_, int_workspace_size_in_bytes, plan_info_, qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, enable_cuda_graph_, + num_kv_heads, head_dim, head_dim, page_size, enable_cuda_graph_, sizeof(DTypeO), stream_); } @@ -277,8 +277,9 @@ class BatchPrefillHandler { cudaStream_t stream_; }; -template +template cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, cudaStream_t stream); @@ -306,7 +307,7 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); - return SinglePrefillWithKVCacheDispatched(params, tmp, stream); })})}); @@ -368,7 +369,7 @@ cudaError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); - return SinglePrefillWithKVCacheDispatched(params, tmp, stream); @@ -419,9 +420,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( params.padded_batch_size = plan_info.padded_batch_size; DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - BatchPrefillWithRaggedKVCacheDispatched( + BatchPrefillWithRaggedKVCacheDispatched( params, handler->GetTmpV(), handler->GetTmpS(), stream); }); })})})}); @@ -471,9 +472,9 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( params.padded_batch_size = plan_info.padded_batch_size; DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { return BatchPrefillWithPagedKVCacheDispatched< - CTA_TILE_Q, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE, - AttentionVariant>(params, handler->GetTmpV(), handler->GetTmpS(), - stream); + CTA_TILE_Q, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, + MASK_MODE, AttentionVariant>(params, handler->GetTmpV(), + handler->GetTmpS(), stream); }) })})})}); return cudaSuccess; diff --git a/tests/jit_utils.py b/tests/jit_utils.py index b500e86ef..6cc8787bf 100644 --- a/tests/jit_utils.py +++ b/tests/jit_utils.py @@ -115,7 +115,8 @@ def jit_prefill_attention_func_args( q_dtype, kv_dtype, q_dtype, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -132,7 +133,8 @@ def jit_prefill_attention_func_args( kv_dtype, q_dtype, torch.int32, - head_dim, + head_dim, # head_dim_qk + head_dim, # head_dim_vo pos_encoding_mode, use_sliding_window, use_logits_soft_cap, diff --git a/tests/test_block_sparse.py b/tests/test_block_sparse.py index 037cc4538..a7a643008 100644 --- a/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -127,3 +127,4 @@ def test_block_sparse_attention( if __name__ == "__main__": test_block_sparse_attention(1, 1, 64, 64, 1, 1, 128, False) + test_block_sparse_attention(16, 16, 256, 256, 16, 16, 256, True) diff --git a/tests/test_deepseek_prefill.py b/tests/test_deepseek_prefill.py new file mode 100644 index 000000000..87092ce06 --- /dev/null +++ b/tests/test_deepseek_prefill.py @@ -0,0 +1,154 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch + +import flashinfer + + +def attention_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, +) -> torch.Tensor: + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + * sm_scale + ) + + if causal: + mask = ( + torch.arange(kv_len - qo_len, kv_len).unsqueeze(1) + >= torch.arange(0, kv_len).unsqueeze(0) + ).to(q.device) + else: + mask = torch.ones(qo_len, kv_len).to(q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref + + +@pytest.mark.parametrize("kv_len", [5532, 7563]) +@pytest.mark.parametrize("qo_len", [1832, 3928]) +@pytest.mark.parametrize("num_kv_heads", [4, 32]) +@pytest.mark.parametrize("num_qo_heads", [32]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("backend", ["fa2", "fa3"]) +def test_single_prefill_with_kv_cache( + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + causal, + backend, +): + head_dim_qk = 192 + head_dim_vo = 128 + q = torch.randn(qo_len, num_qo_heads, head_dim_qk).to(0).half() + k = torch.zeros(kv_len, num_kv_heads, head_dim_qk).to(0).half() + v = torch.randn(kv_len, num_kv_heads, head_dim_vo).to(0).half() + + o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=causal, backend=backend) + + sm_scale = 1.0 / (head_dim_qk**0.5) + + if num_qo_heads != num_kv_heads: + k = k.repeat_interleave(num_qo_heads // num_kv_heads, dim=1) + v = v.repeat_interleave(num_qo_heads // num_kv_heads, dim=1) + + o_ref = attention_ref(1, q, k, v, causal, sm_scale) + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [544, 977]) +@pytest.mark.parametrize("qo_len", [377, 177]) +@pytest.mark.parametrize("num_kv_heads", [4, 32]) +@pytest.mark.parametrize("num_qo_heads", [32]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("backend", ["fa2", "fa3"]) +def test_batch_prefill_with_ragged_kv_cache( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + causal, + backend, +): + kv_layout = "NHD" + head_dim_qk = 192 + head_dim_vo = 128 + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim_qk).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + + k = torch.zeros(batch_size * kv_len, num_kv_heads, head_dim_qk).to(0).half() + v = torch.randn(batch_size * kv_len, num_kv_heads, head_dim_vo).to(0).half() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout, backend=backend + ) + wrapper.plan( + q_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + ) + o = wrapper.run(q, k, v) + + sm_scale = 1.0 / (head_dim_qk**0.5) + if num_qo_heads != num_kv_heads: + k = k.repeat_interleave(num_qo_heads // num_kv_heads, dim=1) + v = v.repeat_interleave(num_qo_heads // num_kv_heads, dim=1) + + o_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) + + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_single_prefill_with_kv_cache(54, 37, 4, 32, False, "fa2") + test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 4, 4, False, "fa2") diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index 2a4e27650..29c0711bc 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -60,7 +60,8 @@ def test_single_decode_mask(): torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o - 128, # head_dim + 128, # head_dim_qk + 128, # head_dim_vo ["custom_mask"], # additional_tensor_names ["uint8_t"], # additional_tensor_dtypes ["sm_scale"], # # additional_scalar_names @@ -158,7 +159,8 @@ def test_flash_sigmoid(): torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o - 128, # hidden_dim + 128, # head_dim_qk + 128, # head_dim_vo [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names @@ -229,7 +231,8 @@ def test_dump_logits(): torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o - 128, # hidden_dim + 128, # head_dim_qk + 128, # head_dim_vo ["output_logits"], # additional_tensor_names ["float"], # additional_tensor_dtypes ["sm_scale"], # additional_scalar_names @@ -263,7 +266,8 @@ def test_batch_decode_flash_sigmoid(use_tensor_cores): torch.float16, # dtype_kv torch.float16, # dtype_o torch.int32, # idtype - 128, # hidden_dim + 128, # hidden_dim_qk + 128, # hidden_dim_vo [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names @@ -369,7 +373,8 @@ def test_batch_prefill_flash_sigmoid(): torch.float16, # dtype_kv torch.float16, # dtype_o torch.int32, # idtype - 128, # hidden_dim + 128, # hidden_dim_qk + 128, # hidden_dim_vo [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names @@ -487,7 +492,8 @@ def test_batch_prefill_sm90_flash_sigmoid(): torch.float16, # dtype_kv torch.float16, # dtype_o torch.int32, # idtype - 128, # hidden_dim + 128, # hidden_dim_qk + 128, # hidden_dim_vo [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names @@ -552,6 +558,7 @@ def test_batch_prefill_sm90_flash_sigmoid(): sigmoid_bias = 0.25 o = wrapper.run(q, k, v, logits_scale, sigmoid_bias) + print(o) wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args ) @@ -642,7 +649,8 @@ def test_debug_print_logits(): torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o - 128, # hidden_dim + 128, # hidden_dim_qk + 128, # hidden_dim_vo [], # additional_tensor_names [], # additional_tensor_dtypes ["sm_scale"], # additional_scalar_names @@ -715,7 +723,8 @@ def test_sm90_debug_print_logits(): torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o - 128, # hidden_dim + 128, # hidden_dim_qk + 128, # hidden_dim_vo [], # additional_tensor_names [], # additional_tensor_dtypes ["sm_scale"], # additional_scalar_names