@@ -32,8 +32,11 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n
32
32
bool use_fp16_qk_reduction) {
33
33
uint32_t batch_size = 9 ;
34
34
std::vector<int32_t > q_lens (batch_size), kv_lens (batch_size);
35
- utils::vec_randint_ (q_lens, 1 , 15 );
36
- utils::vec_randint_ (kv_lens, 15 , 257 );
35
+ // utils::vec_randint_(q_lens, 1, 15);
36
+ // utils::vec_randint_(kv_lens, 15, 257);
37
+ q_lens = {21 , 20 , 40 , 4 , 8 , 99 };
38
+ kv_lens = {21 , 1024 , 8072 , 30 , 27 , 999 };
39
+
37
40
std::vector<int32_t > append_indptr{0 };
38
41
for (size_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
39
42
append_indptr.push_back (append_indptr.back () + kv_lens[request_idx]);
@@ -132,7 +135,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n
132
135
nan_detected = true ;
133
136
}
134
137
num_result_errors_atol_1e_3_rtol_1e_3 +=
135
- (!utils::isclose (float (o_host[i]), float (o_ref[i]), 1e-3 , 1e-3 ));
138
+ (!utils::isclose (float (o_host[i]), float (o_ref[i]), 1e-3 , 5e-4 ));
136
139
}
137
140
float result_accuracy = 1 . - float (num_result_errors_atol_1e_3_rtol_1e_3) /
138
141
max (float (q_len * num_qo_heads * head_dim), 1 .f );
@@ -601,12 +604,12 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz
601
604
602
605
template <typename T>
603
606
void TestBatchPagedPrefillKernelOneHotCorrectness (bool use_fp16_qk_reduction) {
604
- for (size_t num_kv_heads : {4 , 8 , 32 }) {
605
- for (size_t num_qo_heads : {32 }) {
606
- for (size_t page_size : {1 , 16 }) {
607
- for (size_t head_dim : {64 , 128 , 256 }) {
608
- for (size_t causal : {false , true }) {
609
- for (size_t pos_encoding_mode : {0 , 1 }) {
607
+ for (size_t num_kv_heads : {1 }) {
608
+ for (size_t num_qo_heads : {8 }) {
609
+ for (size_t page_size : {16 }) {
610
+ for (size_t head_dim : {128 }) {
611
+ for (size_t causal : {true }) {
612
+ for (size_t pos_encoding_mode : {0 }) {
610
613
_TestBatchPagedPrefillKernelOneHotCorrectness<T, T>(
611
614
num_kv_heads, num_qo_heads, page_size, head_dim, causal,
612
615
PosEncodingMode (pos_encoding_mode), use_fp16_qk_reduction);
@@ -771,6 +774,10 @@ TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16QKHalfAccum)
771
774
TestBatchPagedPrefillKernelLongContextCorrectness<half>(true );
772
775
}
773
776
777
+ TEST (FlashInferCorrectnessTest, BatchPagedPrefillKernelCorrectnessTestOneHotBF16) {
778
+ TestBatchPagedPrefillKernelOneHotCorrectness<nv_bfloat16>(false );
779
+ }
780
+
774
781
TEST (FlashInferCorrectnessTest, BatchPagedPrefillKernelCorrectnessTestOneHotFP16) {
775
782
TestBatchPagedPrefillKernelOneHotCorrectness<half>(false );
776
783
}
0 commit comments