Skip to content

Commit 472b771

Browse files
committed
upd
1 parent 93e1a26 commit 472b771

File tree

4 files changed

+100
-76
lines changed

4 files changed

+100
-76
lines changed

flashinfer/decode.py

+53-48
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,22 @@ def get_single_decode_module(*args):
7575

7676
# torch library for single_decode_with_kv_cache
7777

78-
@register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp",))
78+
@register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp", "o"))
7979
def run_single_decode(
8080
q: torch.Tensor,
8181
k: torch.Tensor,
8282
v: torch.Tensor,
8383
tmp: torch.Tensor,
84+
o: torch.Tensor,
8485
alibi_slopes: Optional[torch.Tensor],
8586
kv_layout_code: int,
8687
window_left: int,
8788
logits_soft_cap: float,
8889
sm_scale: float,
8990
rope_scale: float,
9091
rope_theta: float,
91-
) -> torch.Tensor:
92+
) -> None:
9293
with q.device as device:
93-
o = torch.empty_like(q)
9494
run_func(
9595
q,
9696
k,
@@ -107,23 +107,22 @@ def run_single_decode(
107107
get_cuda_stream(device),
108108
)
109109

110-
return o
111-
112110
@register_fake_op(f"flashinfer::{uri}_run")
113111
def _fake_run_single_decode(
114112
q: torch.Tensor,
115113
k: torch.Tensor,
116114
v: torch.Tensor,
117115
tmp: torch.Tensor,
116+
o: torch.Tensor,
118117
alibi_slopes: Optional[torch.Tensor],
119118
kv_layout_code: int,
120119
window_left: int,
121120
logits_soft_cap: float,
122121
sm_scale: float,
123122
rope_scale: float,
124123
rope_theta: float,
125-
) -> torch.Tensor:
126-
return torch.empty_like(q)
124+
) -> None:
125+
pass
127126

