@@ -75,22 +75,22 @@ def get_single_decode_module(*args):
75
75
76
76
# torch library for single_decode_with_kv_cache
77
77
78
- @register_custom_op (f"flashinfer::{ uri } _run" , mutates_args = ("tmp" ,))
78
+ @register_custom_op (f"flashinfer::{ uri } _run" , mutates_args = ("tmp" , "o" ))
79
79
def run_single_decode (
80
80
q : torch .Tensor ,
81
81
k : torch .Tensor ,
82
82
v : torch .Tensor ,
83
83
tmp : torch .Tensor ,
84
+ o : torch .Tensor ,
84
85
alibi_slopes : Optional [torch .Tensor ],
85
86
kv_layout_code : int ,
86
87
window_left : int ,
87
88
logits_soft_cap : float ,
88
89
sm_scale : float ,
89
90
rope_scale : float ,
90
91
rope_theta : float ,
91
- ) -> torch . Tensor :
92
+ ) -> None :
92
93
with q .device as device :
93
- o = torch .empty_like (q )
94
94
run_func (
95
95
q ,
96
96
k ,
@@ -107,23 +107,22 @@ def run_single_decode(
107
107
get_cuda_stream (device ),
108
108
)
109
109
110
- return o
111
-
112
110
@register_fake_op (f"flashinfer::{ uri } _run" )
113
111
def _fake_run_single_decode (
114
112
q : torch .Tensor ,
115
113
k : torch .Tensor ,
116
114
v : torch .Tensor ,
117
115
tmp : torch .Tensor ,
116
+ o : torch .Tensor ,
118
117
alibi_slopes : Optional [torch .Tensor ],
119
118
kv_layout_code : int ,
120
119
window_left : int ,
121
120
logits_soft_cap : float ,
122
121
sm_scale : float ,
123
122
rope_scale : float ,
124
123
rope_theta : float ,
125
- ) -> torch . Tensor :
126
- return torch . empty_like ( q )
124
+ ) -> None :
125
+ pass
127
126
128
127
# Register the module.
129
128
_single_decode_modules [args ] = SimpleNamespace (run = run_single_decode )
@@ -145,6 +144,7 @@ def get_batch_decode_jit_module(module_name: str, jit_module: Any):
145
144
"int_workspace_buffer" ,
146
145
"paged_k_cache" ,
147
146
"paged_v_cache" ,
147
+ "o" ,
148
148
"maybe_lse" ,
149
149
),
150
150
)
@@ -158,13 +158,13 @@ def run_batch_decode(
158
158
paged_kv_indptr : torch .Tensor ,
159
159
paged_kv_indices : torch .Tensor ,
160
160
paged_kv_last_page_len : torch .Tensor ,
161
+ o : torch .Tensor ,
161
162
maybe_lse : Optional [torch .Tensor ],
162
163
kv_layout_code : int ,
163
164
window_left : int ,
164
165
* args ,
165
- ) -> torch . Tensor :
166
+ ) -> None :
166
167
with q .device as device :
167
- o = torch .empty_like (q )
168
168
run_func (
169
169
float_workspace_buffer ,
170
170
int_workspace_buffer ,
@@ -182,7 +182,6 @@ def run_batch_decode(
182
182
* args ,
183
183
get_cuda_stream (device ),
184
184
)
185
- return o
186
185
187
186
@register_fake_op (f"flashinfer::{ module_name } _run" )
188
187
def _fake_run_batch_decode (
@@ -195,12 +194,13 @@ def _fake_run_batch_decode(
195
194
paged_kv_indptr : torch .Tensor ,
196
195
paged_kv_indices : torch .Tensor ,
197
196
paged_kv_last_page_len : torch .Tensor ,
197
+ o : torch .Tensor ,
198
198
maybe_lse : Optional [torch .Tensor ],
199
199
kv_layout_code : int ,
200
200
window_left : int ,
201
201
* args ,
202
- ) -> torch . Tensor :
203
- return torch . empty_like ( q )
202
+ ) -> None :
203
+ pass
204
204
205
205
_batch_decode_jit_modules [module_name ] = SimpleNamespace (
206
206
plan = plan_func ,
@@ -232,6 +232,7 @@ def get_batch_decode_module(*args):
232
232
"int_workspace_buffer" ,
233
233
"paged_k_cache" ,
234
234
"paged_v_cache" ,
235
+ "o" ,
235
236
"maybe_lse" ,
236
237
),
237
238
)
@@ -245,6 +246,7 @@ def run_batch_decode(
245
246
paged_kv_indptr : torch .Tensor ,
246
247
paged_kv_indices : torch .Tensor ,
247
248
paged_kv_last_page_len : torch .Tensor ,
249
+ o : torch .Tensor ,
248
250
maybe_lse : Optional [torch .Tensor ],
249
251
kv_layout_code : int ,
250
252
window_left : int ,
@@ -253,9 +255,8 @@ def run_batch_decode(
253
255
sm_scale : float ,
254
256
rope_scale : float ,
255
257
rope_theta : float ,
256
- ) -> torch . Tensor :
258
+ ) -> None :
257
259
with q .device as device :
258
- o = torch .empty_like (q )
259
260
run_func (
260
261
float_workspace_buffer ,
261
262
int_workspace_buffer ,
@@ -277,7 +278,6 @@ def run_batch_decode(
277
278
1.0 / rope_theta , # rope_rcp_theta
278
279
get_cuda_stream (device ),
279
280
)
280
- return o
281
281
282
282
@register_fake_op (f"flashinfer::{ uri } _run" )
283
283
def _fake_run_batch_decode (
@@ -290,6 +290,7 @@ def _fake_run_batch_decode(
290
290
paged_kv_indptr : torch .Tensor ,
291
291
paged_kv_indices : torch .Tensor ,
292
292
paged_kv_last_page_len : torch .Tensor ,
293
+ o : torch .Tensor ,
293
294
maybe_lse : Optional [torch .Tensor ],
294
295
kv_layout_code : int ,
295
296
window_left : int ,
@@ -298,8 +299,8 @@ def _fake_run_batch_decode(
298
299
sm_scale : float ,
299
300
rope_scale : float ,
300
301
rope_theta : float ,
301
- ) -> torch . Tensor :
302
- return torch . empty_like ( q )
302
+ ) -> None :
303
+ pass
303
304
304
305
# Register the module.
305
306
#
@@ -454,37 +455,37 @@ def single_decode_with_kv_cache(
454
455
num_qo_heads = q .shape [0 ]
455
456
456
457
if use_tensor_cores :
457
- out = (
458
- get_single_prefill_module ("fa2" )(
459
- q .dtype ,
460
- k .dtype ,
461
- q .dtype ,
462
- head_dim ,
463
- PosEncodingMode [pos_encoding_mode ].value ,
464
- window_left != - 1 , # use_sliding_window
465
- logits_soft_cap > 0 , # use_logits_soft_cap
466
- False , # use_fp16_qk_reduction
467
- )
468
- .run (
469
- q .unsqueeze (0 ),
470
- k ,
471
- v ,
472
- tmp ,
473
- None , # maybe_lse,
474
- MaskMode .NON_CAUSAL .value ,
475
- TensorLayout [kv_layout ].value ,
476
- window_left ,
477
- None , # packed_custom_mask
478
- _get_cache_alibi_slopes_buf (num_qo_heads , q .device ),
479
- logits_soft_cap ,
480
- sm_scale ,
481
- rope_scale ,
482
- rope_theta ,
483
- )[0 ]
484
- .squeeze (0 )
458
+ out = torch .empty_like (q .unsqueeze (0 ))
459
+ get_single_prefill_module ("fa2" )(
460
+ q .dtype ,
461
+ k .dtype ,
462
+ q .dtype ,
463
+ head_dim ,
464
+ PosEncodingMode [pos_encoding_mode ].value ,
465
+ window_left != - 1 , # use_sliding_window
466
+ logits_soft_cap > 0 , # use_logits_soft_cap
467
+ False , # use_fp16_qk_reduction
468
+ ).run (
469
+ q .unsqueeze (0 ),
470
+ k ,
471
+ v ,
472
+ tmp ,
473
+ out ,
474
+ None , # maybe_lse,
475
+ MaskMode .NON_CAUSAL .value ,
476
+ TensorLayout [kv_layout ].value ,
477
+ window_left ,
478
+ None , # packed_custom_mask
479
+ _get_cache_alibi_slopes_buf (num_qo_heads , q .device ),
480
+ logits_soft_cap ,
481
+ sm_scale ,
482
+ rope_scale ,
483
+ rope_theta ,
485
484
)
485
+ out = out .squeeze (0 )
486
486
else :
487
- out = get_single_decode_module (
487
+ out = torch .empty_like (q )
488
+ get_single_decode_module (
488
489
q .dtype ,
489
490
k .dtype ,
490
491
q .dtype ,
@@ -497,6 +498,7 @@ def single_decode_with_kv_cache(
497
498
k ,
498
499
v ,
499
500
tmp ,
501
+ out ,
500
502
_get_cache_alibi_slopes_buf (num_qo_heads , q .device ),
501
503
TensorLayout [kv_layout ].value ,
502
504
window_left ,
@@ -1056,6 +1058,7 @@ def run(
1056
1058
(q .size (0 ), q .size (1 )), dtype = torch .float32 , device = q .device
1057
1059
)
1058
1060
1061
+ out = torch .empty_like (q )
1059
1062
if self .use_tensor_cores :
1060
1063
run_args = [
1061
1064
self ._float_workspace_buffer ,
@@ -1068,6 +1071,7 @@ def run(
1068
1071
self ._paged_kv_indptr_buf ,
1069
1072
self ._paged_kv_indices_buf ,
1070
1073
self ._paged_kv_last_page_len_buf ,
1074
+ out ,
1071
1075
lse ,
1072
1076
MaskMode .NON_CAUSAL .value ,
1073
1077
TensorLayout [self ._kv_layout ].value ,
@@ -1087,7 +1091,7 @@ def run(
1087
1091
rope_theta ,
1088
1092
]
1089
1093
1090
- out = self ._cached_module .paged_run (* run_args )
1094
+ self ._cached_module .paged_run (* run_args )
1091
1095
else :
1092
1096
run_args = [
1093
1097
self ._float_workspace_buffer ,
@@ -1099,6 +1103,7 @@ def run(
1099
1103
self ._paged_kv_indptr_buf ,
1100
1104
self ._paged_kv_indices_buf ,
1101
1105
self ._paged_kv_last_page_len_buf ,
1106
+ out ,
1102
1107
lse ,
1103
1108
TensorLayout [self ._kv_layout ].value ,
1104
1109
window_left ,
@@ -1115,7 +1120,7 @@ def run(
1115
1120
rope_theta ,
1116
1121
]
1117
1122
1118
- out = self ._cached_module .run (* run_args )
1123
+ self ._cached_module .run (* run_args )
1119
1124
if v_scale is not None :
1120
1125
out *= v_scale
1121
1126
0 commit comments