|
| 1 | +#include <flashinfer/attention/hopper/attention_updater.cuh> |
| 2 | +#include <flashinfer/attention/hopper/variant_helper.cuh> |
| 3 | +#include <flashinfer/cutlass_utils.cuh> |
| 4 | +#include <flashinfer/layout.cuh> |
| 5 | +#include <flashinfer/math.cuh> |
| 6 | +#include <flashinfer/sampling.cuh> |
| 7 | + |
| 8 | +#include "tvm_binding_utils.h" |
| 9 | + |
| 10 | +using namespace flashinfer; |
| 11 | + |
| 12 | +// TODO: change the philox seeds and offsets to DLTensor once the underlying API for sampling |
| 13 | +// changes to support multiple seeds |
| 14 | +void SamplingFromProbs(DLTensor* probs, DLTensor* output, DLTensor* maybe_indices, |
| 15 | + bool deterministic, uint64_t philox_seed, uint64_t philox_offset, |
| 16 | + int64_t cuda_stream) { |
| 17 | + CHECK(probs->ndim == 2) << "Probs should have 2 dimensions"; |
| 18 | + unsigned int batch_size = output->shape[0]; |
| 19 | + unsigned int vocab_size = probs->shape[1]; |
| 20 | + |
| 21 | + cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); |
| 22 | + float* probs_cast = static_cast<float*>(probs->data) + probs->byte_offset; |
| 23 | + int* output_cast = static_cast<int*>(output->data) + output->byte_offset; |
| 24 | + int* maybe_indices_cast = |
| 25 | + maybe_indices ? static_cast<int*>(maybe_indices->data) + maybe_indices->byte_offset : nullptr; |
| 26 | + |
| 27 | + cudaError_t status = |
| 28 | + sampling::SamplingFromProb(probs_cast, output_cast, maybe_indices_cast, batch_size, |
| 29 | + vocab_size, deterministic, philox_seed, philox_offset, stream); |
| 30 | + CHECK(status == cudaSuccess) << "SamplingFromProbs failed with error " |
| 31 | + << cudaGetErrorString(status); |
| 32 | +} |
0 commit comments