128127
# Register the module.
129128
_single_decode_modules[args] = SimpleNamespace(run=run_single_decode)
@@ -145,6 +144,7 @@ def get_batch_decode_jit_module(module_name: str, jit_module: Any):
145144
"int_workspace_buffer",
146145
"paged_k_cache",
147146
"paged_v_cache",
147+
"o",
148148
"maybe_lse",
149149
),
150150
)
@@ -158,13 +158,13 @@ def run_batch_decode(
158158
paged_kv_indptr: torch.Tensor,
159159
paged_kv_indices: torch.Tensor,
160160
paged_kv_last_page_len: torch.Tensor,
161+
o: torch.Tensor,
161162
maybe_lse: Optional[torch.Tensor],
162163
kv_layout_code: int,
163164
window_left: int,
164165
*args,
165-
) -> torch.Tensor:
166+
) -> None:
166167
with q.device as device:
167-
o = torch.empty_like(q)
168168
run_func(
169169
float_workspace_buffer,
170170
int_workspace_buffer,
@@ -182,7 +182,6 @@ def run_batch_decode(
182182
*args,
183183
get_cuda_stream(device),
184184
)
185-
return o
186185

187186
@register_fake_op(f"flashinfer::{module_name}_run")
188187
def _fake_run_batch_decode(
@@ -195,12 +194,13 @@ def _fake_run_batch_decode(
195194
paged_kv_indptr: torch.Tensor,
196195
paged_kv_indices: torch.Tensor,
197196
paged_kv_last_page_len: torch.Tensor,
197+
o: torch.Tensor,
198198
maybe_lse: Optional[torch.Tensor],
199199
kv_layout_code: int,
200200
window_left: int,
201201
*args,
202-
) -> torch.Tensor:
203-
return torch.empty_like(q)
202+
) -> None:
203+
pass
204204

205205
_batch_decode_jit_modules[module_name] = SimpleNamespace(
206206
plan=plan_func,
@@ -232,6 +232,7 @@ def get_batch_decode_module(*args):
232232
"int_workspace_buffer",
233233
"paged_k_cache",
234234
"paged_v_cache",
235+
"o",
235236
"maybe_lse",
236237
),
237238
)
@@ -245,6 +246,7 @@ def run_batch_decode(
245246
paged_kv_indptr: torch.Tensor,
246247
paged_kv_indices: torch.Tensor,
247248
paged_kv_last_page_len: torch.Tensor,
249+
o: torch.Tensor,
248250
maybe_lse: Optional[torch.Tensor],
249251
kv_layout_code: int,
250252
window_left: int,
@@ -253,9 +255,8 @@ def run_batch_decode(
253255
sm_scale: float,
254256
rope_scale: float,
255257
rope_theta: float,
256-
) -> torch.Tensor:
258+
) -> None:
257259
with q.device as device:
258-
o = torch.empty_like(q)
259260
run_func(
260261
float_workspace_buffer,
261262
int_workspace_buffer,
@@ -277,7 +278,6 @@ def run_batch_decode(
277278
1.0 / rope_theta, # rope_rcp_theta
278279
get_cuda_stream(device),
279280
)
280-
return o
281281

282282
@register_fake_op(f"flashinfer::{uri}_run")
283283
def _fake_run_batch_decode(
@@ -290,6 +290,7 @@ def _fake_run_batch_decode(
290290
paged_kv_indptr: torch.Tensor,
291291
paged_kv_indices: torch.Tensor,
292292
paged_kv_last_page_len: torch.Tensor,
293+
o: torch.Tensor,
293294
maybe_lse: Optional[torch.Tensor],
294295
kv_layout_code: int,
295296
window_left: int,
@@ -298,8 +299,8 @@ def _fake_run_batch_decode(
298299
sm_scale: float,
299300
rope_scale: float,
300301
rope_theta: float,
301-
) -> torch.Tensor:
302-
return torch.empty_like(q)
302+
) -> None:
303+
pass
303304

304305
# Register the module.
305306
#
@@ -454,37 +455,37 @@ def single_decode_with_kv_cache(
454455
num_qo_heads = q.shape[0]
455456

456457
if use_tensor_cores:
457-
out = (
458-
get_single_prefill_module("fa2")(
459-
q.dtype,
460-
k.dtype,
461-
q.dtype,
462-
head_dim,
463-
PosEncodingMode[pos_encoding_mode].value,
464-
window_left != -1, # use_sliding_window
465-
logits_soft_cap > 0, # use_logits_soft_cap
466-
False, # use_fp16_qk_reduction
467-
)
468-
.run(
469-
q.unsqueeze(0),
470-
k,
471-
v,
472-
tmp,
473-
None, # maybe_lse,
474-
MaskMode.NON_CAUSAL.value,
475-
TensorLayout[kv_layout].value,
476-
window_left,
477-
None, # packed_custom_mask
478-
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
479-
logits_soft_cap,
480-
sm_scale,
481-
rope_scale,
482-
rope_theta,
483-
)[0]
484-
.squeeze(0)
458+
out = torch.empty_like(q.unsqueeze(0))
459+
get_single_prefill_module("fa2")(
460+
q.dtype,
461+
k.dtype,
462+
q.dtype,
463+
head_dim,
464+
PosEncodingMode[pos_encoding_mode].value,
465+
window_left != -1, # use_sliding_window
466+
logits_soft_cap > 0, # use_logits_soft_cap
467+
False, # use_fp16_qk_reduction
468+
).run(
469+
q.unsqueeze(0),
470+
k,
471+
v,
472+
tmp,
473+
out,
474+
None, # maybe_lse,
475+
MaskMode.NON_CAUSAL.value,
476+
TensorLayout[kv_layout].value,
477+
window_left,
478+
None, # packed_custom_mask
479+
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
480+
logits_soft_cap,
481+
sm_scale,
482+
rope_scale,
483+
rope_theta,
485484
)
485+
out = out.squeeze(0)
486486
else:
487-
out = get_single_decode_module(
487+
out = torch.empty_like(q)
488+
get_single_decode_module(
488489
q.dtype,
489490
k.dtype,
490491
q.dtype,
@@ -497,6 +498,7 @@ def single_decode_with_kv_cache(
497498
k,
498499
v,
499500
tmp,
501+
out,
500502
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
501503
TensorLayout[kv_layout].value,
502504
window_left,
@@ -1056,6 +1058,7 @@ def run(
10561058
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
10571059
)
10581060

1061+
out = torch.empty_like(q)
10591062
if self.use_tensor_cores:
10601063
run_args = [
10611064
self._float_workspace_buffer,
@@ -1068,6 +1071,7 @@ def run(
10681071
self._paged_kv_indptr_buf,
10691072
self._paged_kv_indices_buf,
10701073
self._paged_kv_last_page_len_buf,
1074+
out,
10711075
lse,
10721076
MaskMode.NON_CAUSAL.value,
10731077
TensorLayout[self._kv_layout].value,
@@ -1087,7 +1091,7 @@ def run(
10871091
rope_theta,
10881092
]
10891093

1090-
out = self._cached_module.paged_run(*run_args)
1094+
self._cached_module.paged_run(*run_args)
10911095
else:
10921096
run_args = [
10931097
self._float_workspace_buffer,
@@ -1099,6 +1103,7 @@ def run(
10991103
self._paged_kv_indptr_buf,
11001104
self._paged_kv_indices_buf,
11011105
self._paged_kv_last_page_len_buf,
1106+
out,
11021107
lse,
11031108
TensorLayout[self._kv_layout].value,
11041109
window_left,
@@ -1115,7 +1120,7 @@ def run(
11151120
rope_theta,
11161121
]
11171122

1118-
out = self._cached_module.run(*run_args)
1123+
self._cached_module.run(*run_args)
11191124
if v_scale is not None:
11201125
out *= v_scale
11211126

flashinfer/jit/core.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def load_cuda_ops(
9393
"--threads",
9494
"4",
9595
"-use_fast_math",
96+
"-DFLASHINFER_ENABLE_F16",
9697
"-DFLASHINFER_ENABLE_BF16",
9798
"-DFLASHINFER_ENABLE_FP8_E4M3",
9899
"-DFLASHINFER_ENABLE_FP8_E5M2",

0 commit comments

Comments
 (0)