Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf drop in rmsnorm #960

Closed
xiaoqi35 opened this issue Mar 19, 2025 · 17 comments · Fixed by #968
Closed

Perf drop in rmsnorm #960

xiaoqi35 opened this issue Mar 19, 2025 · 17 comments · Fixed by #968

Comments

@xiaoqi35
Copy link

file https://github.com/flashinfer-ai/flashinfer/blob/main/tests/test_norm.py

run testcase with these changes:
...
t0 = time.time()
flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
t1 = time.time()

if name == "main":
test_gemma_fused_add_rmsnorm(19, 4096, torch.float16, False)
test_gemma_norm(19, 3072, torch.float16, False, True)
...

flashinfer 0.1.6+cu124torch2.4:
gemma_fused_add_rmsnorm: 0.0009355545043945312
gemma_rmsnorm: 5.435943603515625e-05

flashinfer-python-0.2.3+cu124torch2.5:
gemma_fused_add_rmsnorm: 0.0019216537475585938
gemma_rmsnorm: 0.00036144256591796875

@yzh119
Copy link
Collaborator

yzh119 commented Mar 19, 2025

Hi @xiaoqi35 , that's not a fair way for comparing cuda kernel performance (you are not measuring the kernel execution time, see https://pytorch.org/docs/stable/notes/cuda.html#asynchronous-execution), you can try using triton's do_bench function:

from triton.testing import do_bench

print(do_bench(lambda: flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps), warmup=100, rep=1000)

or directly run the benchmark: https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_fused_add_rmsnorm.py

@xiaoqi35
Copy link
Author

xiaoqi35 commented Mar 20, 2025

Hi @xiaoqi35 , that's not a fair way for comparing cuda kernel performance (you are not measuring the kernel execution time, see https://pytorch.org/docs/stable/notes/cuda.html#asynchronous-execution), you can try using triton's do_bench function:

from triton.testing import do_bench

print(do_bench(lambda: flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps), warmup=100, rep=1000)
or directly run the benchmark: https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_fused_add_rmsnorm.py

Thanks your idea firstly!
As your saying, the perf of cuda kernel source not changed;
But in the user viewing, many python interface's perf drop so much, including attention(the baseline is flashinfer 0.1.6).
Some model (like gemma2) inference served by sglang drops perf obviously if upgrade flashinfer-0.1.6 to flashinfer-0.2.3.

I haven't find out the root cause. The above phenomenon occurred even in 0.2.0. I am so appreciative If you could provide the clue about the end-2-end(python API) perf data.

@MichoChan
Copy link

MichoChan commented Mar 20, 2025

i also find similar probelm that 0.1.6+cu124torch2.4 BatchPrefillWithPagedKVCachePyTorchWrapper run is faster then 0.2.0 BatchPrefillWithPagedKVCachePyTorchWrapper run

@yzh119
Copy link
Collaborator

yzh119 commented Mar 20, 2025

@MichoChan @xiaoqi35

Can you confirm whether it's because of JIT?
You can try

import flashinfer.flashinfer_kernels
import flashinfer.flashinfer_kernels_sm90

to check whether you are using prebuilt kernels (AOT) or using JIT.

For JIT, there should be a warmup phase to make sure we compile and cache all kernels ahead inference, otherwise compilation time will be counted.

i also find similar probelm that 0.1.6+cu124torch2.4 BatchPrefillWithPagedKVCachePyTorchWrapper run is faster then 0.2.0 BatchPrefillWithPagedKVCachePyTorchWrapper run

At kernel-wise, the implementation between 0.1.6 and 0.2.* should be equivalent, and actually we fixed some issues so 0.2.* should run faster (according to my benchmark suite).

