@@ -158,39 +158,42 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
158
158
uint32_t pos = blockIdx .x ;
159
159
uint32_t head_idx = ty;
160
160
161
- if (num_index_sets > 1 ) {
162
- vec_t <float , vec_size> v_merged_vec;
163
- state_t <vec_size> st;
164
- v_merged_vec.fill (0 .f );
165
- #pragma unroll 2
166
- for (uint32_t iter = 0 ; iter < num_index_sets; ++iter) {
167
- float s = S[(pos * num_index_sets + iter) * num_heads + head_idx];
168
- vec_t <float , vec_size> v;
169
- v.cast_load (V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
170
- tx * vec_size);
171
- st.merge (v, s, 1 );
172
- }
173
-
174
- st.normalize ();
175
- st.o .cast_store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
176
- if (s_merged != nullptr ) {
177
- s_merged[pos * num_heads + head_idx] = st.get_lse ();
178
- }
179
- } else if (num_index_sets == 1 ) {
161
+ if (num_index_sets == 0 ) {
180
162
vec_t <DTypeOut, vec_size> v;
181
- v.cast_load (V + (pos * num_heads + head_idx) * head_dim + tx * vec_size );
163
+ v.fill ( DTypeOut ( 0 ) );
182
164
v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
183
165
if (s_merged != nullptr ) {
184
- s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx] ;
166
+ s_merged[pos * num_heads + head_idx] = - 5e4 ;
185
167
}
186
- } else {
187
- // num_index_sets == 0
168
+ return ;
169
+ }
170
+
171
+ if (num_index_sets == 1 ) {
188
172
vec_t <DTypeOut, vec_size> v;
189
- v.fill ( DTypeOut ( 0 ) );
173
+ v.cast_load (V + (pos * num_heads + head_idx) * head_dim + tx * vec_size );
190
174
v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
191
175
if (s_merged != nullptr ) {
192
- s_merged[pos * num_heads + head_idx] = - 5e4 ;
176
+ s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx] ;
193
177
}
178
+ return ;
179
+ }
180
+
181
+ vec_t <float , vec_size> v_merged_vec;
182
+ state_t <vec_size> st;
183
+ v_merged_vec.fill (0 .f );
184
+ #pragma unroll 2
185
+ for (uint32_t iter = 0 ; iter < num_index_sets; ++iter) {
186
+ float s = S[(pos * num_index_sets + iter) * num_heads + head_idx];
187
+ vec_t <float , vec_size> v;
188
+ v.cast_load (V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
189
+ tx * vec_size);
190
+ st.merge (v, s, 1 );
191
+ }
192
+
193
+ st.normalize ();
194
+ st.o .cast_store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
195
+ if (s_merged != nullptr ) {
196
+ s_merged[pos * num_heads + head_idx] = st.get_lse ();
194
197
}
195
198
}
196
199
@@ -313,68 +316,70 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
313
316
float * s_smem = (float *)(smem + num_smem_stages * bdy * head_dim * sizeof (DTypeIn));
314
317
const uint32_t num_index_sets = indptr[pos + 1 ] - indptr[pos];
315
318
316
- if (num_index_sets > 1 ) {
317
- #pragma unroll
318
- for (uint32_t iter = 0 ; iter < num_smem_stages; ++iter) {
319
- cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
320
- v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
321
- V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
322
- (iter * bdy + ty) < num_index_sets);
323
- cp_async::commit_group ();
324
- }
325
- #pragma unroll 4
326
- for (uint32_t iter = 0 ; iter < ceil_div (num_index_sets, bdy); ++iter) {
327
- if (iter % bdx == 0 ) {
328
- s_smem[ty * bdx + tx] =
329
- iter * bdy + (ty * bdx + tx) < num_index_sets
330
- ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
331
- : 0 .f ;
332
- __syncthreads ();
333
- }
334
- cp_async::wait_group<num_smem_stages - 1 >();
335
- __syncthreads ();
336
- vec_t <float , vec_size> v;
337
- v.cast_load (v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
338
- if (iter * bdy + ty < num_index_sets) {
339
- float s = s_smem[(iter % bdx) * bdy + ty];
340
- st.merge (v, s, 1 );
341
- }
342
- __syncthreads ();
343
- cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
344
- v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
345
- V +
346
- ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
347
- head_dim +
348
- tx * vec_size,
349
- (iter + num_smem_stages) * bdy + ty < num_index_sets);
350
- cp_async::commit_group ();
351
- }
352
- cp_async::wait_group<0 >();
353
- __syncthreads ();
354
-
355
- st.normalize ();
356
- threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
357
- st.normalize ();
358
-
359
- st.o .cast_store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
319
+ if (num_index_sets == 0 ) {
320
+ vec_t <DTypeOut, vec_size> v;
321
+ v.fill (DTypeOut (0 ));
322
+ v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
360
323
if (s_merged != nullptr ) {
361
- s_merged[pos * num_heads + head_idx] = st. get_lse () ;
324
+ s_merged[pos * num_heads + head_idx] = - 5e4 ;
362
325
}
363
- } else if (num_index_sets == 1 ) {
326
+ return ;
327
+ }
328
+
329
+ if (num_index_sets == 1 ) {
364
330
vec_t <DTypeOut, vec_size> v;
365
331
v.cast_load (V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
366
332
v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
367
333
if (s_merged != nullptr ) {
368
334
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
369
335
}
370
- } else {
371
- // num_index_sets == 0
372
- vec_t <DTypeOut, vec_size> v;
373
- v.fill (DTypeOut (0 ));
374
- v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
375
- if (s_merged != nullptr ) {
376
- s_merged[pos * num_heads + head_idx] = -5e4 ;
336
+ }
337
+
338
+ #pragma unroll
339
+ for (uint32_t iter = 0 ; iter < num_smem_stages; ++iter) {
340
+ cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
341
+ v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
342
+ V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
343
+ (iter * bdy + ty) < num_index_sets);
344
+ cp_async::commit_group ();
345
+ }
346
+ #pragma unroll 4
347
+ for (uint32_t iter = 0 ; iter < ceil_div (num_index_sets, bdy); ++iter) {
348
+ if (iter % bdx == 0 ) {
349
+ s_smem[ty * bdx + tx] =
350
+ iter * bdy + (ty * bdx + tx) < num_index_sets
351
+ ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
352
+ : 0 .f ;
353
+ __syncthreads ();
377
354
}
355
+ cp_async::wait_group<num_smem_stages - 1 >();
356
+ __syncthreads ();
357
+ vec_t <float , vec_size> v;
358
+ v.cast_load (v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
359
+ if (iter * bdy + ty < num_index_sets) {
360
+ float s = s_smem[(iter % bdx) * bdy + ty];
361
+ st.merge (v, s, 1 );
362
+ }
363
+ __syncthreads ();
364
+ cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
365
+ v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
366
+ V +
367
+ ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
368
+ head_dim +
369
+ tx * vec_size,
370
+ (iter + num_smem_stages) * bdy + ty < num_index_sets);
371
+ cp_async::commit_group ();
372
+ }
373
+ cp_async::wait_group<0 >();
374
+ __syncthreads ();
375
+
376
+ st.normalize ();
377
+ threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
378
+ st.normalize ();
379
+
380
+ st.o .cast_store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
381
+ if (s_merged != nullptr ) {
382
+ s_merged[pos * num_heads + head_idx] = st.get_lse ();
378
383
}
379
384
}
380
385
0 commit comments