-
Notifications
You must be signed in to change notification settings - Fork 286
/
Copy pathgenerate.py
333 lines (312 loc) · 12.2 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
from itertools import product
from pathlib import Path
from typing import List
from . import (
generate_aot_default_additional_params_header,
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_dispatch_inc,
generate_single_decode_inst,
generate_single_prefill_inst,
)
def get_instantiation_cu(args: argparse.Namespace) -> List[str]:
def write_if_different(path: Path, content: str) -> None:
if path.exists() and path.read_text() == content:
return
path.write_text(content)
path: Path = args.path
head_dims: List[int] = args.head_dims
pos_encoding_modes: List[int] = args.pos_encoding_modes
use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions
mask_modes: List[int] = args.mask_modes
enable_f16: bool = args.enable_f16
enable_bf16: bool = args.enable_bf16
enable_fp8_e4m3: bool = args.enable_fp8_e4m3
enable_fp8_e5m2: bool = args.enable_fp8_e5m2
path.mkdir(parents=True, exist_ok=True)
write_if_different(
path / "dispatch.inc",
generate_dispatch_inc.get_dispatch_inc_str(
argparse.Namespace(
head_dims=head_dims,
head_dims_sm90=head_dims,
pos_encoding_modes=[0],
use_fp16_qk_reductions=[0],
mask_modes=mask_modes,
)
),
)
write_if_different(
path / "aot_default_additional_params.h",
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
)
idtypes = ["i32"]
prefill_dtypes = []
decode_dtypes = []
fp16_dtypes = []
fp8_dtypes = []
if enable_f16:
prefill_dtypes.append("f16")
decode_dtypes.append("f16")
fp16_dtypes.append("f16")
if enable_bf16:
prefill_dtypes.append("bf16")
decode_dtypes.append("bf16")
fp16_dtypes.append("bf16")
if enable_fp8_e4m3:
fp8_dtypes.extend(["e4m3"])
decode_dtypes.extend(["e4m3"])
if enable_fp8_e5m2:
fp8_dtypes.extend(["e5m2"])
decode_dtypes.extend(["e5m2"])
single_decode_uris = []
# single decode files
for head_dim, pos_encoding_mode in product(head_dims, pos_encoding_modes):
for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list(
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"single_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu"
content = generate_single_decode_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
dtype_out,
)
for use_sliding_window in [True, False]:
for use_logits_soft_cap in [True, False]:
single_decode_uris.append(
f"single_decode_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
)
write_if_different(path / fname, content)
# batch decode files
batch_decode_uris = []
for (
head_dim,
pos_encoding_mode,
) in product(
head_dims,
pos_encoding_modes,
):
for idtype in idtypes:
for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list(
product(fp16_dtypes, fp8_dtypes)
):
dtype_out = dtype_q
fname = f"batch_paged_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu"
content = generate_batch_paged_decode_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
)
for use_sliding_window in [True, False]:
for use_logits_soft_cap in [True, False]:
batch_decode_uris.append(
f"batch_decode_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_out}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
)
write_if_different(path / fname, content)
# single prefill files
single_prefill_uris = []
for (
head_dim,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
) in product(
head_dims,
pos_encoding_modes,
use_fp16_qk_reductions,
mask_modes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"single_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
content = generate_single_prefill_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
)
for use_sliding_window in [True, False]:
for use_logits_soft_cap in [True, False]:
if (
mask_mode == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
single_prefill_uris.append(
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)
write_if_different(path / fname, content)
# batch prefill files
batch_prefill_uris = []
for (
head_dim,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
idtype,
) in product(
head_dims,
pos_encoding_modes,
use_fp16_qk_reductions,
mask_modes,
idtypes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_paged_prefill_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)
fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_batch_ragged_prefill_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)
for sliding_window in [True, False]:
for logits_soft_cap in [True, False]:
if (
mask_mode == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
batch_prefill_uris.append(
f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)
return (
single_decode_uris
+ batch_decode_uris
+ single_prefill_uris
+ batch_prefill_uris
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Generate cuda files")
parser.add_argument(
"--path", type=Path, required=True, help="Path to the dispatch inc file"
)
parser.add_argument(
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"
)
parser.add_argument(
"--pos_encoding_modes",
type=int,
required=True,
nargs="+",
help="Position encoding modes",
)
parser.add_argument(
"--use_fp16_qk_reductions",
type=lambda x: x if isinstance(x, int) else int(x.lower() == "true"),
required=True,
nargs="+",
help="Allow fp16 qk reductions",
)
parser.add_argument(
"--mask_modes",
type=int,
required=True,
nargs="+",
help="Mask modes",
)
parser.add_argument(
"--enable_f16",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
required=True,
nargs="+",
help="Enable fp16",
)
parser.add_argument(
"--enable_bf16",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
required=True,
nargs="+",
help="Enable bf16",
)
parser.add_argument(
"--enable_fp8_e4m3",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
default=True,
nargs="+",
help="Enable fp8_e4m3",
)
parser.add_argument(
"--enable_fp8_e5m2",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
default=True,
nargs="+",
help="Enable fp8_e5m2",
)
args = parser.parse_args()
get_instantiation_cu(args)