Skip to content

Commit 1ebbde3

Browse files
authored
feat: Separate QK/VO head dim dispatch for sm90 AOT (#778)
1 parent fc03772 commit 1ebbde3

8 files changed

+61
-49
lines changed

Diff for: aot_build_utils/generate.py

-14
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
generate_batch_paged_decode_inst,
2525
generate_batch_paged_prefill_inst,
2626
generate_batch_ragged_prefill_inst,
27-
generate_dispatch_inc,
2827
generate_single_decode_inst,
2928
generate_single_prefill_inst,
3029
)
@@ -48,19 +47,6 @@ def write_if_different(path: Path, content: str) -> None:
4847

4948
path.mkdir(parents=True, exist_ok=True)
5049

51-
write_if_different(
52-
path / "dispatch.inc",
53-
generate_dispatch_inc.get_dispatch_inc_str(
54-
argparse.Namespace(
55-
head_dims=head_dims,
56-
head_dims_sm90=head_dims,
57-
pos_encoding_modes=[0],
58-
use_fp16_qk_reductions=[0],
59-
mask_modes=mask_modes,
60-
)
61-
),
62-
)
63-
6450
write_if_different(
6551
path / "aot_default_additional_params.h",
6652
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),

Diff for: aot_build_utils/generate_dispatch_inc.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
3535
# head dims for sm90
3636
dispatch_head_dims_sm90_entries = "\n".join(
3737
[
38-
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
39-
for _ in args.head_dims_sm90
38+
" _DISPATCH_CASE_U16x2({}, {}, case_var1, case_var2, __VA_ARGS__) \\".format(
39+
qk, vo
40+
)
41+
for qk, vo in args.head_dims_sm90
4042
]
4143
)
42-
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var, ...) \\
44+
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var1, case_var2, ...) \\
4345
{dispatch_head_dims_sm90_entries}
4446
// EOL
4547
"""

Diff for: aot_build_utils/generate_sm90.py

+29-24
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import argparse
1818
from itertools import product
1919
from pathlib import Path
20-
from typing import List
20+
from typing import List, Tuple
2121

2222
from . import (
2323
generate_batch_paged_prefill_sm90_inst,
@@ -33,7 +33,7 @@ def write_if_different(path: Path, content: str) -> None:
3333
path.write_text(content)
3434

3535
path: Path = args.path
36-
head_dims: List[int] = args.head_dims
36+
head_dims: List[Tuple[int, int]] = args.head_dims
3737
pos_encoding_modes: List[int] = args.pos_encoding_modes
3838
use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions
3939
mask_modes: List[int] = args.mask_modes
@@ -58,7 +58,7 @@ def write_if_different(path: Path, content: str) -> None:
5858
# single prefill files
5959
single_prefill_sm90_uris = []
6060
for (
61-
head_dim,
61+
(head_dim_qk, head_dim_vo),
6262
pos_encoding_mode,
6363
use_fp16_qk_reduction,
6464
mask_mode,
@@ -69,15 +69,15 @@ def write_if_different(path: Path, content: str) -> None:
6969
mask_modes,
7070
):
7171
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)):
72-
fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu"
72+
fname = f"single_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu"
7373
content = generate_single_prefill_sm90_inst.get_cu_file_str(
74-
head_dim, # head_dim_qk
75-
head_dim, # head_dim_vo
74+
head_dim_qk,
75+
head_dim_vo,
7676
pos_encoding_mode,
7777
use_fp16_qk_reduction,
7878
mask_mode,
79-
dtype_q, # dtype_q
80-
dtype_kv, # dtype_kv
79+
dtype_q,
80+
dtype_kv,
8181
dtype_q, # dtype_out
8282
)
8383
for use_sliding_window in [True, False]:
@@ -89,8 +89,8 @@ def write_if_different(path: Path, content: str) -> None:
8989
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
9090
f"dtype_kv_{dtype_kv}_"
9191
f"dtype_o_{dtype_q}_"
92-
f"head_dim_qk_{head_dim}_"
93-
f"head_dim_vo_{head_dim}_"
92+
f"head_dim_qk_{head_dim_qk}_"
93+
f"head_dim_vo_{head_dim_vo}_"
9494
f"posenc_{pos_encoding_mode}_"
9595
f"use_swa_{use_sliding_window}_"
9696
f"use_logits_cap_{use_logits_soft_cap}_"
@@ -101,7 +101,7 @@ def write_if_different(path: Path, content: str) -> None:
101101
# batch prefill files
102102
batch_prefill_sm90_uris = []
103103
for (
104-
head_dim,
104+
(head_dim_qk, head_dim_vo),
105105
pos_encoding_mode,
106106
use_fp16_qk_reduction,
107107
mask_mode,
@@ -114,29 +114,29 @@ def write_if_different(path: Path, content: str) -> None:
114114
idtypes,
115115
):
116116
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)):
117-
fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
117+
fname = f"batch_paged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
118118
content = generate_batch_paged_prefill_sm90_inst.get_cu_file_str(
119-
head_dim, # head_dim_qk
120-
head_dim, # head_dim_vo
119+
head_dim_qk,
120+
head_dim_vo,
121121
pos_encoding_mode,
122122
use_fp16_qk_reduction,
123123
mask_mode,
124-
dtype_q, # dtype_q
125-
dtype_kv, # dtype_kv
124+
dtype_q,
125+
dtype_kv,
126126
dtype_q, # dtype_out
127127
idtype,
128128
)
129129
write_if_different(path / fname, content)
130130

131-
fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
131+
fname = f"batch_ragged_prefill_head_qk_{head_dim_qk}_head_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu"
132132
content = generate_batch_ragged_prefill_sm90_inst.get_cu_file_str(
133-
head_dim, # head_dim_qk
134-
head_dim, # head_dim_vo
133+
head_dim_qk,
134+
head_dim_vo,
135135
pos_encoding_mode,
136136
use_fp16_qk_reduction,
137137
mask_mode,
138-
dtype_q, # dtype_q
139-
dtype_kv, # dtype_kv
138+
dtype_q,
139+
dtype_kv,
140140
dtype_q, # dtype_out
141141
idtype,
142142
)
@@ -152,8 +152,8 @@ def write_if_different(path: Path, content: str) -> None:
152152
f"dtype_kv_{dtype_kv}_"
153153
f"dtype_o_{dtype_q}_"
154154
f"dtype_idx_{idtype}_"
155-
f"head_dim_qk_{head_dim}_"
156-
f"head_dim_vo_{head_dim}_"
155+
f"head_dim_qk_{head_dim_qk}_"
156+
f"head_dim_vo_{head_dim_vo}_"
157157
f"posenc_{pos_encoding_mode}_"
158158
f"use_swa_{sliding_window}_"
159159
f"use_logits_cap_{logits_soft_cap}_"
@@ -169,7 +169,11 @@ def write_if_different(path: Path, content: str) -> None:
169169
"--path", type=Path, required=True, help="Path to the dispatch inc file"
170170
)
171171
parser.add_argument(
172-
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"
172+
"--head_dims",
173+
type=str,
174+
required=True,
175+
nargs="+",
176+
help="Head dimensions in format of 'head_dim_qk,head_dim_vo'",
173177
)
174178
parser.add_argument(
175179
"--pos_encoding_modes",
@@ -207,4 +211,5 @@ def write_if_different(path: Path, content: str) -> None:
207211
help="Enable bf16",
208212
)
209213
args = parser.parse_args()
214+
args.head_dims = [tuple(map(int, x.split(","))) for x in args.head_dims]
210215
get_sm90_instantiation_cu(args)

Diff for: csrc/aot_extension_utils.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
#define DISPATCH_head_dim(expr, const_expr, ...) \
2020
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__))
2121

22-
#define DISPATCH_head_dim_sm90(expr, const_expr, ...) \
23-
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim_sm90(const_expr, __VA_ARGS__))
22+
#define DISPATCH_head_dim_sm90(expr1, expr2, const_expr1, const_expr2, ...) \
23+
_DISPATCH_SWITCH_U16x2("head_dim_qk", "head_dim_vo", expr1, expr2, \
24+
_DISPATCH_CASES_head_dim_sm90(const_expr1, const_expr2, __VA_ARGS__))
2425

2526
#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \
2627
_DISPATCH_SWITCH("positional encoding mode", expr, \

Diff for: csrc/batch_prefill_sm90_config.inc

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ using IdType = int32_t;
4141
using DTypeO = DTypeQ; \
4242
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
4343
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
44-
return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \
45-
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
44+
return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \
4645
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
4746
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
4847
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \

Diff for: csrc/pytorch_extension_utils.h

+20
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,32 @@
122122
} \
123123
}()
124124

125+
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
126+
[&]() -> bool { \
127+
switch (pack_u16(cond1, cond2)) { \
128+
__VA_ARGS__ \
129+
default: \
130+
std::ostringstream oss; \
131+
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" \
132+
<< int(cond1) << ", " << int(cond2) << ")"; \
133+
TORCH_CHECK(false, oss.str()); \
134+
return false; \
135+
} \
136+
}()
137+
125138
#define _DISPATCH_CASE(case_expr, case_var, ...) \
126139
case case_expr: { \
127140
constexpr auto case_var = case_expr; \
128141
return __VA_ARGS__(); \
129142
}
130143

144+
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
145+
case pack_u16(case_expr1, case_expr2): { \
146+
constexpr auto case_var1 = case_expr1; \
147+
constexpr auto case_var2 = case_expr2; \
148+
return __VA_ARGS__(); \
149+
}
150+
131151
#define DISPATCH_BOOL(expr, const_expr, ...) \
132152
[&]() -> bool { \
133153
if (expr) { \

Diff for: csrc/single_prefill_sm90_config.inc

+1-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ using IdType = int32_t;
3939
using DTypeKV = DTypeQ; \
4040
using DTypeO = DTypeQ; \
4141
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; \
42-
return DISPATCH_head_dim_sm90(head_dim_qk, HEAD_DIM_QK, [&] { \
43-
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
42+
return DISPATCH_head_dim_sm90(head_dim_qk, head_dim_vo, HEAD_DIM_QK, HEAD_DIM_VO, [&] { \
4443
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
4544
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
4645
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \

Diff for: setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929

3030
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
3131
head_dims = list(map(int, head_dims))
32-
SM90_ALLOWED_HEAD_DIMS = {64, 128, 256}
33-
head_dims_sm90 = [d for d in head_dims if d in SM90_ALLOWED_HEAD_DIMS]
32+
SM90_ALLOWED_HEAD_DIMS = {(64, 64), (128, 128), (256, 256), (192, 128)}
33+
head_dims_sm90 = list(SM90_ALLOWED_HEAD_DIMS) # No support for custom head dims for SM90
3434

3535
mask_modes = [0, 1, 2]
3636

0 commit comments

Comments
 (0)