diff --git a/Jenkinsfile b/Jenkinsfile index 4b0584c6b..68b0974c0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -61,6 +61,17 @@ def unpack_lib(name, libs) { """ } +def cancel_previous_build() { + // cancel previous build if it is not on main. + if (env.BRANCH_NAME != 'main') { + def buildNumber = env.BUILD_NUMBER as int + // Milestone API allows us to cancel previous build + // with the same milestone number + if (buildNumber > 1) milestone(buildNumber - 1) + milestone(buildNumber) + } +} + def init_git(submodule = false) { cleanWs() // add retry in case checkout timeouts @@ -84,10 +95,21 @@ def init_git(submodule = false) { // } // } -stage('JIT Unittest') { +stage('Unittest') { + cancel_previous_build() parallel( failFast: true, - 'GPU-G5-Test-1': { + 'AOT-Build-Import': { + node('CPU-LARGE-SPOT') { + ws(per_exec_ws('flashinfer-aot')) { + init_git(true) + sh(script: "ls -alh", label: 'Show work directory') + sh(script: "./scripts/task_show_node_info.sh", label: 'Show node info') + sh(script: "${docker_run} --no-gpu ./scripts/task_test_aot_build_import.sh", label: 'Test AOT Build and Import') + } + } + }, + 'JIT-Unittest-1': { node('GPU-G5-SPOT') { ws(per_exec_ws('flashinfer-unittest')) { init_git(true) // we need cutlass submodule @@ -97,7 +119,7 @@ stage('JIT Unittest') { } } }, - 'GPU-G5-Test-2': { + 'JIT-Unittest-2': { node('GPU-G5-SPOT') { ws(per_exec_ws('flashinfer-unittest')) { init_git(true) // we need cutlass submodule @@ -107,7 +129,17 @@ stage('JIT Unittest') { } } }, - 'GPU-G5-Test-4': { + 'JIT-Unittest-3': { + node('GPU-G5-SPOT') { + ws(per_exec_ws('flashinfer-unittest')) { + init_git(true) // we need cutlass submodule + sh(script: "ls -alh", label: 'Show work directory') + sh(script: "./scripts/task_show_node_info.sh", label: 'Show node info') + sh(script: "${docker_run} ./scripts/task_jit_run_tests_part3.sh", label: 'JIT Unittest Part 3') + } + } + }, + 'JIT-Unittest-4': { node('GPU-G5-SPOT') { ws(per_exec_ws('flashinfer-unittest')) { init_git(true) // we need cutlass submodule diff --git a/benchmarks/bench_sampling.py b/benchmarks/bench_sampling.py new file mode 100644 index 000000000..3eba949c9 --- /dev/null +++ b/benchmarks/bench_sampling.py @@ -0,0 +1,144 @@ +import torch +from triton.testing import do_bench + +import flashinfer + + +def normal_distribution(std): + def normal_noise(shape, device): + return torch.randn(shape, device=device) * std + + normal_noise.__name__ = f"normal_distribution(std={std})" + return normal_noise + + +def gumbel_distribution(beta): + def gumbel_noise(shape, device): + U = torch.rand(shape, device=device) + eps = 1e-20 + return torch.log(-torch.log(U + eps) + eps) / beta + + gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})" + return gumbel_noise + + +def init_seed_sampling(*args, **kwargs): + torch.manual_seed(42) + return flashinfer.sampling.sampling_from_probs(*args, **kwargs) + + +def init_seed_top_k_sampling(*args, **kwargs): + torch.manual_seed(42) + return flashinfer.sampling.top_k_sampling_from_probs(*args, **kwargs) + + +def init_seed_top_p_sampling(*args, **kwargs): + torch.manual_seed(42) + return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs) + + +@torch.inference_mode() +def main(): + print("---") + print("naive sampling") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for deterministic in [True, False]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + samples = torch.zeros( + batch_size, dtype=torch.int32, device=probs.device + ) + ms = do_bench( + lambda: init_seed_sampling(probs, deterministic=deterministic), + warmup=100, + rep=1000, + ) + + io = ( + probs.numel() * probs.element_size() + + samples.numel() * samples.element_size() + ) + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-k sampling") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for deterministic in [True, False]: + for k in [10, 100, 1000, 5000]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + samples = torch.zeros( + batch_size, dtype=torch.int32, device=probs.device + ) + ms = do_bench( + lambda: init_seed_top_k_sampling( + probs, k, deterministic=deterministic + ), + warmup=100, + rep=1000, + ) + + io = ( + probs.numel() * probs.element_size() + + samples.numel() * samples.element_size() + ) + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-p sampling") + + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for deterministic in [True, False]: + for p in [0.1, 0.5, 0.9]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + samples = torch.zeros( + batch_size, dtype=torch.int32, device=probs.device + ) + ms = do_bench( + lambda: init_seed_top_p_sampling( + probs, p, deterministic=deterministic + ), + warmup=100, + rep=1000, + ) + + io = ( + probs.numel() * probs.element_size() + + samples.numel() * samples.element_size() + ) + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + +if __name__ == "__main__": + main() diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 61243056c..56449716a 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -362,6 +362,7 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind float aggregate(0); float u = curand_uniform(&state); +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -405,14 +406,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* reinterpret_cast&>( smem_sampling); - float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); - vec_t probs_vec; float aggregate; float q = 1; - double low = 0, high = max_val; + double low = 0, high = 1.f; int sampled_id; int round = 0; do { @@ -421,6 +418,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* __syncthreads(); float u = curand_uniform(&state) * q; aggregate = 0; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -446,6 +444,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* double pivot_1 = (pivot_0 + high) / 2; ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -522,20 +521,17 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* reinterpret_cast&>( smem_sampling); - float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); - vec_t probs_vec; float aggregate; float q = 1; - double low = 0, high = max_val; + double low = 0, high = 1.f; int sampled_id; do { temp_storage.sampled_id = d; __syncthreads(); float u = curand_uniform(&state) * q; aggregate = 0; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -561,6 +557,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* double pivot_1 = (pivot_0 + high) / 2; float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -637,6 +634,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp vec_t probs_vec; float aggregate_gt_pivot = 0; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -664,6 +662,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp temp_storage.sampled_id = d; __syncthreads(); float u = curand_uniform(&state) * q; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -709,20 +708,17 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, reinterpret_cast&>( smem_sampling); - float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); - vec_t probs_vec; float aggregate; float q = 1; - double low = 0, high = max_val; + double low = 0, high = 1.f; int sampled_id; do { temp_storage.sampled_id = d; __syncthreads(); float u = curand_uniform(&state) * q; aggregate = 0; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -748,6 +744,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, double pivot_1 = (pivot_0 + high) / 2; ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -988,6 +985,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* double mid = (low + high) / 2; min_gt_low = high; max_le_high = low; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1034,6 +1032,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* float normalizer = math::ptx_rcp(max(sum_low, 1e-8)); // normalize +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1085,6 +1084,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType double mid = (low + high) / 2; min_gt_low = high; max_le_high = low; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1132,6 +1132,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType } // masking +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1185,6 +1186,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* double mid = (low + high) / 2; min_gt_low = high; max_le_high = low; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1236,6 +1238,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* } // normalize +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1372,6 +1375,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token float sum_relu_q_minus_p = 0; vec_t q_vec, p_vec; float relu_q_minus_p[VEC_SIZE]; +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { q_vec.fill(0); p_vec.fill(0); @@ -1403,6 +1407,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token float u = curand_uniform(&curand_state) * sum_relu_q_minus_p; float aggregate_relu_q_minus_p(0); +#pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { q_vec.fill(0); p_vec.fill(0); diff --git a/scripts/task_jit_run_tests_part2.sh b/scripts/task_jit_run_tests_part2.sh index e1dea434f..0c64e711e 100755 --- a/scripts/task_jit_run_tests_part2.sh +++ b/scripts/task_jit_run_tests_part2.sh @@ -14,4 +14,3 @@ pytest -s tests/test_norm.py pytest -s tests/test_rope.py pytest -s tests/test_mla_page.py pytest -s tests/test_quantization.py -# pytest -s tests/test_sampling.py diff --git a/scripts/task_jit_run_tests_part3.sh b/scripts/task_jit_run_tests_part3.sh new file mode 100755 index 000000000..322338e1a --- /dev/null +++ b/scripts/task_jit_run_tests_part3.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -eo pipefail +set -x +: ${MAX_JOBS:=$(nproc)} +: ${CUDA_VISIBLE_DEVICES:=0} + +pip install -e . -v + +pytest -s tests/test_sampling.py diff --git a/scripts/task_test_aot_build_import.sh b/scripts/task_test_aot_build_import.sh new file mode 100755 index 000000000..e487cbb1d --- /dev/null +++ b/scripts/task_test_aot_build_import.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +set -eo pipefail +set -x +: ${MAX_JOBS:=$(nproc)} +: ${CUDA_VISIBLE_DEVICES:=""} +export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0+PTX" +export FLASHINFER_ENABLE_AOT=1 + +python -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)" +python -m build --no-isolation --wheel +pip install dist/*.whl + +# test import +mkdir -p tmp +cd tmp +python -c "import flashinfer.flashinfer_kernels" +python -c "import flashinfer.flashinfer_kernels_sm90" +python -c "import flashinfer" diff --git a/setup.py b/setup.py index 7e72d0f94..6aa2bc2d2 100644 --- a/setup.py +++ b/setup.py @@ -169,6 +169,10 @@ def __init__(self, *args, **kwargs) -> None: if arch < 75: raise RuntimeError("FlashInfer requires sm75+") + if os.environ.get("FLASHINFER_USE_CXX11_ABI"): + # force use cxx11 abi + torch._C._GLIBCXX_USE_CXX11_ABI = 1 + cuda_version = get_cuda_version() torch_full_version = Version(torch.__version__) torch_version = f"{torch_full_version.major}.{torch_full_version.minor}" diff --git a/tests/rope_reference.py b/tests/rope_reference.py index a3ab74c2b..82df2e848 100644 --- a/tests/rope_reference.py +++ b/tests/rope_reference.py @@ -39,9 +39,15 @@ def apply_scaling(freqs: torch.Tensor): def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False + dim: int, + end: int, + theta: float = 10000.0, + use_scaled: bool = False, + device: str = "cuda:0", ): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) + ) t = torch.arange(end, device=freqs.device, dtype=torch.float32) if use_scaled: freqs = apply_scaling(freqs) @@ -86,10 +92,15 @@ def rotate_half(x): def generate_cos_sin_f32_cache( - max_seq_len, head_dim, theta=1e4, use_scaled: bool = False + max_seq_len, head_dim, theta=1e4, use_scaled: bool = False, device: str = "cuda:0" ): - position = torch.arange(max_seq_len).float().unsqueeze(1) - freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + position = torch.arange(max_seq_len, device=device, dtype=torch.float32).unsqueeze( + 1 + ) + freqs = 1.0 / ( + theta + ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim) + ) freqs = torch.cat([freqs, freqs], dim=-1).contiguous() if use_scaled: freqs = apply_scaling(freqs) @@ -112,6 +123,7 @@ def __init__( base: int, is_neox_style: bool, dtype: torch.dtype, + device: str = "cuda:0", ) -> None: super().__init__() self.head_size = head_size @@ -120,7 +132,7 @@ def __init__( self.base = base self.is_neox_style = is_neox_style self.dtype = dtype - + self.device = device cache = self._compute_cos_sin_cache() self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -129,7 +141,10 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: inv_freq = 1.0 / ( base ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device=self.device + ) + / self.rotary_dim ) ) return inv_freq @@ -137,7 +152,9 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) + t = torch.arange( + self.max_position_embeddings, dtype=torch.float, device=self.device + ) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() diff --git a/tests/test_rope.py b/tests/test_rope.py index 20e5ac12a..4e0c40b1c 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -61,11 +61,11 @@ def test_rope( # reference implementation if llama_version == "llama": freqs_cis = precompute_freqs_cis( - rotary_dim, qkv_len + offset, 10000.0, use_scaled=False + rotary_dim, qkv_len + offset, 10000.0, use_scaled=False, device="cuda:0" ).to("cuda:0") else: freqs_cis = precompute_freqs_cis( - rotary_dim, qkv_len + offset, 5e5, use_scaled=True + rotary_dim, qkv_len + offset, 5e5, use_scaled=True, device="cuda:0" ).to("cuda:0") q_rot_ref, k_rot_ref = apply_rotary_emb( q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[..., :rotary_dim], @@ -315,11 +315,23 @@ def test_rope_cos_sin_cache( num_kv_heads: int, ): rope_ref = RotaryEmbedding( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype - ).to(device) + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + ) rope_flashinfer = FlashInferRotaryEmbedding( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype - ).to(device) + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + ) pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) query = torch.randn( diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 1d8548166..d4cc0e31a 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -38,20 +38,19 @@ def gumbel_noise(shape, device): return gumbel_noise -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", [ normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1), - gumbel_distribution(1), ], ) @pytest.mark.parametrize("zero_ratio", [0.0, 0.5, 0.9]) def test_sampling_freq(vocab_size, distribution, zero_ratio): torch.manual_seed(42) - num_trials = 1000000 + num_trials = 5000000 logits = distribution((1, vocab_size), "cuda:0") zero_indices = torch.randperm(vocab_size)[: int(vocab_size * zero_ratio)] logits[:, zero_indices] = -float("inf") @@ -69,18 +68,18 @@ def test_sampling_freq(vocab_size, distribution, zero_ratio): assert similarity > 0.99, f"similarity: {similarity}" -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", [ normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1), - gumbel_distribution(1), ], ) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_sampling_freq(vocab_size, distribution, p): + # use torch profiler to check the performance of the code torch.manual_seed(42) logits = distribution((1, vocab_size), "cuda:0") probs = torch.softmax(logits, dim=-1) @@ -91,7 +90,7 @@ def test_top_p_sampling_freq(vocab_size, distribution, p): renorm_probs = flashinfer.sampling.top_p_renorm_probs(probs, p) counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device) - num_trials = 1000000 + num_trials = 5000000 samples = flashinfer.sampling.top_p_sampling_from_probs( probs, p, @@ -104,14 +103,13 @@ def test_top_p_sampling_freq(vocab_size, distribution, p): assert similarity > 0.99, f"similarity: {similarity}" -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", [ normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1), - gumbel_distribution(1), ], ) @pytest.mark.parametrize("k", [10, 100, 500]) @@ -127,7 +125,7 @@ def test_top_k_sampling_freq(vocab_size, distribution, k): renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k) counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device) - num_trials = 1000000 + num_trials = 5000000 samples = flashinfer.sampling.top_k_sampling_from_probs( probs, k, @@ -140,8 +138,8 @@ def test_top_k_sampling_freq(vocab_size, distribution, k): assert similarity > 0.99, f"similarity: {similarity}" -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) def test_sampling(batch_size, vocab_size): torch.manual_seed(42) pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") @@ -153,8 +151,8 @@ def test_sampling(batch_size, vocab_size): assert torch.all(samples < vocab_size) and torch.all(samples >= 0) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_sampling(batch_size, vocab_size, p): torch.manual_seed(42) @@ -173,8 +171,8 @@ def test_top_p_sampling(batch_size, vocab_size, p): assert torch.all(mask[torch.arange(batch_size), samples] == 1) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) def test_top_k_sampling(batch_size, vocab_size, k): if k > vocab_size: @@ -195,8 +193,8 @@ def test_top_k_sampling(batch_size, vocab_size, k): ] -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) def test_min_p_sampling(batch_size, vocab_size, p): torch.manual_seed(42) @@ -223,8 +221,8 @@ def test_min_p_sampling(batch_size, vocab_size, p): ] -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): torch.manual_seed(42) @@ -265,8 +263,8 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): ] -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [100]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size, k, p): @@ -287,8 +285,8 @@ def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size assert torch.all(samples == samples_ref) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p): torch.manual_seed(42) @@ -316,8 +314,8 @@ def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p): assert torch.all(samples == samples_ref) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_renorm_probs(batch_size, vocab_size, p): torch.manual_seed(42) @@ -342,8 +340,8 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p): ) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) def test_top_k_renorm_probs(batch_size, vocab_size, k): if k > vocab_size: @@ -370,8 +368,8 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k): ) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) def test_top_k_mask_logits(batch_size, vocab_size, k): if k > vocab_size: @@ -391,8 +389,8 @@ def test_top_k_mask_logits(batch_size, vocab_size, k): ) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7]) @pytest.mark.parametrize("onehot_target", [False, True]) def test_chain_speculative_sampling( @@ -472,15 +470,15 @@ def test_chain_speculative_sampling( if __name__ == "__main__": - test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) + # test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) test_top_p_sampling_freq(128256, gumbel_distribution(0.1), 0.5) - test_top_k_sampling_freq(1, 128256, 10) - test_sampling(19, 500) - test_sampling(1, 111) - test_top_p_sampling(3, 111, 0.9) - test_top_k_sampling(3, 111, 10) - test_top_p_renorm_probs(3, 111, 0.9) - test_top_k_renorm_probs(3, 111, 10) - test_top_k_mask_logits(99, 989, 10) - test_chain_speculative_sampling(3, 111, 3, False) - test_chain_speculative_sampling(3, 111, 3, True) + # test_top_k_sampling_freq(1, 128256, 10) + # test_sampling(19, 500) + # test_sampling(1, 111) + # test_top_p_sampling(3, 111, 0.9) + # test_top_k_sampling(3, 111, 10) + # test_top_p_renorm_probs(3, 111, 0.9) + # test_top_k_renorm_probs(3, 111, 10) + # test_top_k_mask_logits(99, 989, 10) + # test_chain_speculative_sampling(3, 111, 3, False) + # test_chain_speculative_sampling(3, 111, 3, True)