@@ -73,13 +73,13 @@ def backend_module(*args):
73
73
if backend == "fa2" :
74
74
_kernels = torch .ops .flashinfer_kernels
75
75
76
- run_func = _kernels .single_prefill_with_kv_cache
76
+ run_func = _kernels .single_prefill_with_kv_cache . default
77
77
else :
78
78
_kernels_sm90 = torch .ops .flashinfer_kernels_sm90
79
79
80
- run_func = _kernels_sm90 .single_prefill_with_kv_cache_sm90
80
+ run_func = _kernels_sm90 .single_prefill_with_kv_cache_sm90 . default
81
81
else :
82
- run_func = gen_single_prefill_module (backend , * args ).run
82
+ run_func = gen_single_prefill_module (backend , * args ).run . default
83
83
84
84
# torch library for single_prefill_with_kv_cache
85
85
@@ -180,24 +180,30 @@ def backend_module(*args):
180
180
if backend == "fa2" :
181
181
_kernels = torch .ops .flashinfer_kernels
182
182
183
- plan_func = _kernels .batch_prefill_with_kv_cache_plan
184
- ragged_run_func = _kernels .batch_prefill_with_ragged_kv_cache_run
185
- paged_run_func = _kernels .batch_prefill_with_paged_kv_cache_run
183
+ plan_func = _kernels .batch_prefill_with_kv_cache_plan .default
184
+ ragged_run_func = (
185
+ _kernels .batch_prefill_with_ragged_kv_cache_run .default
186
+ )
187
+ paged_run_func = (
188
+ _kernels .batch_prefill_with_paged_kv_cache_run .default
189
+ )
186
190
else :
187
191
_kernels_sm90 = torch .ops .flashinfer_kernels_sm90
188
192
189
- plan_func = _kernels_sm90 .batch_prefill_with_kv_cache_sm90_plan
193
+ plan_func = (
194
+ _kernels_sm90 .batch_prefill_with_kv_cache_sm90_plan .default
195
+ )
190
196
ragged_run_func = (
191
- _kernels_sm90 .batch_prefill_with_ragged_kv_cache_sm90_run
197
+ _kernels_sm90 .batch_prefill_with_ragged_kv_cache_sm90_run . default
192
198
)
193
199
paged_run_func = (
194
- _kernels_sm90 .batch_prefill_with_paged_kv_cache_sm90_run
200
+ _kernels_sm90 .batch_prefill_with_paged_kv_cache_sm90_run . default
195
201
)
196
202
else :
197
203
module = gen_batch_prefill_module (backend , * args )
198
- plan_func = module .plan
199
- ragged_run_func = module .ragged_run
200
- paged_run_func = module .paged_run
204
+ plan_func = module .plan . default
205
+ ragged_run_func = module .ragged_run . default
206
+ paged_run_func = module .paged_run . default
201
207
202
208
# torch library for ragged_run
203
209
@@ -437,9 +443,9 @@ def get_batch_prefill_jit_module(module_name: str, jit_module: Any):
437
443
if module_name in _batch_prefill_jit_modules :
438
444
return _batch_prefill_jit_modules [module_name ]
439
445
440
- plan_func = jit_module .plan
441
- ragged_run_func = jit_module .ragged_run
442
- paged_run_func = jit_module .paged_run
446
+ plan_func = jit_module .plan . default
447
+ ragged_run_func = jit_module .ragged_run . default
448
+ paged_run_func = jit_module .paged_run . default
443
449
444
450
# torch library for ragged_run
445
451
@register_custom_op (
@@ -611,7 +617,7 @@ def single_prefill_with_kv_cache_with_jit_module(
611
617
lse = torch .empty (
612
618
(q .size (0 ), q .size (1 )), dtype = torch .float32 , device = device
613
619
)
614
- jit_module .run (
620
+ jit_module .run . default (
615
621
q ,
616
622
k ,
617
623
v ,
0 commit comments