Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support deepseek prefill attention shape #765

Merged
merged 21 commits into from
Feb 1, 2025
51 changes: 37 additions & 14 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_dispatch_inc,
generate_single_decode_inst,
generate_single_prefill_inst,
)
Expand All @@ -47,6 +48,19 @@ def write_if_different(path: Path, content: str) -> None:

path.mkdir(parents=True, exist_ok=True)

write_if_different(
path / "dispatch.inc",
generate_dispatch_inc.get_dispatch_inc_str(
argparse.Namespace(
head_dims=head_dims,
head_dims_sm90=head_dims,
pos_encoding_modes=[0],
use_fp16_qk_reductions=[0],
mask_modes=mask_modes,
)
),
)

write_if_different(
path / "aot_default_additional_params.h",
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
Expand Down Expand Up @@ -79,9 +93,10 @@ def write_if_different(path: Path, content: str) -> None:
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"single_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu"
fname = f"single_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu"
content = generate_single_decode_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
Expand All @@ -93,7 +108,8 @@ def write_if_different(path: Path, content: str) -> None:
f"single_decode_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
Expand All @@ -114,9 +130,10 @@ def write_if_different(path: Path, content: str) -> None:
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"batch_paged_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu"
fname = f"batch_paged_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu"
content = generate_batch_paged_decode_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
Expand All @@ -130,7 +147,8 @@ def write_if_different(path: Path, content: str) -> None:
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"dtype_idx_{idtype}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
Expand All @@ -153,9 +171,10 @@ def write_if_different(path: Path, content: str) -> None:
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
fname = f"single_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
content = generate_single_prefill_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -172,7 +191,8 @@ def write_if_different(path: Path, content: str) -> None:
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
Expand All @@ -198,9 +218,10 @@ def write_if_different(path: Path, content: str) -> None:
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_paged_prefill_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -211,9 +232,10 @@ def write_if_different(path: Path, content: str) -> None:
)
write_if_different(path / fname, content)

fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_ragged_prefill_inst.get_cu_file_str(
head_dim,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -234,7 +256,8 @@ def write_if_different(path: Path, content: str) -> None:
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_{head_dim}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
Expand Down
19 changes: 11 additions & 8 deletions aot_build_utils/generate_batch_paged_decode_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
dtype_q,
dtype_kv,
Expand All @@ -35,25 +36,25 @@ def get_cu_file_str(

using Params = BatchDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>;

template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);

template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);

template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);

template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim_qk}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
Expand All @@ -69,20 +70,22 @@ def get_cu_file_str(

}}
""".format(
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
dtype_q=dtype_literal[dtype_q],
dtype_kv=dtype_literal[dtype_kv],
dtype_out=dtype_literal[dtype_out],
idtype=idtype_literal[idtype],
head_dim_kpe=head_dim // 8,
head_dim=head_dim_vo, # NOTE(Zihao): for MLA instantiation, we should move them to a standalone file
head_dim_kpe=head_dim_vo // 8,
)
return content


if __name__ == "__main__":
pattern = (
r"batch_paged_decode_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_paged_decode_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)

Expand Down
10 changes: 6 additions & 4 deletions aot_build_utils/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -41,13 +42,14 @@ def get_cu_file_str(
def get_insts(attention_variant, dtype_out):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{cta_tile_q}, {head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{cta_tile_q}, {head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
Params params,
{dtype_out}* tmp_v,
float* tmp_s, cudaStream_t stream);
""".format(
cta_tile_q=cta_tile_q,
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
use_fp16_qk_reduction=use_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
Expand Down Expand Up @@ -92,7 +94,7 @@ def get_insts(attention_variant, dtype_out):

if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_paged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)
compiled_pattern = re.compile(pattern)
Expand Down
20 changes: 13 additions & 7 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -37,7 +38,8 @@ def get_cu_file_str(
def get_insts(attention_variant):
return """
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
Expand All @@ -46,7 +48,8 @@ def get_insts(attention_variant):
(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
Expand All @@ -55,7 +58,8 @@ def get_insts(attention_variant):
(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
Expand All @@ -64,15 +68,17 @@ def get_insts(attention_variant):
(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
<{head_dim_qk},
{head_dim_vo},
{mask_mode},
/*USE_SLIDING_WINDOW=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant},
Params>
(Params& params, cudaStream_t stream);
""".format(
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
Expand Down Expand Up @@ -107,7 +113,7 @@ def get_insts(attention_variant):

if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_paged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_"
r"dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
Expand Down
10 changes: 6 additions & 4 deletions aot_build_utils/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@


def get_cu_file_str(
head_dim,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
Expand All @@ -41,13 +42,14 @@ def get_cu_file_str(
def get_insts(attention_variant, dtype_out):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{cta_tile_q}, {head_dim}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{cta_tile_q}, {head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
Params params,
{dtype_out}* tmp_v,
float* tmp_s, cudaStream_t stream);
""".format(
cta_tile_q=cta_tile_q,
head_dim=head_dim,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
use_fp16_qk_reduction=use_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
Expand Down Expand Up @@ -94,7 +96,7 @@ def get_insts(attention_variant, dtype_out):

if __name__ == "__main__":
pattern = (
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"batch_ragged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)
compiled_pattern = re.compile(pattern)
Expand Down
Loading