@@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
136
136
template <typename T>
137
137
void _TestMergeKernelCorrectness (size_t num_index_sets, size_t seq_len, size_t num_heads,
138
138
size_t head_dim, bool sparse_s) {
139
- EXPECT_GT (num_index_sets, 1 ) << " num_index_sets must be greater than 1" ;
140
139
std::vector<T> V_host (seq_len * num_index_sets * num_heads * head_dim);
141
140
std::vector<float > V_host_trans_f32 (num_index_sets * seq_len * num_heads * head_dim);
142
141
std::vector<float > S_host (seq_len * num_index_sets * num_heads);
@@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
178
177
thrust::device_vector<T> V_merged_1_device (seq_len * num_heads * head_dim);
179
178
thrust::device_vector<float > S_merged_1_device (seq_len * num_heads);
180
179
181
- // Method 0: use MergeState
182
- MergeState (thrust::raw_pointer_cast (V_device_trans_f32.data ()),
183
- thrust::raw_pointer_cast (S_device_trans.data ()),
184
- thrust::raw_pointer_cast (V_device_trans_f32.data () + seq_len * num_heads * head_dim),
185
- thrust::raw_pointer_cast (S_device_trans.data () + seq_len * num_heads),
186
- thrust::raw_pointer_cast (V_merged_0_device.data ()),
187
- thrust::raw_pointer_cast (S_merged_0_device.data ()), seq_len, num_heads, head_dim);
188
- for (uint i = 2 ; i < num_index_sets; ++i) {
189
- MergeStateInPlace (
190
- thrust::raw_pointer_cast (V_merged_0_device.data ()),
191
- thrust::raw_pointer_cast (S_merged_0_device.data ()),
192
- thrust::raw_pointer_cast (V_device_trans_f32.data () + i * seq_len * num_heads * head_dim),
193
- thrust::raw_pointer_cast (S_device_trans.data () + i * seq_len * num_heads), seq_len,
194
- num_heads, head_dim);
180
+ if (num_index_sets > 1 ) {
181
+ // Method 0: use MergeState
182
+ MergeState (thrust::raw_pointer_cast (V_device_trans_f32.data ()),
183
+ thrust::raw_pointer_cast (S_device_trans.data ()),
184
+ thrust::raw_pointer_cast (V_device_trans_f32.data () + seq_len * num_heads * head_dim),
185
+ thrust::raw_pointer_cast (S_device_trans.data () + seq_len * num_heads),
186
+ thrust::raw_pointer_cast (V_merged_0_device.data ()),
187
+ thrust::raw_pointer_cast (S_merged_0_device.data ()), seq_len, num_heads, head_dim);
188
+ for (uint i = 2 ; i < num_index_sets; ++i) {
189
+ MergeStateInPlace (
190
+ thrust::raw_pointer_cast (V_merged_0_device.data ()),
191
+ thrust::raw_pointer_cast (S_merged_0_device.data ()),
192
+ thrust::raw_pointer_cast (V_device_trans_f32.data () + i * seq_len * num_heads * head_dim),
193
+ thrust::raw_pointer_cast (S_device_trans.data () + i * seq_len * num_heads), seq_len,
194
+ num_heads, head_dim);
195
+ }
196
+ } else {
197
+ V_merged_0_device = V_device;
198
+ S_merged_0_device = S_device;
195
199
}
196
200
197
201
// Method 1: use MergeStates
@@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,
479
483
480
484
template <typename T>
481
485
void TestMergeKernelCorrectness () {
482
- for (size_t num_index_sets : {2 , 9 , 81 , 513 }) {
486
+ for (size_t num_index_sets : {1 , 2 , 9 , 81 , 513 }) {
483
487
for (size_t seq_len : {4 , 16 , 77 }) {
484
488
for (size_t num_heads : {1 , 21 , 32 }) {
485
489
for (size_t head_dim : {64 , 128 , 256 }) {
0 commit comments