@@ -691,6 +691,99 @@ __global__ void cache_kernel(
691
691
}
692
692
}
693
693
694
+ template <typename T, int VecSize = 1 >
695
+ __global__ void cache_use_excess_kernel (
696
+ const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
697
+ // head_size]
698
+ T *__restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
699
+ // head_size]
700
+ T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
701
+ // head_size]
702
+ const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq]
703
+ const int *__restrict__ padding_offsets, // [num_tokens]
704
+ const int *__restrict__ cum_offsets,
705
+ const int *__restrict__ seq_lens, // [bsz]
706
+ const int *__restrict__ seq_lens_decoder, // [bsz]
707
+ const int *__restrict__ excess_blocks, // [bsz, excess_num]
708
+ const int max_seq_len,
709
+ const int max_blocks_per_seq,
710
+ const int num_heads,
711
+ const int head_size,
712
+ const int block_size,
713
+ const uint32_t elem_cnt,
714
+ const int kv_num_heads,
715
+ const int token_num,
716
+ const int excess_num) {
717
+ using LoadT = AlignedVector<T, VecSize>;
718
+ LoadT src_vec;
719
+
720
+ uint32_t global_thread_idx = blockDim .x * blockIdx .x + threadIdx .x ;
721
+ const uint32_t hidden_size = kv_num_heads * head_size;
722
+ const uint32_t offset = 2 * hidden_size;
723
+ for (uint32_t linear_index = global_thread_idx * VecSize,
724
+ step = gridDim .x * blockDim .x * VecSize;
725
+ linear_index < elem_cnt;
726
+ linear_index += step) {
727
+ uint32_t token_idx = linear_index / offset;
728
+ const uint32_t bias = linear_index % offset;
729
+ const uint32_t qkv_id = bias / hidden_size; // skip q
730
+ const uint32_t qkv_bias = bias % hidden_size;
731
+ const uint32_t hi = qkv_bias / head_size;
732
+ const uint32_t h_bias = qkv_bias % head_size;
733
+
734
+ uint32_t block_idx, block_offset;
735
+
736
+ if (token_idx < token_num) {
737
+ const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
738
+ const uint32_t ori_bi = ori_token_idx / max_seq_len;
739
+ const uint32_t last_offset = seq_lens[ori_bi] % block_size;
740
+ if (seq_lens[ori_bi] == 0 ) continue ;
741
+
742
+ const int32_t *block_table_now = nullptr ;
743
+ block_table_now = block_tables + ori_bi * max_blocks_per_seq;
744
+ const uint32_t ori_seq_id =
745
+ ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
746
+ if (ori_seq_id >= seq_lens[ori_bi] - last_offset) continue ;
747
+
748
+ block_idx = block_table_now[ori_seq_id / block_size];
749
+ block_offset = ori_seq_id % block_size;
750
+ } else {
751
+ const uint32_t excess_token_id = token_idx - token_num;
752
+ const uint32_t ori_bi = excess_token_id / (excess_num * block_size);
753
+ const uint32_t last_offset = seq_lens[ori_bi] % block_size;
754
+ if (seq_lens[ori_bi] == 0 ) continue ;
755
+
756
+ const uint32_t excess_id =
757
+ (excess_token_id % (excess_num * block_size)) / block_size;
758
+ const uint32_t excess_token_offset = excess_token_id % block_size;
759
+
760
+ if (excess_token_offset < last_offset) {
761
+ token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi] +
762
+ seq_lens[ori_bi] - last_offset + excess_token_offset;
763
+ } else {
764
+ continue ;
765
+ }
766
+
767
+ block_idx = excess_blocks[ori_bi * excess_num + excess_id];
768
+ block_offset = excess_token_offset;
769
+ }
770
+
771
+ const uint32_t tgt_idx =
772
+ block_idx * kv_num_heads * block_size * head_size +
773
+ hi * block_size * head_size + block_offset * head_size + h_bias;
774
+
775
+ const uint32_t ori_idx =
776
+ token_idx * (num_heads + 2 * kv_num_heads) * head_size +
777
+ num_heads * head_size + qkv_id * hidden_size + hi * head_size + h_bias;
778
+
779
+ Load<T, VecSize>(&qkv[ori_idx], &src_vec);
780
+ if (qkv_id == 0 ) {
781
+ Store<T, VecSize>(src_vec, &key_cache[tgt_idx]);
782
+ } else {
783
+ Store<T, VecSize>(src_vec, &value_cache[tgt_idx]);
784
+ }
785
+ }
786
+ }
694
787
695
788
template <typename T,
696
789
uint32_t num_frags_y,
@@ -1463,9 +1556,12 @@ void CascadeAppendWriteCacheKVQKV(
1463
1556
// kv_num_heads, head_dim] if GQA)
1464
1557
const paddle::Tensor &block_table,
1465
1558
const paddle::Tensor &padding_offsets,
1559
+ const paddle::Tensor &cum_offsets,
1466
1560
const paddle::Tensor &seq_lens_encoder,
1467
1561
const paddle::Tensor &seq_lens_decoder,
1468
1562
const int max_seq_len,
1563
+ const int bsz,
1564
+ const paddle::optional<paddle::Tensor>& excess_blocks,
1469
1565
cudaStream_t &stream,
1470
1566
paddle::Tensor *key_cache_out,
1471
1567
paddle::Tensor *value_cache_out) {
@@ -1477,29 +1573,58 @@ void CascadeAppendWriteCacheKVQKV(
1477
1573
auto head_dim_v = meta_data.head_dims_v ;
1478
1574
auto block_size = meta_data.block_size ;
1479
1575
1480
- const uint32_t elem_nums =
1481
- num_tokens * kv_num_heads * (head_dim_qk + head_dim_v);
1576
+ int excess_block_num = 0 ;
1577
+ int *excess_blocks_ptr = nullptr ;
1578
+ if (excess_blocks) {
1579
+ excess_block_num = excess_blocks.get ().dims ()[1 ];
1580
+ excess_blocks_ptr =const_cast <int *>(excess_blocks.get ().data <int >());
1581
+ }
1582
+ uint32_t elem_nums = (num_tokens + bsz * excess_block_num * block_size) * kv_num_heads * (head_dim_qk + head_dim_v);
1583
+ // 额外每个bid 多分配excess_block_num * block_size 个
1584
+
1482
1585
constexpr int PackSize = 16 / sizeof (T);
1483
1586
const int pack_num = elem_nums / PackSize;
1484
1587
const int blocksize = 128 ;
1485
1588
int grid_size = 1 ;
1486
1589
GetNumBlocks<128 >(pack_num, &grid_size);
1487
- cache_kernel<T, PackSize><<<grid_size, blocksize, 0 , stream>>> (
1488
- reinterpret_cast <T *>(const_cast <T *>(qkv.data <T>())),
1489
- reinterpret_cast <T *>(key_cache_out->data <T>()),
1490
- reinterpret_cast <T *>(value_cache_out->data <T>()),
1491
- block_table.data <int >(),
1492
- padding_offsets.data <int >(),
1493
- seq_lens_encoder.data <int >(),
1494
- seq_lens_decoder.data <int >(),
1495
- max_seq_len,
1496
- max_blocks_per_seq,
1497
- num_heads,
1498
- head_dim_qk,
1499
- head_dim_v,
1500
- block_size,
1501
- elem_nums,
1502
- kv_num_heads);
1590
+ if (excess_blocks_ptr) {
1591
+ cache_use_excess_kernel<T, PackSize><<<grid_size, blocksize, 0 , stream>>> (
1592
+ reinterpret_cast <T *>(const_cast <T *>(qkv.data <T>())),
1593
+ reinterpret_cast <T *>(key_cache_out->data <T>()),
1594
+ reinterpret_cast <T *>(value_cache_out->data <T>()),
1595
+ block_table.data <int >(),
1596
+ padding_offsets.data <int >(),
1597
+ cum_offsets.data <int >(),
1598
+ seq_lens_encoder.data <int >(),
1599
+ seq_lens_decoder.data <int >(),
1600
+ excess_blocks_ptr,
1601
+ max_seq_len,
1602
+ max_blocks_per_seq,
1603
+ num_heads,
1604
+ head_dim_qk,
1605
+ block_size,
1606
+ elem_nums,
1607
+ kv_num_heads,
1608
+ num_tokens,
1609
+ excess_block_num);
1610
+ } else {
1611
+ cache_kernel<T, PackSize><<<grid_size, blocksize, 0 , stream>>> (
1612
+ reinterpret_cast <T *>(const_cast <T *>(qkv.data <T>())),
1613
+ reinterpret_cast <T *>(key_cache_out->data <T>()),
1614
+ reinterpret_cast <T *>(value_cache_out->data <T>()),
1615
+ block_table.data <int >(),
1616
+ padding_offsets.data <int >(),
1617
+ seq_lens_encoder.data <int >(),
1618
+ seq_lens_decoder.data <int >(),
1619
+ max_seq_len,
1620
+ max_blocks_per_seq,
1621
+ num_heads,
1622
+ head_dim_qk,
1623
+ head_dim_v,
1624
+ block_size,
1625
+ elem_nums,
1626
+ kv_num_heads);
1627
+ }
1503
1628
}
1504
1629
1505
1630
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
0 commit comments