Skip to content

Commit 86b12ad

Browse files
yzh119youkaichao
andauthored
perf: reduce torch.library dispatch overhead (#968)
This pr fixes #960, #764 refactors the codebase to use torch.library which introduces some cpu-side overhead when operators are not captured by CUDAGraph. As suggested by @youkaichao , we can bypass the pytorch dispatcher by `torch.ops.namespace.op_name.default`. Co-authored-by: Kaichao You <[email protected]>
1 parent bb49fac commit 86b12ad

12 files changed

+67
-59
lines changed

flashinfer/activation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_act_and_mul_module(act_func_name: str):
6666

6767
# torch library for act_and_mul
6868
fname = f"{act_func_name}_and_mul"
69-
fn = getattr(module, fname)
69+
fn = getattr(module, fname).default
7070

7171
@register_custom_op(f"flashinfer::{fname}", mutates_args=("out",))
7272
def _act_and_mul(

flashinfer/cascade.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def merge_state(
9898
s_b = s_b.to(torch.float32)
9999
v_merged = torch.empty_like(v_a)
100100
s_merged = torch.empty_like(s_a)
101-
get_cascade_module().merge_state(
101+
get_cascade_module().merge_state.default(
102102
v_a, s_a, v_b, s_b, v_merged, s_merged, get_cuda_stream(device)
103103
)
104104
return v_merged, s_merged
@@ -160,7 +160,7 @@ def merge_state_in_place(
160160
with v.device as device: # device guard
161161
s = s.to(torch.float32)
162162
s_other = s_other.to(torch.float32)
163-
get_cascade_module().merge_state_in_place(
163+
get_cascade_module().merge_state_in_place.default(
164164
v, s, v_other, s_other, mask, get_cuda_stream(device)
165165
)
166166

@@ -221,7 +221,7 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.
221221
seq_len, num_heads, head_dim, dtype=v.dtype, device=device
222222
)
223223
s_merged = torch.empty(seq_len, num_heads, dtype=torch.float32, device=device)
224-
get_cascade_module().merge_states(
224+
get_cascade_module().merge_states.default(
225225
v, s, v_merged, s_merged, get_cuda_stream(device)
226226
)
227227
return v_merged, s_merged

flashinfer/decode.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def get_single_decode_module(*args):
6969
if has_prebuilt_ops and uri in prebuilt_ops_uri:
7070
_kernels = torch.ops.flashinfer_kernels
7171

72-
run_func = _kernels.single_decode_with_kv_cache
72+
run_func = _kernels.single_decode_with_kv_cache.default
7373
else:
74-
run_func = gen_single_decode_module(*args).run
74+
run_func = gen_single_decode_module(*args).run.default
7575

7676
# torch library for single_decode_with_kv_cache
7777

@@ -134,8 +134,8 @@ def get_batch_decode_jit_module(module_name: str, jit_module: Any):
134134
if module_name in _batch_decode_jit_modules:
135135
return _batch_decode_jit_modules[module_name]
136136

137-
plan_func = jit_module.plan
138-
run_func = jit_module.run
137+
plan_func = jit_module.plan.default
138+
run_func = jit_module.run.default
139139

140140
@register_custom_op(
141141
f"flashinfer::{module_name}_run",
@@ -216,12 +216,12 @@ def get_batch_decode_module(*args):
216216
if has_prebuilt_ops and uri in prebuilt_ops_uri:
217217
_kernels = torch.ops.flashinfer_kernels
218218

219-
plan_func = _kernels.batch_decode_with_paged_kv_cache_plan
220-
run_func = _kernels.batch_decode_with_paged_kv_cache_run
219+
plan_func = _kernels.batch_decode_with_paged_kv_cache_plan.default
220+
run_func = _kernels.batch_decode_with_paged_kv_cache_run.default
221221
else:
222222
mod = gen_batch_decode_module(*args)
223-
plan_func = mod.plan
224-
run_func = mod.run
223+
plan_func = mod.plan.default
224+
run_func = mod.run.default
225225

226226
# torch library for batch_decode_with_paged_kv_cache_run
227227

@@ -327,7 +327,7 @@ def single_decode_with_kv_cache_with_jit_module(
327327
"single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, device
328328
)
329329
o = torch.empty_like(q)
330-
jit_module.run(
330+
jit_module.run.default(
331331
q,
332332
k,
333333
v,

flashinfer/gemm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def bmm_fp8(
6666
) -> None:
6767
with A.device as device:
6868
cublas_handle = torch.cuda.current_blas_handle()
69-
module.bmm_fp8(
69+
module.bmm_fp8.default(
7070
A,
7171
B,
7272
D,
@@ -105,7 +105,7 @@ def cutlass_segment_gemm(
105105
weight_column_major: bool,
106106
) -> None:
107107
with x_data.device as device:
108-
module.cutlass_segment_gemm(
108+
module.cutlass_segment_gemm.default(
109109
workspace_buffer,
110110
all_problems,
111111
x_data,
@@ -182,7 +182,7 @@ def cutlass_segment_gemm_sm90(
182182
weight_column_major: bool,
183183
) -> None:
184184
with x_data.device as device:
185-
module.cutlass_segment_gemm_sm90(
185+
module.cutlass_segment_gemm_sm90.default(
186186
workspace_buffer,
187187
int_workspace_buffer,
188188
all_problems,

flashinfer/mla.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def plan(
254254
self._use_profiler = use_profiler
255255

256256
with self.device as device:
257-
self._plan_info = self._cached_module.plan(
257+
self._plan_info = self._cached_module.plan.default(
258258
self._float_workspace_buffer,
259259
self._int_workspace_buffer,
260260
self._pin_memory_int_workspace_buffer,
@@ -349,7 +349,7 @@ def run(
349349
lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse"
350350
)
351351
profiler_args = (profiler_buffer,) if self._use_profiler else ()
352-
self._cached_module.run(
352+
self._cached_module.run.default(
353353
self._float_workspace_buffer,
354354
self._int_workspace_buffer,
355355
self._plan_info,

flashinfer/norm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _rmsnorm(
8787
enable_pdl: bool,
8888
) -> None:
8989
with input.device as device: # device guard
90-
get_norm_module().rmsnorm(
90+
get_norm_module().rmsnorm.default(
9191
out, input, weight, eps, enable_pdl, get_cuda_stream(device)
9292
)
9393

@@ -134,7 +134,7 @@ def fused_add_rmsnorm(
134134
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
135135
"""
136136
with input.device as device: # device guard
137-
get_norm_module().fused_add_rmsnorm(
137+
get_norm_module().fused_add_rmsnorm.default(
138138
input, residual, weight, eps, enable_pdl, get_cuda_stream(device)
139139
)
140140

@@ -195,7 +195,7 @@ def _gemma_rmsnorm(
195195
enable_pdl: bool,
196196
) -> None:
197197
with input.device as device: # device guard
198-
get_norm_module().gemma_rmsnorm(
198+
get_norm_module().gemma_rmsnorm.default(
199199
out, input, weight, eps, enable_pdl, get_cuda_stream(device)
200200
)
201201

@@ -244,7 +244,7 @@ def gemma_fused_add_rmsnorm(
244244
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
245245
"""
246246
with input.device as device:
247-
get_norm_module().gemma_fused_add_rmsnorm(
247+
get_norm_module().gemma_fused_add_rmsnorm.default(
248248
input, residual, weight, eps, enable_pdl, get_cuda_stream(device)
249249
)
250250

flashinfer/page.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def block_sparse_indices_to_vector_sparse_offsets(
7272
assert vector_sparse_indptr.dtype == torch.int32
7373
assert kv_lens.dtype == torch.int32
7474
batch_size = block_sparse_indptr.size(0) - 1
75-
get_page_module().block_sparse_indices_to_vector_sparse_offsets(
75+
get_page_module().block_sparse_indices_to_vector_sparse_offsets.default(
7676
block_sparse_indices,
7777
block_sparse_indptr,
7878
vector_sparse_offsets,
@@ -108,7 +108,7 @@ def _append_paged_mla_kv_cache_kernel(
108108
kv_indices = kv_indices.int()
109109
kv_indptr = kv_indptr.int()
110110
kv_last_page_len = kv_last_page_len.int()
111-
get_page_module().append_paged_mla_kv_cache(
111+
get_page_module().append_paged_mla_kv_cache.default(
112112
append_ckv,
113113
append_kpe,
114114
batch_indices,
@@ -144,7 +144,7 @@ def _append_paged_kv_cache_kernel(
144144
kv_indices = kv_indices.int()
145145
kv_indptr = kv_indptr.int()
146146
kv_last_page_len = kv_last_page_len.int()
147-
get_page_module().append_paged_kv_cache(
147+
get_page_module().append_paged_kv_cache.default(
148148
append_key,
149149
append_value,
150150
batch_indices,

flashinfer/pod.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def get_pod_module(*args):
6767
_kernels = torch.ops.flashinfer_kernels
6868
# torch library for pod_with_kv_cache
6969
# No tensor deprecated due to poor performance. Just use tensor cores for both.
70-
run_tensor = _kernels.pod_with_kv_cache_tensor
70+
run_tensor = _kernels.pod_with_kv_cache_tensor.default
7171
else:
72-
run_tensor = gen_pod_module(*args).pod_with_kv_cache_tensor
72+
run_tensor = gen_pod_module(*args).pod_with_kv_cache_tensor.default
7373
# Register the module
7474
_pod_modules[args] = SimpleNamespace(run_tensor=run_tensor)
7575
return _pod_modules[args]

flashinfer/prefill.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ def backend_module(*args):
7373
if backend == "fa2":
7474
_kernels = torch.ops.flashinfer_kernels
7575

76-
run_func = _kernels.single_prefill_with_kv_cache
76+
run_func = _kernels.single_prefill_with_kv_cache.default
7777
else:
7878
_kernels_sm90 = torch.ops.flashinfer_kernels_sm90
7979

80-
run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90
80+
run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90.default
8181
else:
82-
run_func = gen_single_prefill_module(backend, *args).run
82+
run_func = gen_single_prefill_module(backend, *args).run.default
8383

8484
# torch library for single_prefill_with_kv_cache
8585

@@ -180,24 +180,30 @@ def backend_module(*args):
180180
if backend == "fa2":
181181
_kernels = torch.ops.flashinfer_kernels
182182

183-
plan_func = _kernels.batch_prefill_with_kv_cache_plan
184-
ragged_run_func = _kernels.batch_prefill_with_ragged_kv_cache_run
185-
paged_run_func = _kernels.batch_prefill_with_paged_kv_cache_run
183+
plan_func = _kernels.batch_prefill_with_kv_cache_plan.default
184+
ragged_run_func = (
185+
_kernels.batch_prefill_with_ragged_kv_cache_run.default
186+
)
187+
paged_run_func = (
188+
_kernels.batch_prefill_with_paged_kv_cache_run.default
189+
)
186190
else:
187191
_kernels_sm90 = torch.ops.flashinfer_kernels_sm90
188192

189-
plan_func = _kernels_sm90.batch_prefill_with_kv_cache_sm90_plan
193+
plan_func = (
194+
_kernels_sm90.batch_prefill_with_kv_cache_sm90_plan.default
195+
)
190196
ragged_run_func = (
191-
_kernels_sm90.batch_prefill_with_ragged_kv_cache_sm90_run
197+
_kernels_sm90.batch_prefill_with_ragged_kv_cache_sm90_run.default
192198
)
193199
paged_run_func = (
194-
_kernels_sm90.batch_prefill_with_paged_kv_cache_sm90_run
200+
_kernels_sm90.batch_prefill_with_paged_kv_cache_sm90_run.default
195201
)
196202
else:
197203
module = gen_batch_prefill_module(backend, *args)
198-
plan_func = module.plan
199-
ragged_run_func = module.ragged_run
200-
paged_run_func = module.paged_run
204+
plan_func = module.plan.default
205+
ragged_run_func = module.ragged_run.default
206+
paged_run_func = module.paged_run.default
201207

202208
# torch library for ragged_run
203209

@@ -437,9 +443,9 @@ def get_batch_prefill_jit_module(module_name: str, jit_module: Any):
437443
if module_name in _batch_prefill_jit_modules:
438444
return _batch_prefill_jit_modules[module_name]
439445

440-
plan_func = jit_module.plan
441-
ragged_run_func = jit_module.ragged_run
442-
paged_run_func = jit_module.paged_run
446+
plan_func = jit_module.plan.default
447+
ragged_run_func = jit_module.ragged_run.default
448+
paged_run_func = jit_module.paged_run.default
443449

444450
# torch library for ragged_run
445451
@register_custom_op(
@@ -611,7 +617,7 @@ def single_prefill_with_kv_cache_with_jit_module(
611617
lse = torch.empty(
612618
(q.size(0), q.size(1)), dtype=torch.float32, device=device
613619
)
614-
jit_module.run(
620+
jit_module.run.default(
615621
q,
616622
k,
617623
v,

flashinfer/quantization.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def _packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor:
4747
with x.device as device: # device guard
4848
x = x.to(torch.bool)
4949
y = torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=device)
50-
get_quantization_module().packbits(x, bitorder, y, get_cuda_stream(device))
50+
get_quantization_module().packbits.default(
51+
x, bitorder, y, get_cuda_stream(device)
52+
)
5153
return y
5254

5355

@@ -146,7 +148,7 @@ def segment_packbits(
146148
indptr = indptr.to(torch.int32)
147149
indptr_new = indptr_new.to(torch.int32)
148150
y = torch.empty(output_nnzs, dtype=torch.uint8, device=device)
149-
get_quantization_module().segment_packbits(
151+
get_quantization_module().segment_packbits.default(
150152
x, indptr, indptr_new, bitorder, y, get_cuda_stream(device)
151153
)
152154
return (

flashinfer/rope.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _apply_rope(
5858
with q.device as device:
5959
indptr = indptr.int()
6060
offsets = offsets.int()
61-
get_rope_module().apply_rope(
61+
get_rope_module().apply_rope.default(
6262
q,
6363
k,
6464
q_rope,
@@ -108,7 +108,7 @@ def _apply_llama31_rope(
108108
with q.device as device:
109109
indptr = indptr.int()
110110
offsets = offsets.int()
111-
get_rope_module().apply_llama31_rope(
111+
get_rope_module().apply_llama31_rope.default(
112112
q,
113113
k,
114114
q_rope,
@@ -159,7 +159,7 @@ def _apply_rope_pos_ids(
159159
) -> None:
160160
with q.device as device:
161161
pos_ids = pos_ids.int()
162-
get_rope_module().apply_rope_pos_ids(
162+
get_rope_module().apply_rope_pos_ids.default(
163163
q,
164164
k,
165165
q_rope,
@@ -202,7 +202,7 @@ def _apply_rope_pos_ids_cos_sin_cache(
202202
) -> None:
203203
with q.device as device:
204204
pos_ids = pos_ids.int()
205-
get_rope_module().apply_rope_pos_ids_cos_sin_cache(
205+
get_rope_module().apply_rope_pos_ids_cos_sin_cache.default(
206206
q,
207207
k,
208208
q_rope,
@@ -247,7 +247,7 @@ def _apply_llama31_rope_pos_ids(
247247
) -> None:
248248
with q.device as device:
249249
pos_ids = pos_ids.int()
250-
get_rope_module().apply_llama31_rope_pos_ids(
250+
get_rope_module().apply_llama31_rope_pos_ids.default(
251251
q,
252252
k,
253253
q_rope,

0 commit comments

Comments
 (0)