@@ -97,16 +97,18 @@ __device__ __forceinline__ void compute_qk(const typename AttentionVariant::Para
97
97
st.m = max (st.m , s[j]);
98
98
}
99
99
100
- float o_scale = math::ptx_exp2 (m_prev - st.m );
101
- st.d *= o_scale;
100
+ if constexpr (variant.use_softmax ) {
101
+ float o_scale = math::ptx_exp2 (m_prev - st.m );
102
+ st.d *= o_scale;
102
103
#pragma unroll
103
- for (uint32_t j = 0 ; j < tile_size; ++j) {
104
- s[j] = math::ptx_exp2 (s[j] - st.m );
105
- st.d += s[j];
106
- }
104
+ for (uint32_t j = 0 ; j < tile_size; ++j) {
105
+ s[j] = math::ptx_exp2 (s[j] - st.m );
106
+ st.d += s[j];
107
+ }
107
108
#pragma unroll
108
- for (uint32_t i = 0 ; i < vec_size; ++i) {
109
- st.o [i] = st.o [i] * o_scale;
109
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
110
+ st.o [i] = st.o [i] * o_scale;
111
+ }
110
112
}
111
113
}
112
114
@@ -148,23 +150,38 @@ __device__ __forceinline__ void update_local_state(const T* smem, const float* s
148
150
* \param smem The pointer to shared memory buffer for o
149
151
* \param smem_md The pointer to shared memory buffer for m/d
150
152
*/
151
- template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz>
152
- __device__ __forceinline__ void sync_state (state_t <vec_size>& st, float * smem, float * smem_md) {
153
+ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename AttentionVariant>
154
+ __device__ __forceinline__ void sync_state (AttentionVariant variant, state_t <vec_size>& st,
155
+ float * smem, float * smem_md) {
153
156
if constexpr (bdz > 1 ) {
154
157
constexpr uint32_t head_dim = bdx * vec_size;
155
158
auto block = cg::this_thread_block ();
156
159
uint32_t tx = threadIdx .x , ty = threadIdx .y , tz = threadIdx .z ;
157
160
st.o .store (smem + (tz * bdy + ty) * head_dim + tx * vec_size);
158
- smem_md[(tz * bdy + ty) * 2 ] = st.m ;
159
- smem_md[(tz * bdy + ty) * 2 + 1 ] = st.d ;
160
- block.sync ();
161
- st.init ();
161
+ if constexpr (variant.use_softmax ) {
162
+ smem_md[(tz * bdy + ty) * 2 ] = st.m ;
163
+ smem_md[(tz * bdy + ty) * 2 + 1 ] = st.d ;
164
+ block.sync ();
165
+ st.init ();
162
166
#pragma unroll
163
- for (uint32_t j = 0 ; j < bdz; ++j) {
164
- float mz = smem_md[(j * bdy + ty) * 2 ], dz = smem_md[(j * bdy + ty) * 2 + 1 ];
165
- vec_t <float , vec_size> oz;
166
- oz.load (smem + (j * bdy + ty) * head_dim + tx * vec_size);
167
- st.merge (oz, mz, dz);
167
+ for (uint32_t j = 0 ; j < bdz; ++j) {
168
+ float mz = smem_md[(j * bdy + ty) * 2 ], dz = smem_md[(j * bdy + ty) * 2 + 1 ];
169
+ vec_t <float , vec_size> oz;
170
+ oz.load (smem + (j * bdy + ty) * head_dim + tx * vec_size);
171
+ st.merge (oz, mz, dz);
172
+ }
173
+ } else {
174
+ block.sync ();
175
+ st.init ();
176
+ #pragma unroll
177
+ for (uint32_t j = 0 ; j < bdz; ++j) {
178
+ vec_t <float , vec_size> oz;
179
+ oz.load (smem + (j * bdy + ty) * head_dim + tx * vec_size);
180
+ #pragma unroll
181
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
182
+ st.o [i] += oz[i];
183
+ }
184
+ }
168
185
}
169
186
}
170
187
}
@@ -338,8 +355,10 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__
338
355
block.sync ();
339
356
340
357
// sync local state of all warps inside a threadblock
341
- sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast <float *>(smem), smem_md);
342
- st_local.normalize ();
358
+ sync_state<vec_size, bdx, bdy, bdz>(variant, st_local, reinterpret_cast <float *>(smem), smem_md);
359
+ if constexpr (variant.use_softmax ) {
360
+ st_local.normalize ();
361
+ }
343
362
344
363
st_local.o .cast_store (o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
345
364
if (lse != nullptr ) {
@@ -557,8 +576,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
557
576
block.sync ();
558
577
559
578
// sync local state of all warps inside a threadblock
560
- sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast <float *>(smem), smem_md);
561
- st.normalize ();
579
+ sync_state<vec_size, bdx, bdy, bdz>(variant, st, reinterpret_cast <float *>(smem), smem_md);
580
+ if constexpr (variant.use_softmax ) {
581
+ st.normalize ();
582
+ }
562
583
563
584
if (tz == 0 ) {
564
585
st.o .cast_store (o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
0 commit comments