Can you share us more details (what's your GPU architecture, how did you measure the execution time, using nsys?).

@MichoChan
Copy link

i use nsys with 0.2.3 aot version, no jit in runtime, and 0.1.6
i profile model llama3.2 3b using sglang with different flashinfer version.

compare two nsys prof, i find the spacing between the kernels is a bit larger in 0.2.3 aot version

nsys_res.tgz

@xiaoqi35
Copy link
Author

xiaoqi35 commented Mar 21, 2025

@MichoChan @xiaoqi35

Can you confirm whether it's because of JIT? You can try

import flashinfer.flashinfer_kernels
import flashinfer.flashinfer_kernels_sm90
to check whether you are using prebuilt kernels (AOT) or using JIT.

For JIT, there should be a warmup phase to make sure we compile and cache all kernels ahead inference, otherwise compilation time will be counted.

i also find similar probelm that 0.1.6+cu124torch2.4 BatchPrefillWithPagedKVCachePyTorchWrapper run is faster then 0.2.0 BatchPrefillWithPagedKVCachePyTorchWrapper run

At kernel-wise, the implementation between 0.1.6 and 0.2.* should be equivalent, and actually we fixed some issues so 0.2.* should run faster (according to my benchmark suite).

Can you share us more details (what's your GPU architecture, how did you measure the execution time, using nsys?).

Not in JIT. Hardware: NVIDIA A100 80G
Perf drop in cpu time. I don't know what flashinfer contributors care about point: only kernel during or all End2End(including pre/post kernel time kernel).
The nsys report supports my opinion. In v0.2.3, cuda kernel was launched later, not executed during longer.
The least reproducted code and nsys profile are attached with norm.tar.gz.

norm.tar.gz

@yzh119
Copy link
Collaborator

yzh119 commented Mar 21, 2025

Hi @xiaoqi35 @MichoChan , the kernels are expected to be captured and replayed by CUDAGraphs, thus removing CPU-side overheads. The traces provided by @MichoChan have disabled CUDAGraph.

For more context: we now use torch.library interface instead of previous pytorch CUDA Extension (which v0.1.6 relies on) for python-c++ ffi which might have some overhead at CPU side.

I don't know what flashinfer contributors care about point: only kernel during or all End2End(including pre/post kernel time kernel).

In general, we care about end-to-end performance when all kernels captured with CUDAGraph (that's the common case of LLM inference frameworks). CPU overhead is also something we should investigate (but when CUDAGraph is enabled, they should disappear) @abcdabcd987 @youkaichao do you have any insights where these CPU overheads might come from?

@youkaichao
Copy link
Collaborator

when we use torch library to bind ops, sometimes we can bypass pytorch dispatcher by accessing torch.ops.namespace.op_name.default directly, that might help reducing cpu overhead.

the key is the .default attribute, which calls the C++ side function directly.

@yzh119
Copy link
Collaborator

yzh119 commented Mar 22, 2025

@youkaichao thank you for the hint, let me try it!

@yzh119
Copy link
Collaborator

yzh119 commented Mar 22, 2025

Hi @MichoChan @xiaoqi35 , does #968 address your issue?

@yzh119 yzh119 closed this as completed in 86b12ad Mar 22, 2025
@yzh119 yzh119 reopened this Mar 22, 2025
@xiaoqi35
Copy link
Author

xiaoqi35 commented Mar 23, 2025

Hi @MichoChan @xiaoqi35 , does #968 address your issue?

Not solved if 968 doesn't need to re-compile and build.
The norm module in gemma2 decodelayer is slower than 0.1.6 even in cudagraph captured still.

Maybe a higher level guard policy, End2Eend LLM level, is necessary when release flashinfer's major version.

Thanks a lot your great work again!

@yzh119
Copy link
Collaborator

yzh119 commented Mar 23, 2025

The norm module in gemma2 decodelayer is slower than 0.1.6 even in cudagraph captured still.

Can you provide a nsys trace after enabling cudagraph? We haven't observed that in TRTLLM.

@yzh119
Copy link
Collaborator

yzh119 commented Mar 23, 2025

@MichoChan @xiaoqi35 here is a script to reproduce the cpu overhead issue:

https://gist.github.com/yzh119/d9bf2abbb667abcbb806979f4bbea633

We indeed observe huge python side overhead in v0.2, one of the reason is the device guard (

with input.device as device: # device guard
).

I tried to reproduce your issue on A100, here is the result I get:

v0.2.3

w/o CUDAGraph 0.02311849594116211
w/  CUDAGraph 0.003802776336669922

v0.2.3 (remove device guard)

w/o CUDAGraph 0.01325082778930664
w/  CUDAGraph 0.003797769546508789

v0.1.6

w/o CUDAGraph 0.004994630813598633
w/  CUDAGraph 0.0035665035247802734

The gap without CUDAGraph is large, we should fix the issue ASAP, thanks for reporting. There are some changes after the release of v0.1.6 that improves the numerical precision (#587, #592) that slows down the kernel a little bit.

@yzh119
Copy link
Collaborator

yzh119 commented Mar 23, 2025

Result on H100 (we add PDL support in #930 , which benefits H100 and later architectures when CUDAGraph is turned on, but not useful for earlier GPU such as Ampere):

v0.2.3 (w/ pdl)

You can turn pdl on by setting enable_pdl=True in the norm kernels:

w/o CUDAGraph 0.011697769165039062
w/  CUDAGraph 0.002118825912475586

Remove device guard

w/o CUDAGraph 0.008272647857666016
w/  CUDAGraph 0.002228260040283203

v0.2.3 (no pdl)

w/o CUDAGraph 0.020609617233276367
w/  CUDAGraph 0.002488374710083008

v0.1.6

w/o CUDAGraph 0.0037620067596435547
w/  CUDAGraph 0.0026903152465820312

yzh119 added a commit that referenced this issue Mar 24, 2025
This PR fixes issue #960 , we identifies several performance bottlenecks
for our python APIs when kernels are not captured by CUDAGraph:
1. The device guard in Python is slow (`with input.device as device:`)
2. Get current cuda stream in Python is time-consuming.

These issues were introduced in JIT refactor after v0.1.6 (mainly for
accelerating JIT compilation speed). In this PR, we changed back to get
stream and device guard in C++).

@MichoChan @xiaoqi35
@yzh119
Copy link
Collaborator

yzh119 commented Mar 24, 2025

#969 should greatly reduce the python overhead.

@yzh119
Copy link
Collaborator

yzh119 commented Mar 24, 2025

I'll close it for now, feel free to re-open the issue if it still exist.

@yzh119 yzh119 closed this as completed Mar 24, 2025
@xiaoqi35
Copy link
Author

The norm module in gemma2 decodelayer is slower than 0.1.6 even in cudagraph captured still.

Can you provide a nsys trace after enabling cudagraph? We haven't observed that in TRTLLM.

The perf of singe norm enabled cudagraph is aligned with 0.1.6. My tests also verified it now.
My test method made a mistake (added timeflag into gemma2 decodelayer enabled cudagraph during decode phase ).
It's okay to close issue for now. I'll re-open it if I got evidence that 0.2.x is still slower than 0.1.6 during gemma2 decode phase.
Thanks your professional response and code commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants