Skip to content

Commit d81af97

Browse files
authored
feat: add a use_softmax field in variant class (#533)
If true, use online softmax rules to update attention states. If not, simply adding the partial attention results. In the future, we can generalize this as a state update function.
1 parent 7bb53d9 commit d81af97

File tree

3 files changed

+319
-222
lines changed

3 files changed

+319
-222
lines changed

include/flashinfer/attention/decode.cuh

+44-23
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,18 @@ __device__ __forceinline__ void compute_qk(const typename AttentionVariant::Para
9797
st.m = max(st.m, s[j]);
9898
}
9999

100-
float o_scale = math::ptx_exp2(m_prev - st.m);
101-
st.d *= o_scale;
100+
if constexpr (variant.use_softmax) {
101+
float o_scale = math::ptx_exp2(m_prev - st.m);
102+
st.d *= o_scale;
102103
#pragma unroll
103-
for (uint32_t j = 0; j < tile_size; ++j) {
104-
s[j] = math::ptx_exp2(s[j] - st.m);
105-
st.d += s[j];
106-
}
104+
for (uint32_t j = 0; j < tile_size; ++j) {
105+
s[j] = math::ptx_exp2(s[j] - st.m);
106+
st.d += s[j];
107+
}
107108
#pragma unroll
108-
for (uint32_t i = 0; i < vec_size; ++i) {
109-
st.o[i] = st.o[i] * o_scale;
109+
for (uint32_t i = 0; i < vec_size; ++i) {
110+
st.o[i] = st.o[i] * o_scale;
111+
}
110112
}
111113
}
112114

@@ -148,23 +150,38 @@ __device__ __forceinline__ void update_local_state(const T* smem, const float* s
148150
* \param smem The pointer to shared memory buffer for o
149151
* \param smem_md The pointer to shared memory buffer for m/d
150152
*/
151-
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz>
152-
__device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, float* smem_md) {
153+
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename AttentionVariant>
154+
__device__ __forceinline__ void sync_state(AttentionVariant variant, state_t<vec_size>& st,
155+
float* smem, float* smem_md) {
153156
if constexpr (bdz > 1) {
154157
constexpr uint32_t head_dim = bdx * vec_size;
155158
auto block = cg::this_thread_block();
156159
uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
157160
st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size);
158-
smem_md[(tz * bdy + ty) * 2] = st.m;
159-
smem_md[(tz * bdy + ty) * 2 + 1] = st.d;
160-
block.sync();
161-
st.init();
161+
if constexpr (variant.use_softmax) {
162+
smem_md[(tz * bdy + ty) * 2] = st.m;
163+
smem_md[(tz * bdy + ty) * 2 + 1] = st.d;
164+
block.sync();
165+
st.init();
162166
#pragma unroll
163-
for (uint32_t j = 0; j < bdz; ++j) {
164-
float mz = smem_md[(j * bdy + ty) * 2], dz = smem_md[(j * bdy + ty) * 2 + 1];
165-
vec_t<float, vec_size> oz;
166-
oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size);
167-
st.merge(oz, mz, dz);
167+
for (uint32_t j = 0; j < bdz; ++j) {
168+
float mz = smem_md[(j * bdy + ty) * 2], dz = smem_md[(j * bdy + ty) * 2 + 1];
169+
vec_t<float, vec_size> oz;
170+
oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size);
171+
st.merge(oz, mz, dz);
172+
}
173+
} else {
174+
block.sync();
175+
st.init();
176+
#pragma unroll
177+
for (uint32_t j = 0; j < bdz; ++j) {
178+
vec_t<float, vec_size> oz;
179+
oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size);
180+
#pragma unroll
181+
for (uint32_t i = 0; i < vec_size; ++i) {
182+
st.o[i] += oz[i];
183+
}
184+
}
168185
}
169186
}
170187
}
@@ -338,8 +355,10 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__
338355
block.sync();
339356

340357
// sync local state of all warps inside a threadblock
341-
sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast<float*>(smem), smem_md);
342-
st_local.normalize();
358+
sync_state<vec_size, bdx, bdy, bdz>(variant, st_local, reinterpret_cast<float*>(smem), smem_md);
359+
if constexpr (variant.use_softmax) {
360+
st_local.normalize();
361+
}
343362

344363
st_local.o.cast_store(o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
345364
if (lse != nullptr) {
@@ -557,8 +576,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
557576
block.sync();
558577

559578
// sync local state of all warps inside a threadblock
560-
sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast<float*>(smem), smem_md);
561-
st.normalize();
579+
sync_state<vec_size, bdx, bdy, bdz>(variant, st, reinterpret_cast<float*>(smem), smem_md);
580+
if constexpr (variant.use_softmax) {
581+
st.normalize();
582+
}
562583

563584
if (tz == 0) {
564585
st.o.cast_store(o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);

0 commit comments

Comments
 (0)