|
16 | 16 | #ifndef FLASHINFER_SAMPLING_CUH_
|
17 | 17 | #define FLASHINFER_SAMPLING_CUH_
|
18 | 18 |
|
| 19 | +#include <driver_types.h> |
| 20 | + |
19 | 21 | #include <cub/block/block_adjacent_difference.cuh>
|
20 | 22 | #include <cub/block/block_reduce.cuh>
|
21 | 23 | #include <cub/block/block_scan.cuh>
|
@@ -342,6 +344,96 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
|
342 | 344 | }
|
343 | 345 | }
|
344 | 346 |
|
| 347 | +template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM, |
| 348 | + BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, typename DType, typename IdType> |
| 349 | +__global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* top_k, |
| 350 | + DType* top_p, IdType* output, bool* success, |
| 351 | + uint32_t d, uint32_t max_rounds) { |
| 352 | + const uint32_t batch_size = gridDim.x; |
| 353 | + const uint32_t bx = blockIdx.x, tx = threadIdx.x; |
| 354 | + IdType k = top_k[bx]; |
| 355 | + DType p = top_p[bx]; |
| 356 | + |
| 357 | + extern __shared__ __align__( |
| 358 | + alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) |
| 359 | + uint8_t smem_sampling[]; |
| 360 | + auto& temp_storage = reinterpret_cast< |
| 361 | + SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling); |
| 362 | + |
| 363 | + vec_t<DType, VEC_SIZE> probs_vec; |
| 364 | + DType aggregate; |
| 365 | + DType q = DType(0); |
| 366 | + DType pivot = DType(0); |
| 367 | + IdType sampled_id; |
| 368 | + for (uint32_t round = 0; round < max_rounds; ++round) { |
| 369 | + temp_storage.data.sampled_id = d - 1; |
| 370 | + __syncthreads(); |
| 371 | + DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q); |
| 372 | + aggregate = DType(0); |
| 373 | + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
| 374 | + probs_vec.fill(DType(0)); |
| 375 | + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
| 376 | + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); |
| 377 | + } |
| 378 | + |
| 379 | + DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DType>( |
| 380 | + i, d, pivot, u, probs_vec, aggregate, &temp_storage); |
| 381 | + if (aggregate > u) { |
| 382 | + break; |
| 383 | + } |
| 384 | + } |
| 385 | + __syncthreads(); |
| 386 | + sampled_id = temp_storage.data.sampled_id; |
| 387 | + pivot = probs[bx * d + sampled_id]; |
| 388 | + |
| 389 | + Pair<DType> aggregate_leq_pivot{DType(0), 0}; |
| 390 | + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
| 391 | + probs_vec.fill(DType(0)); |
| 392 | + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
| 393 | + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); |
| 394 | + } |
| 395 | + |
| 396 | + Pair<DType> probs_leq_pivot[VEC_SIZE]; |
| 397 | +#pragma unroll |
| 398 | + for (uint32_t j = 0; j < VEC_SIZE; ++j) { |
| 399 | + probs_leq_pivot[j] = { |
| 400 | + (probs_vec[j] <= pivot) ? probs_vec[j] : DType(0), |
| 401 | + (probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; |
| 402 | + } |
| 403 | + |
| 404 | + aggregate_leq_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>( |
| 405 | + temp_storage.block_prim.reduce_pair) |
| 406 | + .Sum<VEC_SIZE>(probs_leq_pivot); |
| 407 | + if (tx == 0) { |
| 408 | + temp_storage.data.block_aggregate.pair = aggregate_leq_pivot; |
| 409 | + } |
| 410 | + __syncthreads(); |
| 411 | + if (temp_storage.data.block_aggregate.pair.count + k > d && |
| 412 | + float(temp_storage.data.block_aggregate.pair.value) + p > 1 + eps) { |
| 413 | + break; |
| 414 | + } |
| 415 | + } |
| 416 | + q = temp_storage.data.block_aggregate.pair.value; |
| 417 | + if (temp_storage.data.block_aggregate.pair.count + k > d && float(q) + p > 1 + eps) { |
| 418 | + break; |
| 419 | + } |
| 420 | + } |
| 421 | + __syncthreads(); |
| 422 | + if (tx == 0) { |
| 423 | + if (temp_storage.data.block_aggregate.pair.count + k <= d || float(q) + p <= 1 + eps) { |
| 424 | + // failed to sample within MAX_TOP_P_ROUNDS |
| 425 | + if (success != nullptr) { |
| 426 | + success[bx] = false; |
| 427 | + } |
| 428 | + } else { |
| 429 | + output[bx] = sampled_id; |
| 430 | + if (success != nullptr) { |
| 431 | + success[bx] = true; |
| 432 | + } |
| 433 | + } |
| 434 | + } |
| 435 | +} |
| 436 | + |
345 | 437 | template <typename T, typename IdType>
|
346 | 438 | cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size,
|
347 | 439 | uint32_t d, cudaStream_t stream = 0) {
|
@@ -434,6 +526,28 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
|
434 | 526 | return cudaSuccess;
|
435 | 527 | }
|
436 | 528 |
|
| 529 | +template <typename T, typename IdType> |
| 530 | +cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k, T* top_p, |
| 531 | + IdType* output, bool* success, uint32_t batch_size, uint32_t d, |
| 532 | + uint32_t max_rounds, cudaStream_t stream = 0) { |
| 533 | + constexpr uint32_t BLOCK_THREADS = 1024; |
| 534 | + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); |
| 535 | + |
| 536 | + const uint32_t smem_size = sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>); |
| 537 | + dim3 nblks(batch_size); |
| 538 | + dim3 nthrs(BLOCK_THREADS); |
| 539 | + void* args[] = {&probs, &uniform_samples, &top_k, &top_p, &output, &success, &d, &max_rounds}; |
| 540 | + |
| 541 | + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { |
| 542 | + auto kernel = |
| 543 | + TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, T, IdType>; |
| 544 | + FLASHINFER_CUDA_CALL( |
| 545 | + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); |
| 546 | + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); |
| 547 | + }); |
| 548 | + return cudaSuccess; |
| 549 | +} |
| 550 | + |
437 | 551 | template <typename T, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
438 | 552 | struct RenormTempStorage {
|
439 | 553 | union {
|
|
0 commit comments