Skip to content

Commit 701c813

Browse files
authored
perf: slight optimization on merge states (#313)
When cudagraph is enabled, we will still call merge states kernels for short sequence length, which incurs some unnecessary overhead. This PR accelerates merge states kernel when there is nothing to merge (`num_index_sets=1`). We can actually write through to the target buffer for small sequence length, but I'm always lazy evaluated and I'll leave it for a future PR (if necessary).
1 parent 2ab2bca commit 701c813

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

Diff for: include/flashinfer/attention/cascade.cuh

+40-1
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,29 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
157157
uint32_t tx = threadIdx.x, ty = threadIdx.y;
158158
uint32_t pos = blockIdx.x;
159159
uint32_t head_idx = ty;
160-
state_t<vec_size> st;
160+
161+
if (num_index_sets == 0) {
162+
vec_t<DTypeOut, vec_size> v;
163+
v.fill(DTypeOut(0));
164+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
165+
if (s_merged != nullptr) {
166+
s_merged[pos * num_heads + head_idx] = -5e4;
167+
}
168+
return;
169+
}
170+
171+
if (num_index_sets == 1) {
172+
vec_t<DTypeOut, vec_size> v;
173+
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
174+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
175+
if (s_merged != nullptr) {
176+
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
177+
}
178+
return;
179+
}
161180

162181
vec_t<float, vec_size> v_merged_vec;
182+
state_t<vec_size> st;
163183
v_merged_vec.fill(0.f);
164184
#pragma unroll 2
165185
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
@@ -296,6 +316,25 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
296316
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
297317
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];
298318

319+
if (num_index_sets == 0) {
320+
vec_t<DTypeOut, vec_size> v;
321+
v.fill(DTypeOut(0));
322+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
323+
if (s_merged != nullptr) {
324+
s_merged[pos * num_heads + head_idx] = -5e4;
325+
}
326+
return;
327+
}
328+
329+
if (num_index_sets == 1) {
330+
vec_t<DTypeOut, vec_size> v;
331+
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
332+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
333+
if (s_merged != nullptr) {
334+
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
335+
}
336+
}
337+
299338
#pragma unroll
300339
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
301340
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(

Diff for: src/test_cascade.cu

+20-16
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
136136
template <typename T>
137137
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
138138
size_t head_dim, bool sparse_s) {
139-
EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1";
140139
std::vector<T> V_host(seq_len * num_index_sets * num_heads * head_dim);
141140
std::vector<float> V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim);
142141
std::vector<float> S_host(seq_len * num_index_sets * num_heads);
@@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
178177
thrust::device_vector<T> V_merged_1_device(seq_len * num_heads * head_dim);
179178
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);
180179

181-
// Method 0: use MergeState
182-
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
183-
thrust::raw_pointer_cast(S_device_trans.data()),
184-
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
185-
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
186-
thrust::raw_pointer_cast(V_merged_0_device.data()),
187-
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
188-
for (uint i = 2; i < num_index_sets; ++i) {
189-
MergeStateInPlace(
190-
thrust::raw_pointer_cast(V_merged_0_device.data()),
191-
thrust::raw_pointer_cast(S_merged_0_device.data()),
192-
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
193-
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
194-
num_heads, head_dim);
180+
if (num_index_sets > 1) {
181+
// Method 0: use MergeState
182+
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
183+
thrust::raw_pointer_cast(S_device_trans.data()),
184+
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
185+
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
186+
thrust::raw_pointer_cast(V_merged_0_device.data()),
187+
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
188+
for (uint i = 2; i < num_index_sets; ++i) {
189+
MergeStateInPlace(
190+
thrust::raw_pointer_cast(V_merged_0_device.data()),
191+
thrust::raw_pointer_cast(S_merged_0_device.data()),
192+
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
193+
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
194+
num_heads, head_dim);
195+
}
196+
} else {
197+
V_merged_0_device = V_device;
198+
S_merged_0_device = S_device;
195199
}
196200

197201
// Method 1: use MergeStates
@@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,
479483

480484
template <typename T>
481485
void TestMergeKernelCorrectness() {
482-
for (size_t num_index_sets : {2, 9, 81, 513}) {
486+
for (size_t num_index_sets : {1, 2, 9, 81, 513}) {
483487
for (size_t seq_len : {4, 16, 77}) {
484488
for (size_t num_heads : {1, 21, 32}) {
485489
for (size_t head_dim : {64, 128, 256}) {

0 commit comments

Comments
 (0)