Skip to content

Commit af31796

Browse files
committed
Add tests for BatchPrefill
1 parent d7cf5d2 commit af31796

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

src/test_batch_prefill.cu

+16-9
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n
3232
bool use_fp16_qk_reduction) {
3333
uint32_t batch_size = 9;
3434
std::vector<int32_t> q_lens(batch_size), kv_lens(batch_size);
35-
utils::vec_randint_(q_lens, 1, 15);
36-
utils::vec_randint_(kv_lens, 15, 257);
35+
// utils::vec_randint_(q_lens, 1, 15);
36+
// utils::vec_randint_(kv_lens, 15, 257);
37+
q_lens = {21, 20, 40, 4, 8, 99};
38+
kv_lens = {21, 1024, 8072, 30, 27, 999};
39+
3740
std::vector<int32_t> append_indptr{0};
3841
for (size_t request_idx = 0; request_idx < batch_size; ++request_idx) {
3942
append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]);
@@ -132,7 +135,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n
132135
nan_detected = true;
133136
}
134137
num_result_errors_atol_1e_3_rtol_1e_3 +=
135-
(!utils::isclose(float(o_host[i]), float(o_ref[i]), 1e-3, 1e-3));
138+
(!utils::isclose(float(o_host[i]), float(o_ref[i]), 1e-3, 5e-4));
136139
}
137140
float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) /
138141
max(float(q_len * num_qo_heads * head_dim), 1.f);
@@ -601,12 +604,12 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz
601604

602605
template <typename T>
603606
void TestBatchPagedPrefillKernelOneHotCorrectness(bool use_fp16_qk_reduction) {
604-
for (size_t num_kv_heads : {4, 8, 32}) {
605-
for (size_t num_qo_heads : {32}) {
606-
for (size_t page_size : {1, 16}) {
607-
for (size_t head_dim : {64, 128, 256}) {
608-
for (size_t causal : {false, true}) {
609-
for (size_t pos_encoding_mode : {0, 1}) {
607+
for (size_t num_kv_heads : {1}) {
608+
for (size_t num_qo_heads : {8}) {
609+
for (size_t page_size : {16}) {
610+
for (size_t head_dim : {128}) {
611+
for (size_t causal : {true}) {
612+
for (size_t pos_encoding_mode : {0}) {
610613
_TestBatchPagedPrefillKernelOneHotCorrectness<T, T>(
611614
num_kv_heads, num_qo_heads, page_size, head_dim, causal,
612615
PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction);
@@ -771,6 +774,10 @@ TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16QKHalfAccum)
771774
TestBatchPagedPrefillKernelLongContextCorrectness<half>(true);
772775
}
773776

777+
TEST(FlashInferCorrectnessTest, BatchPagedPrefillKernelCorrectnessTestOneHotBF16) {
778+
TestBatchPagedPrefillKernelOneHotCorrectness<nv_bfloat16>(false);
779+
}
780+
774781
TEST(FlashInferCorrectnessTest, BatchPagedPrefillKernelCorrectnessTestOneHotFP16) {
775782
TestBatchPagedPrefillKernelOneHotCorrectness<half>(false);
776783
}

0 commit comments

Comments
 (0)