Skip to content

Commit cf7a7d4

Browse files
committed
upd
1 parent fb16238 commit cf7a7d4

File tree

2 files changed

+98
-93
lines changed

2 files changed

+98
-93
lines changed

include/flashinfer/attention/cascade.cuh

Lines changed: 83 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -158,39 +158,42 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
158158
uint32_t pos = blockIdx.x;
159159
uint32_t head_idx = ty;
160160

161-
if (num_index_sets > 1) {
162-
vec_t<float, vec_size> v_merged_vec;
163-
state_t<vec_size> st;
164-
v_merged_vec.fill(0.f);
165-
#pragma unroll 2
166-
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
167-
float s = S[(pos * num_index_sets + iter) * num_heads + head_idx];
168-
vec_t<float, vec_size> v;
169-
v.cast_load(V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
170-
tx * vec_size);
171-
st.merge(v, s, 1);
172-
}
173-
174-
st.normalize();
175-
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
176-
if (s_merged != nullptr) {
177-
s_merged[pos * num_heads + head_idx] = st.get_lse();
178-
}
179-
} else if (num_index_sets == 1) {
161+
if (num_index_sets == 0) {
180162
vec_t<DTypeOut, vec_size> v;
181-
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
163+
v.fill(DTypeOut(0));
182164
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
183165
if (s_merged != nullptr) {
184-
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
166+
s_merged[pos * num_heads + head_idx] = -5e4;
185167
}
186-
} else {
187-
// num_index_sets == 0
168+
return;
169+
}
170+
171+
if (num_index_sets == 1) {
188172
vec_t<DTypeOut, vec_size> v;
189-
v.fill(DTypeOut(0));
173+
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
190174
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
191175
if (s_merged != nullptr) {
192-
s_merged[pos * num_heads + head_idx] = -5e4;
176+
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
193177
}
178+
return;
179+
}
180+
181+
vec_t<float, vec_size> v_merged_vec;
182+
state_t<vec_size> st;
183+
v_merged_vec.fill(0.f);
184+
#pragma unroll 2
185+
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
186+
float s = S[(pos * num_index_sets + iter) * num_heads + head_idx];
187+
vec_t<float, vec_size> v;
188+
v.cast_load(V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
189+
tx * vec_size);
190+
st.merge(v, s, 1);
191+
}
192+
193+
st.normalize();
194+
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
195+
if (s_merged != nullptr) {
196+
s_merged[pos * num_heads + head_idx] = st.get_lse();
194197
}
195198
}
196199

@@ -313,68 +316,70 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
313316
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
314317
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];
315318

316-
if (num_index_sets > 1) {
317-
#pragma unroll
318-
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
319-
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
320-
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
321-
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
322-
(iter * bdy + ty) < num_index_sets);
323-
cp_async::commit_group();
324-
}
325-
#pragma unroll 4
326-
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
327-
if (iter % bdx == 0) {
328-
s_smem[ty * bdx + tx] =
329-
iter * bdy + (ty * bdx + tx) < num_index_sets
330-
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
331-
: 0.f;
332-
__syncthreads();
333-
}
334-
cp_async::wait_group<num_smem_stages - 1>();
335-
__syncthreads();
336-
vec_t<float, vec_size> v;
337-
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
338-
if (iter * bdy + ty < num_index_sets) {
339-
float s = s_smem[(iter % bdx) * bdy + ty];
340-
st.merge(v, s, 1);
341-
}
342-
__syncthreads();
343-
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
344-
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
345-
V +
346-
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
347-
head_dim +
348-
tx * vec_size,
349-
(iter + num_smem_stages) * bdy + ty < num_index_sets);
350-
cp_async::commit_group();
351-
}
352-
cp_async::wait_group<0>();
353-
__syncthreads();
354-
355-
st.normalize();
356-
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
357-
st.normalize();
358-
359-
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
319+
if (num_index_sets == 0) {
320+
vec_t<DTypeOut, vec_size> v;
321+
v.fill(DTypeOut(0));
322+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
360323
if (s_merged != nullptr) {
361-
s_merged[pos * num_heads + head_idx] = st.get_lse();
324+
s_merged[pos * num_heads + head_idx] = -5e4;
362325
}
363-
} else if (num_index_sets == 1) {
326+
return;
327+
}
328+
329+
if (num_index_sets == 1) {
364330
vec_t<DTypeOut, vec_size> v;
365331
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
366332
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
367333
if (s_merged != nullptr) {
368334
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
369335
}
370-
} else {
371-
// num_index_sets == 0
372-
vec_t<DTypeOut, vec_size> v;
373-
v.fill(DTypeOut(0));
374-
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
375-
if (s_merged != nullptr) {
376-
s_merged[pos * num_heads + head_idx] = -5e4;
336+
}
337+
338+
#pragma unroll
339+
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
340+
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
341+
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
342+
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
343+
(iter * bdy + ty) < num_index_sets);
344+
cp_async::commit_group();
345+
}
346+
#pragma unroll 4
347+
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
348+
if (iter % bdx == 0) {
349+
s_smem[ty * bdx + tx] =
350+
iter * bdy + (ty * bdx + tx) < num_index_sets
351+
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
352+
: 0.f;
353+
__syncthreads();
377354
}
355+
cp_async::wait_group<num_smem_stages - 1>();
356+
__syncthreads();
357+
vec_t<float, vec_size> v;
358+
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
359+
if (iter * bdy + ty < num_index_sets) {
360+
float s = s_smem[(iter % bdx) * bdy + ty];
361+
st.merge(v, s, 1);
362+
}
363+
__syncthreads();
364+
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
365+
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
366+
V +
367+
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
368+
head_dim +
369+
tx * vec_size,
370+
(iter + num_smem_stages) * bdy + ty < num_index_sets);
371+
cp_async::commit_group();
372+
}
373+
cp_async::wait_group<0>();
374+
__syncthreads();
375+
376+
st.normalize();
377+
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
378+
st.normalize();
379+
380+
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
381+
if (s_merged != nullptr) {
382+
s_merged[pos * num_heads + head_idx] = st.get_lse();
378383
}
379384
}
380385

src/test_cascade.cu

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,21 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
178178
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);
179179

180180
if (num_index_sets > 1) {
181-
// Method 0: use MergeState
182-
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
183-
thrust::raw_pointer_cast(S_device_trans.data()),
184-
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
185-
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
186-
thrust::raw_pointer_cast(V_merged_0_device.data()),
187-
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
188-
for (uint i = 2; i < num_index_sets; ++i) {
189-
MergeStateInPlace(
190-
thrust::raw_pointer_cast(V_merged_0_device.data()),
191-
thrust::raw_pointer_cast(S_merged_0_device.data()),
192-
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
193-
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
194-
num_heads, head_dim);
195-
}
181+
// Method 0: use MergeState
182+
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
183+
thrust::raw_pointer_cast(S_device_trans.data()),
184+
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
185+
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
186+
thrust::raw_pointer_cast(V_merged_0_device.data()),
187+
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
188+
for (uint i = 2; i < num_index_sets; ++i) {
189+
MergeStateInPlace(
190+
thrust::raw_pointer_cast(V_merged_0_device.data()),
191+
thrust::raw_pointer_cast(S_merged_0_device.data()),
192+
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
193+
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
194+
num_heads, head_dim);
195+
}
196196
} else {
197197
V_merged_0_device = V_device;
198198
S_merged_0_device = S_device;

0 commit comments

Comments
 (0)