-
Notifications
You must be signed in to change notification settings - Fork 276
/
Copy pathgenerate_batch_ragged_prefill_inst.py
107 lines (83 loc) · 3.51 KB
/
generate_batch_ragged_prefill_inst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import re
import sys
from pathlib import Path
from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
def get_cu_file_str(
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):
cta_tile_q_choice = [128, 64, 32, 16]
def get_insts(attention_variant, dtype_out):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{cta_tile_q}, {head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode}, {attention_variant}, Params>(
Params params,
{dtype_out}* tmp_v,
float* tmp_s, cudaStream_t stream);
""".format(
cta_tile_q=cta_tile_q,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
use_fp16_qk_reduction=use_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
dtype_out=dtype_out,
)
for cta_tile_q in cta_tile_q_choice
]
)
use_custom_mask = "true" if int(mask_mode) == 2 else "false"
dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]
content = f"""#include <flashinfer/attention_impl.cuh>
namespace flashinfer {{
using Params = BatchPrefillRaggedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>;
using AttentionVariant1 = DefaultAttention<{use_custom_mask}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;
{get_insts("AttentionVariant1", dtype_out)}
using AttentionVariant2 = DefaultAttention<{use_custom_mask}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;
{get_insts("AttentionVariant2", dtype_out)}
using AttentionVariant3 = DefaultAttention<{use_custom_mask}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;
{get_insts("AttentionVariant3", dtype_out)}
using AttentionVariant4 = DefaultAttention<{use_custom_mask}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;
{get_insts("AttentionVariant4", dtype_out)}
}}
"""
return content
if __name__ == "__main__":
pattern = (
r"batch_ragged_prefill_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))