-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sampling kernel (top_k_top_p_sampling_from_probs) hangs #769
Comments
The expected behavior is each thread hold the same Do you have a script to reproduce the issue? It will be easier for debugging. |
But u can be different for each thread? |
Let me see if I can get a minimal repro. |
All threads in a block should read the same |
Repro (Hangs on H100-80G). Instructions and hints for debugging based-on me looking at the code:
|
Cool, thanks so much for collecting this failed case :) |
The failure case occurs in Our sampling kernels get stuck at flashinfer/include/flashinfer/sampling.cuh Lines 1080 to 1130 in 090b100
The issue arises because |
I printed the values of ( if (threadIdx.x == 0 && blockIdx.x == 5) {
printf("%d %d %x %x %x %x %x\n", threadlocal_sum.count, k, *reinterpret_cast<int*>(&low), *reinterpret_cast<int*>(&mid), *reinterpret_cast<int*>(&high), *reinterpret_cast<int*>(&min_gt_low), *reinterpret_cast<int*>(&max_le_high));
}
following logs are omitted cause the repeat the same pattern |
It's interesting to see that the average of float("0x0") and float("0xee0000") is float("0x0") on cuda (it's not the case for x86 cpu)
this somehow breaks our loop invariant because we assume p.s. this only happens when you enable |
Disabling fast math compiler flag, or changing the data types of |
…ampling Kernels (#774) This PR addresses issue #769. As discussed in [this comment](#769 (comment)), the use of the approximate division instruction `div.approx.ftz.f32` can break the loop invariant, preventing the loop from terminating. To resolve this, this PR changes the data types of `low`, `high`, and `mid` to `double`, ensuring that the compiler maintains IEEE-754 compliance and preserves numerical stability.
Should have been fixed in #774 . |
For certain inputs, sampling kernels hang. I think the following in sampling.cuh is the culprit code.
__syncthreads()
is called conditionally here and under certain input conditions some threads may never hit__syncthreads()
.The text was updated successfully, but these errors were encountered: