Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling kernel (top_k_top_p_sampling_from_probs) hangs #769

Closed
dskhudia opened this issue Jan 31, 2025 · 11 comments
Closed

Sampling kernel (top_k_top_p_sampling_from_probs) hangs #769

dskhudia opened this issue Jan 31, 2025 · 11 comments
Assignees

Comments

@dskhudia
Copy link

dskhudia commented Jan 31, 2025

For certain inputs, sampling kernels hang. I think the following in sampling.cuh is the culprit code. __syncthreads() is called conditionally here and under certain input conditions some threads may never hit __syncthreads().

  if (aggregate + aggregate_local > u) {
    if constexpr (DETERMINISTIC) {
      DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, T>(
          prob_greater_than_threshold, inclusive_cdf, temp_storage);
    } else {
      BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
          .InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);

      __syncthreads();
    }

#pragma unroll
    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
      greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j];
    }

    bool greater_than_u_diff[VEC_SIZE];
#ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED
    BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
        .SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
    BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
        .FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
    __syncthreads();

#pragma unroll
    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
      if (greater_than_u_diff[j]) {
        atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
      }
    }
    __syncthreads();
  }
@dskhudia dskhudia changed the title Sampling kernel hangs Sampling kernel (top_k_top_p_sampling_from_probs) hangs Jan 31, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Jan 31, 2025

__syncthreads() is called conditionally here and under certain input conditions some threads may never hit __syncthreads()

The expected behavior is each thread hold the same aggregate + aggregate_local.

Do you have a script to reproduce the issue? It will be easier for debugging.

@dskhudia
Copy link
Author

But u can be different for each thread?

@dskhudia
Copy link
Author

Let me see if I can get a minimal repro.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 31, 2025

But u can be different for each thread?

All threads in a block should read the same u in a round.

@dskhudia
Copy link
Author

Repro (Hangs on H100-80G). Instructions and hints for debugging based-on me looking at the code:

  • If you change probs tensor, it passes. So it's an input dependent hang.
  • Changing uniform_samples doesn't make the hang go away.
  • rename tensors.txt file to tensor.tar and untar the attached file before running. Github didn't allow to attach a file with tar extension.

tensors.txt

import torch
from flashinfer.sampling import (
        min_p_sampling_from_probs,
        top_k_renorm_prob,
        top_k_top_p_sampling_from_probs,
        top_p_renorm_prob
        )

probs = torch.load('probs.0')
uniform_samples = torch.load('uniform_samples.0')
top_ks = torch.load('sampling_info.top_ks.0')
top_ps = torch.load('sampling_info.top_ps.0')


print(f'uniform_samples: {uniform_samples.shape}')

batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
        probs,
        uniform_samples,
        top_ks,
        top_ps,
    )
torch.cuda.synchronize()
print(f'success: {success}')

@yzh119
Copy link
Collaborator

yzh119 commented Jan 31, 2025

Cool, thanks so much for collecting this failed case :)

@yzh119
Copy link
Collaborator

yzh119 commented Feb 1, 2025

The failure case occurs in probs[5], which exhibits a nearly one-hot distribution—one value is close to 1 while all others are near zero.

Our sampling kernels get stuck at

do {
Pair<DType> threadlocal_sum{DType(0), 0};
Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
float mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_greater_than_pivot_pair[j] = {
(probs_vec[j] > mid) ? probs_vec[j] : DType(0),
(probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
}
threadlocal_sum += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
__syncthreads();
}
min_gt_low =
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
if (tx == 0) {
temp_storage.block_aggregate.pair = threadlocal_sum;
temp_storage.min_val = min_gt_low;
temp_storage.max_val = max_le_high;
}
__syncthreads();
threadlocal_sum = temp_storage.block_aggregate.pair;
min_gt_low = temp_storage.min_val;
max_le_high = temp_storage.max_val;
if (threadlocal_sum.count >= k) {
low = mid;
sum_low = float(threadlocal_sum.value);
} else {
high = min(mid, max_le_high);
}
} while (min_gt_low != max_le_high);

The issue arises because min_gt_low and max_le_high are very close to each other (and close to 1), but they are not exactly equal, preventing loop termination.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 1, 2025

I printed the values of (threadlocal_sum.count, k, low, mid, high, min_gt_low, max_le_high by (the floating point values are too small so I use hex form:

      if (threadIdx.x == 0 && blockIdx.x == 5) {
        printf("%d %d %x %x %x %x %x\n", threadlocal_sum.count, k, *reinterpret_cast<int*>(&low), *reinterpret_cast<int*>(&mid), *reinterpret_cast<int*>(&high), *reinterpret_cast<int*>(&min_gt_low), *reinterpret_cast<int*>(&max_le_high));
      }
uniform_samples: torch.Size([32, 7])
1 200 0 3f000000 3f000000 8c0000 3f800000
1 200 0 3e800000 1c830000 8c0000 1c830000
3 200 0 1c030000 1c030000 8c0000 1c830000
3 200 0 1b830000 1b690000 8c0000 1b690000
4 200 0 1ae90000 1ae90000 8c0000 1b690000
4 200 0 1a690000 18550000 8c0000 18550000
5 200 0 17d50000 17d50000 8c0000 18550000
5 200 0 17550000 16b30000 8c0000 16b30000
7 200 0 16330000 16330000 8c0000 16b30000
7 200 0 15b30000 12f10000 8c0000 12f10000
8 200 0 12710000 12710000 8c0000 12f10000
8 200 0 11f10000 10f70000 8c0000 10f70000
11 200 0 10770000 10770000 8c0000 10f70000
12 200 0 ff70000 ff70000 8c0000 10350000
14 200 0 f770000 f770000 8c0000 fab0000
14 200 0 ef70000 e6e0000 8c0000 e6e0000
15 200 0 dee0000 dee0000 8c0000 e6e0000
17 200 0 d6e0000 d6e0000 8c0000 daf0000
18 200 0 cee0000 cee0000 8c0000 d550000
20 200 0 c6e0000 c6e0000 8c0000 c9c0000
21 200 0 bee0000 bee0000 8c0000 c140000
24 200 0 b6e0000 b6e0000 8c0000 be60000
25 200 0 aee0000 aee0000 8c0000 b290000
28 200 0 a6e0000 a6e0000 8c0000 ace0000
30 200 0 9ee0000 9ee0000 8c0000 a420000
31 200 0 96e0000 96e0000 8c0000 9b70000
36 200 0 8ee0000 8ee0000 8c0000 95f0000
39 200 0 86e0000 86e0000 8c0000 8a40000
41 200 0 7ee0000 7ee0000 8c0000 8470000
42 200 0 76e0000 76e0000 8c0000 7920000
51 200 0 6ee0000 6ee0000 8c0000 7640000
56 200 0 66e0000 66e0000 8c0000 6d70000
64 200 0 5ee0000 5ee0000 8c0000 64b0000
72 200 0 56e0000 56e0000 8c0000 5c00000
84 200 0 4ee0000 4ee0000 8c0000 5690000
99 200 0 46e0000 46e0000 8c0000 4dc0000
104 200 0 3ee0000 3ee0000 8c0000 4500000
109 200 0 36e0000 36e0000 8c0000 3c40000
122 200 0 2ee0000 2ee0000 8c0000 3390000
141 200 0 26e0000 26e0000 8c0000 2e10000
155 200 0 1ee0000 1ee0000 8c0000 2550000
174 200 0 16e0000 16e0000 8c0000 1e40000
191 200 0 ee0000 ee0000 8c0000 13e0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000
212 200 0 0 ee0000 8c0000 cb0000

following logs are omitted cause the repeat the same pattern

@yzh119
Copy link
Collaborator

yzh119 commented Feb 1, 2025

It's interesting to see that the average of float("0x0") and float("0xee0000") is float("0x0") on cuda (it's not the case for x86 cpu)

low = float("0x0")
high = float("0xee0000")
mid = (low + high) / 2 # float("0x0")

this somehow breaks our loop invariant because we assume low < mid < high.

p.s. this only happens when you enable -use_fast_math, the compiler will emit a div.approx.ftz.f32 instruction.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 1, 2025

Disabling fast math compiler flag, or changing the data types of low, mid, high to double, or adding an episilon term in the condition (min_gt_low + eps < max_le_high) can fix the issue.

yzh119 added a commit that referenced this issue Feb 1, 2025
…ampling Kernels (#774)

This PR addresses issue #769.  

As discussed in [this
comment](#769 (comment)),
the use of the approximate division instruction `div.approx.ftz.f32` can
break the loop invariant, preventing the loop from terminating. To
resolve this, this PR changes the data types of `low`, `high`, and `mid`
to `double`, ensuring that the compiler maintains IEEE-754 compliance
and preserves numerical stability.
@yzh119
Copy link
Collaborator

yzh119 commented Feb 1, 2025

Should have been fixed in #774 .

@yzh119 yzh119 closed this as completed Feb 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants