Skip to content

Commit a8bc999

Browse files
committed
upd
ok how come upd
1 parent 2ab2bca commit a8bc999

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

include/flashinfer/attention/cascade.cuh

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,29 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
157157
uint32_t tx = threadIdx.x, ty = threadIdx.y;
158158
uint32_t pos = blockIdx.x;
159159
uint32_t head_idx = ty;
160-
state_t<vec_size> st;
160+
161+
if (num_index_sets == 0) {
162+
vec_t<DTypeOut, vec_size> v;
163+
v.fill(DTypeOut(0));
164+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
165+
if (s_merged != nullptr) {
166+
s_merged[pos * num_heads + head_idx] = -5e4;
167+
}
168+
return;
169+
}
170+
171+
if (num_index_sets == 1) {
172+
vec_t<DTypeOut, vec_size> v;
173+
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
174+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
175+
if (s_merged != nullptr) {
176+
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
177+
}
178+
return;
179+
}
161180

162181
vec_t<float, vec_size> v_merged_vec;
182+
state_t<vec_size> st;
163183
v_merged_vec.fill(0.f);
164184
#pragma unroll 2
165185
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
@@ -296,6 +316,25 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
296316
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
297317
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];
298318

319+
if (num_index_sets == 0) {
320+
vec_t<DTypeOut, vec_size> v;
321+
v.fill(DTypeOut(0));
322+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
323+
if (s_merged != nullptr) {
324+
s_merged[pos * num_heads + head_idx] = -5e4;
325+
}
326+
return;
327+
}
328+
329+
if (num_index_sets == 1) {
330+
vec_t<DTypeOut, vec_size> v;
331+
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
332+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
333+
if (s_merged != nullptr) {
334+
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
335+
}
336+
}
337+
299338
#pragma unroll
300339
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
301340
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(

src/test_cascade.cu

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
136136
template <typename T>
137137
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
138138
size_t head_dim, bool sparse_s) {
139-
EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1";
140139
std::vector<T> V_host(seq_len * num_index_sets * num_heads * head_dim);
141140
std::vector<float> V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim);
142141
std::vector<float> S_host(seq_len * num_index_sets * num_heads);
@@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
178177
thrust::device_vector<T> V_merged_1_device(seq_len * num_heads * head_dim);
179178
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);
180179

181-
// Method 0: use MergeState
182-
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
183-
thrust::raw_pointer_cast(S_device_trans.data()),
184-
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
185-
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
186-
thrust::raw_pointer_cast(V_merged_0_device.data()),
187-
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
188-
for (uint i = 2; i < num_index_sets; ++i) {
189-
MergeStateInPlace(
190-
thrust::raw_pointer_cast(V_merged_0_device.data()),
191-
thrust::raw_pointer_cast(S_merged_0_device.data()),
192-
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
193-
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
194-
num_heads, head_dim);
180+
if (num_index_sets > 1) {
181+
// Method 0: use MergeState
182+
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
183+
thrust::raw_pointer_cast(S_device_trans.data()),
184+
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
185+
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
186+
thrust::raw_pointer_cast(V_merged_0_device.data()),
187+
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
188+
for (uint i = 2; i < num_index_sets; ++i) {
189+
MergeStateInPlace(
190+
thrust::raw_pointer_cast(V_merged_0_device.data()),
191+
thrust::raw_pointer_cast(S_merged_0_device.data()),
192+
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
193+
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
194+
num_heads, head_dim);
195+
}
196+
} else {
197+
V_merged_0_device = V_device;
198+
S_merged_0_device = S_device;
195199
}
196200

197201
// Method 1: use MergeStates
@@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,
479483

480484
template <typename T>
481485
void TestMergeKernelCorrectness() {
482-
for (size_t num_index_sets : {2, 9, 81, 513}) {
486+
for (size_t num_index_sets : {1, 2, 9, 81, 513}) {
483487
for (size_t seq_len : {4, 16, 77}) {
484488
for (size_t num_heads : {1, 21, 32}) {
485489
for (size_t head_dim : {64, 128, 256}) {

0 commit comments

Comments
 (0)