17
17
import argparse
18
18
from itertools import product
19
19
from pathlib import Path
20
- from typing import List
20
+ from typing import List , Tuple
21
21
22
22
from . import (
23
23
generate_batch_paged_prefill_sm90_inst ,
@@ -33,7 +33,7 @@ def write_if_different(path: Path, content: str) -> None:
33
33
path .write_text (content )
34
34
35
35
path : Path = args .path
36
- head_dims : List [int ] = args .head_dims
36
+ head_dims : List [Tuple [ int , int ] ] = args .head_dims
37
37
pos_encoding_modes : List [int ] = args .pos_encoding_modes
38
38
use_fp16_qk_reductions : List [int ] = args .use_fp16_qk_reductions
39
39
mask_modes : List [int ] = args .mask_modes
@@ -58,7 +58,7 @@ def write_if_different(path: Path, content: str) -> None:
58
58
# single prefill files
59
59
single_prefill_sm90_uris = []
60
60
for (
61
- head_dim ,
61
+ ( head_dim_qk , head_dim_vo ) ,
62
62
pos_encoding_mode ,
63
63
use_fp16_qk_reduction ,
64
64
mask_mode ,
@@ -69,15 +69,15 @@ def write_if_different(path: Path, content: str) -> None:
69
69
mask_modes ,
70
70
):
71
71
for dtype_q , dtype_kv in list (zip (prefill_dtypes , prefill_dtypes )):
72
- fname = f"single_prefill_head_ { head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _sm90.cu"
72
+ fname = f"single_prefill_head_qk_ { head_dim_qk } _head_vo_ { head_dim_vo } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _sm90.cu"
73
73
content = generate_single_prefill_sm90_inst .get_cu_file_str (
74
- head_dim , # head_dim_qk
75
- head_dim , # head_dim_vo
74
+ head_dim_qk ,
75
+ head_dim_vo ,
76
76
pos_encoding_mode ,
77
77
use_fp16_qk_reduction ,
78
78
mask_mode ,
79
- dtype_q , # dtype_q
80
- dtype_kv , # dtype_kv
79
+ dtype_q ,
80
+ dtype_kv ,
81
81
dtype_q , # dtype_out
82
82
)
83
83
for use_sliding_window in [True , False ]:
@@ -89,8 +89,8 @@ def write_if_different(path: Path, content: str) -> None:
89
89
f"single_prefill_with_kv_cache_dtype_q_{ dtype_q } _"
90
90
f"dtype_kv_{ dtype_kv } _"
91
91
f"dtype_o_{ dtype_q } _"
92
- f"head_dim_qk_{ head_dim } _"
93
- f"head_dim_vo_{ head_dim } _"
92
+ f"head_dim_qk_{ head_dim_qk } _"
93
+ f"head_dim_vo_{ head_dim_vo } _"
94
94
f"posenc_{ pos_encoding_mode } _"
95
95
f"use_swa_{ use_sliding_window } _"
96
96
f"use_logits_cap_{ use_logits_soft_cap } _"
@@ -101,7 +101,7 @@ def write_if_different(path: Path, content: str) -> None:
101
101
# batch prefill files
102
102
batch_prefill_sm90_uris = []
103
103
for (
104
- head_dim ,
104
+ ( head_dim_qk , head_dim_vo ) ,
105
105
pos_encoding_mode ,
106
106
use_fp16_qk_reduction ,
107
107
mask_mode ,
@@ -114,29 +114,29 @@ def write_if_different(path: Path, content: str) -> None:
114
114
idtypes ,
115
115
):
116
116
for dtype_q , dtype_kv in list (zip (prefill_dtypes , prefill_dtypes )):
117
- fname = f"batch_paged_prefill_head_qk_{ head_dim } _head_vo_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } _sm90.cu"
117
+ fname = f"batch_paged_prefill_head_qk_{ head_dim_qk } _head_vo_{ head_dim_vo } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } _sm90.cu"
118
118
content = generate_batch_paged_prefill_sm90_inst .get_cu_file_str (
119
- head_dim , # head_dim_qk
120
- head_dim , # head_dim_vo
119
+ head_dim_qk ,
120
+ head_dim_vo ,
121
121
pos_encoding_mode ,
122
122
use_fp16_qk_reduction ,
123
123
mask_mode ,
124
- dtype_q , # dtype_q
125
- dtype_kv , # dtype_kv
124
+ dtype_q ,
125
+ dtype_kv ,
126
126
dtype_q , # dtype_out
127
127
idtype ,
128
128
)
129
129
write_if_different (path / fname , content )
130
130
131
- fname = f"batch_ragged_prefill_head_qk_{ head_dim } _head_vo_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } _sm90.cu"
131
+ fname = f"batch_ragged_prefill_head_qk_{ head_dim_qk } _head_vo_{ head_dim_vo } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } _sm90.cu"
132
132
content = generate_batch_ragged_prefill_sm90_inst .get_cu_file_str (
133
- head_dim , # head_dim_qk
134
- head_dim , # head_dim_vo
133
+ head_dim_qk ,
134
+ head_dim_vo ,
135
135
pos_encoding_mode ,
136
136
use_fp16_qk_reduction ,
137
137
mask_mode ,
138
- dtype_q , # dtype_q
139
- dtype_kv , # dtype_kv
138
+ dtype_q ,
139
+ dtype_kv ,
140
140
dtype_q , # dtype_out
141
141
idtype ,
142
142
)
@@ -152,8 +152,8 @@ def write_if_different(path: Path, content: str) -> None:
152
152
f"dtype_kv_{ dtype_kv } _"
153
153
f"dtype_o_{ dtype_q } _"
154
154
f"dtype_idx_{ idtype } _"
155
- f"head_dim_qk_{ head_dim } _"
156
- f"head_dim_vo_{ head_dim } _"
155
+ f"head_dim_qk_{ head_dim_qk } _"
156
+ f"head_dim_vo_{ head_dim_vo } _"
157
157
f"posenc_{ pos_encoding_mode } _"
158
158
f"use_swa_{ sliding_window } _"
159
159
f"use_logits_cap_{ logits_soft_cap } _"
@@ -169,7 +169,11 @@ def write_if_different(path: Path, content: str) -> None:
169
169
"--path" , type = Path , required = True , help = "Path to the dispatch inc file"
170
170
)
171
171
parser .add_argument (
172
- "--head_dims" , type = int , required = True , nargs = "+" , help = "Head dimensions"
172
+ "--head_dims" ,
173
+ type = str ,
174
+ required = True ,
175
+ nargs = "+" ,
176
+ help = "Head dimensions in format of 'head_dim_qk,head_dim_vo'" ,
173
177
)
174
178
parser .add_argument (
175
179
"--pos_encoding_modes" ,
@@ -207,4 +211,5 @@ def write_if_different(path: Path, content: str) -> None:
207
211
help = "Enable bf16" ,
208
212
)
209
213
args = parser .parse_args ()
214
+ args .head_dims = [tuple (map (int , x .split ("," ))) for x in args .head_dims ]
210
215
get_sm90_instantiation_cu (args )
0 commit comments