Skip to content

Commit f65b93f

Browse files
authored
feat: Added tvm binding for sampling kernel (#958)
1 parent 86b12ad commit f65b93f

File tree

5 files changed

+57
-0
lines changed

5 files changed

+57
-0
lines changed

flashinfer/jit/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
4545
)
4646
from .attention import gen_pod_module as gen_pod_module
47+
from .attention import gen_sampling_tvm_binding as gen_sampling_tvm_binding
4748
from .attention import gen_single_decode_module as gen_single_decode_module
4849
from .attention import gen_single_prefill_module as gen_single_prefill_module
4950
from .attention import get_batch_decode_mla_uri as get_batch_decode_mla_uri

flashinfer/jit/attention/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@
4848
from .tvm import (
4949
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
5050
)
51+
from .tvm import gen_sampling_tvm_binding as gen_sampling_tvm_binding

flashinfer/jit/attention/tvm.py

+16
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@
3535
from .utils import generate_additional_params
3636

3737

38+
def gen_sampling_tvm_binding(uri: str):
39+
gen_directory = FLASHINFER_GEN_SRC_DIR / uri
40+
os.makedirs(gen_directory, exist_ok=True)
41+
42+
source_paths = []
43+
for filename in ["sampling.cu", "sampling_jit_tvm_binding.cu"]:
44+
src_path = FLASHINFER_TVM_BINDING_DIR / filename
45+
dest_path = gen_directory / filename
46+
source_paths.append(dest_path)
47+
with open(src_path, "r") as f:
48+
source = f.read()
49+
write_if_different(dest_path, source)
50+
51+
return uri, source_paths
52+
53+
3854
def gen_customize_batch_prefill_tvm_binding(
3955
backend: str,
4056
uri: str,

tvm_binding/sampling.cu

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include <flashinfer/attention/hopper/attention_updater.cuh>
2+
#include <flashinfer/attention/hopper/variant_helper.cuh>
3+
#include <flashinfer/cutlass_utils.cuh>
4+
#include <flashinfer/layout.cuh>
5+
#include <flashinfer/math.cuh>
6+
#include <flashinfer/sampling.cuh>
7+
8+
#include "tvm_binding_utils.h"
9+
10+
using namespace flashinfer;
11+
12+
// TODO: change the philox seeds and offsets to DLTensor once the underlying API for sampling
13+
// changes to support multiple seeds
14+
void SamplingFromProbs(DLTensor* probs, DLTensor* output, DLTensor* maybe_indices,
15+
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
16+
int64_t cuda_stream) {
17+
CHECK(probs->ndim == 2) << "Probs should have 2 dimensions";
18+
unsigned int batch_size = output->shape[0];
19+
unsigned int vocab_size = probs->shape[1];
20+
21+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
22+
float* probs_cast = static_cast<float*>(probs->data) + probs->byte_offset;
23+
int* output_cast = static_cast<int*>(output->data) + output->byte_offset;
24+
int* maybe_indices_cast =
25+
maybe_indices ? static_cast<int*>(maybe_indices->data) + maybe_indices->byte_offset : nullptr;
26+
27+
cudaError_t status =
28+
sampling::SamplingFromProb(probs_cast, output_cast, maybe_indices_cast, batch_size,
29+
vocab_size, deterministic, philox_seed, philox_offset, stream);
30+
CHECK(status == cudaSuccess) << "SamplingFromProbs failed with error "
31+
<< cudaGetErrorString(status);
32+
}
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "tvm_binding_utils.h"
2+
3+
void SamplingFromProbs(DLTensor* probs, DLTensor* output, DLTensor* maybe_indices,
4+
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
5+
int64_t cuda_stream);
6+
7+
TVM_DLL_EXPORT_TYPED_FUNC(sampling_from_probs, SamplingFromProbs);

0 commit comments

Comments
 (0)