Skip to content

ci: improve jenkins #943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def unpack_lib(name, libs) {
"""
}

def cancel_previous_build() {
// cancel previous build if it is not on main.
if (env.BRANCH_NAME != 'main') {
def buildNumber = env.BUILD_NUMBER as int
// Milestone API allows us to cancel previous build
// with the same milestone number
if (buildNumber > 1) milestone(buildNumber - 1)
milestone(buildNumber)
}
}

def init_git(submodule = false) {
cleanWs()
// add retry in case checkout timeouts
Expand All @@ -84,10 +95,21 @@ def init_git(submodule = false) {
// }
// }

stage('JIT Unittest') {
stage('Unittest') {
cancel_previous_build()
parallel(
failFast: true,
'GPU-G5-Test-1': {
'AOT-Build-Import': {
node('CPU-LARGE-SPOT') {
ws(per_exec_ws('flashinfer-aot')) {
init_git(true)
sh(script: "ls -alh", label: 'Show work directory')
sh(script: "./scripts/task_show_node_info.sh", label: 'Show node info')
sh(script: "${docker_run} --no-gpu ./scripts/task_test_aot_build_import.sh", label: 'Test AOT Build and Import')
}
}
},
'JIT-Unittest-1': {
node('GPU-G5-SPOT') {
ws(per_exec_ws('flashinfer-unittest')) {
init_git(true) // we need cutlass submodule
Expand All @@ -97,7 +119,7 @@ stage('JIT Unittest') {
}
}
},
'GPU-G5-Test-2': {
'JIT-Unittest-2': {
node('GPU-G5-SPOT') {
ws(per_exec_ws('flashinfer-unittest')) {
init_git(true) // we need cutlass submodule
Expand All @@ -107,7 +129,17 @@ stage('JIT Unittest') {
}
}
},
'GPU-G5-Test-4': {
'JIT-Unittest-3': {
node('GPU-G5-SPOT') {
ws(per_exec_ws('flashinfer-unittest')) {
init_git(true) // we need cutlass submodule
sh(script: "ls -alh", label: 'Show work directory')
sh(script: "./scripts/task_show_node_info.sh", label: 'Show node info')
sh(script: "${docker_run} ./scripts/task_jit_run_tests_part3.sh", label: 'JIT Unittest Part 3')
}
}
},
'JIT-Unittest-4': {
node('GPU-G5-SPOT') {
ws(per_exec_ws('flashinfer-unittest')) {
init_git(true) // we need cutlass submodule
Expand Down
144 changes: 144 additions & 0 deletions benchmarks/bench_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
from triton.testing import do_bench

import flashinfer


def normal_distribution(std):
def normal_noise(shape, device):
return torch.randn(shape, device=device) * std

normal_noise.__name__ = f"normal_distribution(std={std})"
return normal_noise


def gumbel_distribution(beta):
def gumbel_noise(shape, device):
U = torch.rand(shape, device=device)
eps = 1e-20
return torch.log(-torch.log(U + eps) + eps) / beta

gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})"
return gumbel_noise


def init_seed_sampling(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.sampling_from_probs(*args, **kwargs)


def init_seed_top_k_sampling(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.top_k_sampling_from_probs(*args, **kwargs)


def init_seed_top_p_sampling(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs)


@torch.inference_mode()
def main():
print("---")
print("naive sampling")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
samples = torch.zeros(
batch_size, dtype=torch.int32, device=probs.device
)
ms = do_bench(
lambda: init_seed_sampling(probs, deterministic=deterministic),
warmup=100,
rep=1000,
)

io = (
probs.numel() * probs.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("top-k sampling")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
for k in [10, 100, 1000, 5000]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
samples = torch.zeros(
batch_size, dtype=torch.int32, device=probs.device
)
ms = do_bench(
lambda: init_seed_top_k_sampling(
probs, k, deterministic=deterministic
),
warmup=100,
rep=1000,
)

io = (
probs.numel() * probs.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("top-p sampling")

for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
for p in [0.1, 0.5, 0.9]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
samples = torch.zeros(
batch_size, dtype=torch.int32, device=probs.device
)
ms = do_bench(
lambda: init_seed_top_p_sampling(
probs, p, deterministic=deterministic
),
warmup=100,
rep=1000,
)

io = (
probs.numel() * probs.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)


if __name__ == "__main__":
main()
35 changes: 20 additions & 15 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind
float aggregate(0);
float u = curand_uniform(&state);

#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -405,14 +406,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = max_val;
double low = 0, high = 1.f;
int sampled_id;
int round = 0;
do {
Expand All @@ -421,6 +418,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
__syncthreads();
float u = curand_uniform(&state) * q;
aggregate = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand All @@ -446,6 +444,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
double pivot_1 = (pivot_0 + high) / 2;

ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -522,20 +521,17 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = max_val;
double low = 0, high = 1.f;
int sampled_id;
do {
temp_storage.sampled_id = d;
__syncthreads();
float u = curand_uniform(&state) * q;
aggregate = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand All @@ -561,6 +557,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
double pivot_1 = (pivot_0 + high) / 2;

float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -637,6 +634,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp

vec_t<float, VEC_SIZE> probs_vec;
float aggregate_gt_pivot = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -664,6 +662,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
temp_storage.sampled_id = d;
__syncthreads();
float u = curand_uniform(&state) * q;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -709,20 +708,17 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = max_val;
double low = 0, high = 1.f;
int sampled_id;
do {
temp_storage.sampled_id = d;
__syncthreads();
float u = curand_uniform(&state) * q;
aggregate = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand All @@ -748,6 +744,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
double pivot_1 = (pivot_0 + high) / 2;

ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -988,6 +985,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
double mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -1034,6 +1032,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
float normalizer = math::ptx_rcp(max(sum_low, 1e-8));

// normalize
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -1085,6 +1084,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
double mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
logits_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -1132,6 +1132,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
}

// masking
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
logits_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -1185,6 +1186,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
double mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -1236,6 +1238,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
}

// normalize
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand Down Expand Up @@ -1372,6 +1375,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
float sum_relu_q_minus_p = 0;
vec_t<float, VEC_SIZE> q_vec, p_vec;
float relu_q_minus_p[VEC_SIZE];
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
q_vec.fill(0);
p_vec.fill(0);
Expand Down Expand Up @@ -1403,6 +1407,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
float u = curand_uniform(&curand_state) * sum_relu_q_minus_p;

float aggregate_relu_q_minus_p(0);
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
q_vec.fill(0);
p_vec.fill(0);
Expand Down
Loading