-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perf drop in rmsnorm #960
Comments
Hi @xiaoqi35 , that's not a fair way for comparing cuda kernel performance (you are not measuring the kernel execution time, see https://pytorch.org/docs/stable/notes/cuda.html#asynchronous-execution), you can try using triton's do_bench function: from triton.testing import do_bench
print(do_bench(lambda: flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps), warmup=100, rep=1000) or directly run the benchmark: https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_fused_add_rmsnorm.py |
Thanks your idea firstly! I haven't find out the root cause. The above phenomenon occurred even in 0.2.0. I am so appreciative If you could provide the clue about the end-2-end(python API) perf data. |
i also find similar probelm that 0.1.6+cu124torch2.4 BatchPrefillWithPagedKVCachePyTorchWrapper run is faster then 0.2.0 BatchPrefillWithPagedKVCachePyTorchWrapper run |
Can you confirm whether it's because of JIT? import flashinfer.flashinfer_kernels
import flashinfer.flashinfer_kernels_sm90 to check whether you are using prebuilt kernels (AOT) or using JIT. For JIT, there should be a warmup phase to make sure we compile and cache all kernels ahead inference, otherwise compilation time will be counted.
At kernel-wise, the implementation between 0.1.6 and 0.2.* should be equivalent, and actually we fixed some issues so 0.2.* should run faster (according to my benchmark suite). Can you share us more details (what's your GPU architecture, how did you measure the execution time, using nsys?). |
i use nsys with 0.2.3 aot version, no jit in runtime, and 0.1.6 compare two nsys prof, i find the spacing between the kernels is a bit larger in 0.2.3 aot version |
Not in JIT. Hardware: NVIDIA A100 80G |
Hi @xiaoqi35 @MichoChan , the kernels are expected to be captured and replayed by CUDAGraphs, thus removing CPU-side overheads. The traces provided by @MichoChan have disabled CUDAGraph. For more context: we now use torch.library interface instead of previous pytorch CUDA Extension (which v0.1.6 relies on) for python-c++ ffi which might have some overhead at CPU side.
In general, we care about end-to-end performance when all kernels captured with CUDAGraph (that's the common case of LLM inference frameworks). CPU overhead is also something we should investigate (but when CUDAGraph is enabled, they should disappear) @abcdabcd987 @youkaichao do you have any insights where these CPU overheads might come from? |
when we use torch library to bind ops, sometimes we can bypass pytorch dispatcher by accessing the key is the |
@youkaichao thank you for the hint, let me try it! |
Hi @MichoChan @xiaoqi35 , does #968 address your issue? |
Not solved if 968 doesn't need to re-compile and build. Maybe a higher level guard policy, End2Eend LLM level, is necessary when release flashinfer's major version. Thanks a lot your great work again! |
Can you provide a nsys trace after enabling cudagraph? We haven't observed that in TRTLLM. |
@MichoChan @xiaoqi35 here is a script to reproduce the cpu overhead issue: https://gist.github.com/yzh119/d9bf2abbb667abcbb806979f4bbea633 We indeed observe huge python side overhead in v0.2, one of the reason is the device guard ( Line 136 in 86b12ad
I tried to reproduce your issue on A100, here is the result I get: v0.2.3
v0.2.3 (remove device guard)
v0.1.6
The gap without CUDAGraph is large, we should fix the issue ASAP, thanks for reporting. There are some changes after the release of v0.1.6 that improves the numerical precision (#587, #592) that slows down the kernel a little bit. |
Result on H100 (we add PDL support in #930 , which benefits H100 and later architectures when CUDAGraph is turned on, but not useful for earlier GPU such as Ampere): v0.2.3 (w/ pdl)You can turn pdl on by setting
Remove device guard
v0.2.3 (no pdl)
v0.1.6
|
This PR fixes issue #960 , we identifies several performance bottlenecks for our python APIs when kernels are not captured by CUDAGraph: 1. The device guard in Python is slow (`with input.device as device:`) 2. Get current cuda stream in Python is time-consuming. These issues were introduced in JIT refactor after v0.1.6 (mainly for accelerating JIT compilation speed). In this PR, we changed back to get stream and device guard in C++). @MichoChan @xiaoqi35
#969 should greatly reduce the python overhead. |
I'll close it for now, feel free to re-open the issue if it still exist. |
The perf of singe norm enabled cudagraph is aligned with 0.1.6. My tests also verified it now. |
file https://github.com/flashinfer-ai/flashinfer/blob/main/tests/test_norm.py
run testcase with these changes:
...
t0 = time.time()
flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
t1 = time.time()
if name == "main":
test_gemma_fused_add_rmsnorm(19, 4096, torch.float16, False)
test_gemma_norm(19, 3072, torch.float16, False, True)
...
flashinfer 0.1.6+cu124torch2.4:
gemma_fused_add_rmsnorm: 0.0009355545043945312
gemma_rmsnorm: 5.435943603515625e-05
flashinfer-python-0.2.3+cu124torch2.5:
gemma_fused_add_rmsnorm: 0.0019216537475585938
gemma_rmsnorm: 0.00036144256591796875
The text was updated successfully, but these errors were encountered: