Skip to content

Commit 2211480

Browse files
authored
【infer】 predict_dy_insert support more inputs and append attn support excess_blocks input (#10446)
* support input input_ids, and all_rank_return in dy_insert * refine code * delete print * fix kv_cache block mismatch block_table * check input_ids len * check input_ids len logger * fix dy_insert cuda700, append attn support excess_blocks * check empty_cache order * check modeling forward input excess_blocks
1 parent 345ff57 commit 2211480

17 files changed

+324
-110
lines changed

csrc/gpu/append_attention.cu

+13-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
5353
const paddle::optional<paddle::Tensor>& cache_v_zp,
5454
const paddle::optional<paddle::Tensor>& out_linear_shifts,
5555
const paddle::optional<paddle::Tensor>& out_linear_smooths,
56+
const paddle::optional<paddle::Tensor>& excess_blocks,
5657
const std::string& cache_quant_type_str,
5758
const bool use_neox_rotary_style,
5859
const int max_input_length,
@@ -140,6 +141,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
140141
cache_v_quant_scales,
141142
cache_k_zp,
142143
cache_v_zp,
144+
excess_blocks,
143145
cache_quant_type_str,
144146
kv_num_blocks_data,
145147
max_input_length,
@@ -167,6 +169,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
167169
cache_v_quant_scales,
168170
cache_k_zp,
169171
cache_v_zp,
172+
excess_blocks,
170173
cache_quant_type_str,
171174
kv_num_blocks_data,
172175
max_input_length,
@@ -568,6 +571,7 @@ std::vector<paddle::Tensor> AppendAttention(
568571
const paddle::optional<paddle::Tensor>& cache_v_zp,
569572
const paddle::optional<paddle::Tensor>& out_linear_shifts,
570573
const paddle::optional<paddle::Tensor>& out_linear_smooths,
574+
const paddle::optional<paddle::Tensor>& excess_blocks,
571575
const std::string& compute_dtype,
572576
const std::string& cache_quant_type_str,
573577
const bool use_neox_rotary_style,
@@ -632,6 +636,7 @@ std::vector<paddle::Tensor> AppendAttention(
632636
cache_v_zp,
633637
out_linear_shifts,
634638
out_linear_smooths,
639+
excess_blocks,
635640
cache_quant_type_str,
636641
use_neox_rotary_style,
637642
max_input_length,
@@ -679,6 +684,7 @@ std::vector<paddle::Tensor> AppendAttention(
679684
cache_v_zp,
680685
out_linear_shifts,
681686
out_linear_smooths,
687+
excess_blocks,
682688
cache_quant_type_str,
683689
use_neox_rotary_style,
684690
max_input_length,
@@ -727,6 +733,7 @@ std::vector<paddle::Tensor> AppendAttention(
727733
cache_v_zp,
728734
out_linear_shifts,
729735
out_linear_smooths,
736+
excess_blocks,
730737
cache_quant_type_str,
731738
use_neox_rotary_style,
732739
max_input_length,
@@ -773,6 +780,7 @@ std::vector<paddle::Tensor> AppendAttention(
773780
cache_v_zp,
774781
out_linear_shifts,
775782
out_linear_smooths,
783+
excess_blocks,
776784
cache_quant_type_str,
777785
use_neox_rotary_style,
778786
max_input_length,
@@ -831,7 +839,8 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
831839
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
832840
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
833841
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
834-
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape) {
842+
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
843+
const paddle::optional<std::vector<int64_t>>& excess_blocks_shape) {
835844
const int token_num = qkv_shape[0];
836845
const int kv_num_heads = key_cache_shape[1];
837846
const int head_dim_qk = key_cache_shape[3];
@@ -876,6 +885,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
876885
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
877886
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
878887
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
888+
const paddle::optional<paddle::DataType>& excess_blocks_dtype,
879889
const std::string& compute_dtype,
880890
const std::string& cache_quant_type_str,
881891
const bool use_neox_rotary_style,
@@ -949,7 +959,8 @@ PD_BUILD_OP(append_attention)
949959
paddle::Optional("cache_k_zp"),
950960
paddle::Optional("cache_v_zp"),
951961
paddle::Optional("out_linear_shifts"),
952-
paddle::Optional("out_linear_smooths")})
962+
paddle::Optional("out_linear_smooths"),
963+
paddle::Optional("excess_blocks")})
953964
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
954965
.SetInplaceMap({{"key_cache", "key_cache_out"},
955966
{"value_cache", "value_cache_out"}})

csrc/gpu/append_attn/encoder_write_cache_with_rope_impl.cuh

+143-18
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,99 @@ __global__ void cache_kernel(
691691
}
692692
}
693693

694+
template <typename T, int VecSize = 1>
695+
__global__ void cache_use_excess_kernel(
696+
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
697+
// head_size]
698+
T *__restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
699+
// head_size]
700+
T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
701+
// head_size]
702+
const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq]
703+
const int *__restrict__ padding_offsets, // [num_tokens]
704+
const int *__restrict__ cum_offsets,
705+
const int *__restrict__ seq_lens, // [bsz]
706+
const int *__restrict__ seq_lens_decoder, // [bsz]
707+
const int *__restrict__ excess_blocks, // [bsz, excess_num]
708+
const int max_seq_len,
709+
const int max_blocks_per_seq,
710+
const int num_heads,
711+
const int head_size,
712+
const int block_size,
713+
const uint32_t elem_cnt,
714+
const int kv_num_heads,
715+
const int token_num,
716+
const int excess_num) {
717+
using LoadT = AlignedVector<T, VecSize>;
718+
LoadT src_vec;
719+
720+
uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
721+
const uint32_t hidden_size = kv_num_heads * head_size;
722+
const uint32_t offset = 2 * hidden_size;
723+
for (uint32_t linear_index = global_thread_idx * VecSize,
724+
step = gridDim.x * blockDim.x * VecSize;
725+
linear_index < elem_cnt;
726+
linear_index += step) {
727+
uint32_t token_idx = linear_index / offset;
728+
const uint32_t bias = linear_index % offset;
729+
const uint32_t qkv_id = bias / hidden_size; // skip q
730+
const uint32_t qkv_bias = bias % hidden_size;
731+
const uint32_t hi = qkv_bias / head_size;
732+
const uint32_t h_bias = qkv_bias % head_size;
733+
734+
uint32_t block_idx, block_offset;
735+
736+
if (token_idx < token_num) {
737+
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
738+
const uint32_t ori_bi = ori_token_idx / max_seq_len;
739+
const uint32_t last_offset = seq_lens[ori_bi] % block_size;
740+
if (seq_lens[ori_bi] == 0) continue;
741+
742+
const int32_t *block_table_now = nullptr;
743+
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
744+
const uint32_t ori_seq_id =
745+
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
746+
if (ori_seq_id >= seq_lens[ori_bi] - last_offset) continue;
747+
748+
block_idx = block_table_now[ori_seq_id / block_size];
749+
block_offset = ori_seq_id % block_size;
750+
} else {
751+
const uint32_t excess_token_id = token_idx - token_num;
752+
const uint32_t ori_bi = excess_token_id / (excess_num * block_size);
753+
const uint32_t last_offset = seq_lens[ori_bi] % block_size;
754+
if (seq_lens[ori_bi] == 0) continue;
755+
756+
const uint32_t excess_id =
757+
(excess_token_id % (excess_num * block_size)) / block_size;
758+
const uint32_t excess_token_offset = excess_token_id % block_size;
759+
760+
if (excess_token_offset < last_offset) {
761+
token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi] +
762+
seq_lens[ori_bi] - last_offset + excess_token_offset;
763+
} else {
764+
continue;
765+
}
766+
767+
block_idx = excess_blocks[ori_bi * excess_num + excess_id];
768+
block_offset = excess_token_offset;
769+
}
770+
771+
const uint32_t tgt_idx =
772+
block_idx * kv_num_heads * block_size * head_size +
773+
hi * block_size * head_size + block_offset * head_size + h_bias;
774+
775+
const uint32_t ori_idx =
776+
token_idx * (num_heads + 2 * kv_num_heads) * head_size +
777+
num_heads * head_size + qkv_id * hidden_size + hi * head_size + h_bias;
778+
779+
Load<T, VecSize>(&qkv[ori_idx], &src_vec);
780+
if (qkv_id == 0) {
781+
Store<T, VecSize>(src_vec, &key_cache[tgt_idx]);
782+
} else {
783+
Store<T, VecSize>(src_vec, &value_cache[tgt_idx]);
784+
}
785+
}
786+
}
694787

695788
template <typename T,
696789
uint32_t num_frags_y,
@@ -1463,9 +1556,12 @@ void CascadeAppendWriteCacheKVQKV(
14631556
// kv_num_heads, head_dim] if GQA)
14641557
const paddle::Tensor &block_table,
14651558
const paddle::Tensor &padding_offsets,
1559+
const paddle::Tensor &cum_offsets,
14661560
const paddle::Tensor &seq_lens_encoder,
14671561
const paddle::Tensor &seq_lens_decoder,
14681562
const int max_seq_len,
1563+
const int bsz,
1564+
const paddle::optional<paddle::Tensor>& excess_blocks,
14691565
cudaStream_t &stream,
14701566
paddle::Tensor *key_cache_out,
14711567
paddle::Tensor *value_cache_out) {
@@ -1477,29 +1573,58 @@ void CascadeAppendWriteCacheKVQKV(
14771573
auto head_dim_v = meta_data.head_dims_v;
14781574
auto block_size = meta_data.block_size;
14791575

1480-
const uint32_t elem_nums =
1481-
num_tokens * kv_num_heads * (head_dim_qk + head_dim_v);
1576+
int excess_block_num = 0;
1577+
int *excess_blocks_ptr = nullptr;
1578+
if (excess_blocks) {
1579+
excess_block_num = excess_blocks.get().dims()[1];
1580+
excess_blocks_ptr =const_cast<int*>(excess_blocks.get().data<int>());
1581+
}
1582+
uint32_t elem_nums = (num_tokens + bsz * excess_block_num * block_size) * kv_num_heads * (head_dim_qk + head_dim_v);
1583+
// 额外每个bid 多分配excess_block_num * block_size 个
1584+
14821585
constexpr int PackSize = 16 / sizeof(T);
14831586
const int pack_num = elem_nums / PackSize;
14841587
const int blocksize = 128;
14851588
int grid_size = 1;
14861589
GetNumBlocks<128>(pack_num, &grid_size);
1487-
cache_kernel<T, PackSize><<<grid_size, blocksize, 0, stream>>>(
1488-
reinterpret_cast<T *>(const_cast<T *>(qkv.data<T>())),
1489-
reinterpret_cast<T *>(key_cache_out->data<T>()),
1490-
reinterpret_cast<T *>(value_cache_out->data<T>()),
1491-
block_table.data<int>(),
1492-
padding_offsets.data<int>(),
1493-
seq_lens_encoder.data<int>(),
1494-
seq_lens_decoder.data<int>(),
1495-
max_seq_len,
1496-
max_blocks_per_seq,
1497-
num_heads,
1498-
head_dim_qk,
1499-
head_dim_v,
1500-
block_size,
1501-
elem_nums,
1502-
kv_num_heads);
1590+
if (excess_blocks_ptr) {
1591+
cache_use_excess_kernel<T, PackSize><<<grid_size, blocksize, 0, stream>>>(
1592+
reinterpret_cast<T *>(const_cast<T *>(qkv.data<T>())),
1593+
reinterpret_cast<T *>(key_cache_out->data<T>()),
1594+
reinterpret_cast<T *>(value_cache_out->data<T>()),
1595+
block_table.data<int>(),
1596+
padding_offsets.data<int>(),
1597+
cum_offsets.data<int>(),
1598+
seq_lens_encoder.data<int>(),
1599+
seq_lens_decoder.data<int>(),
1600+
excess_blocks_ptr,
1601+
max_seq_len,
1602+
max_blocks_per_seq,
1603+
num_heads,
1604+
head_dim_qk,
1605+
block_size,
1606+
elem_nums,
1607+
kv_num_heads,
1608+
num_tokens,
1609+
excess_block_num);
1610+
} else {
1611+
cache_kernel<T, PackSize><<<grid_size, blocksize, 0, stream>>>(
1612+
reinterpret_cast<T *>(const_cast<T *>(qkv.data<T>())),
1613+
reinterpret_cast<T *>(key_cache_out->data<T>()),
1614+
reinterpret_cast<T *>(value_cache_out->data<T>()),
1615+
block_table.data<int>(),
1616+
padding_offsets.data<int>(),
1617+
seq_lens_encoder.data<int>(),
1618+
seq_lens_decoder.data<int>(),
1619+
max_seq_len,
1620+
max_blocks_per_seq,
1621+
num_heads,
1622+
head_dim_qk,
1623+
head_dim_v,
1624+
block_size,
1625+
elem_nums,
1626+
kv_num_heads);
1627+
}
15031628
}
15041629

15051630
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>

csrc/gpu/append_attn/encoder_write_cache_with_rope_kernel.h

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void EncoderWriteCacheWithRopeKernel(
3636
const paddle::optional<paddle::Tensor>& cache_v_scale,
3737
const paddle::optional<paddle::Tensor>& cache_k_zp,
3838
const paddle::optional<paddle::Tensor>& cache_v_zp,
39+
const paddle::optional<paddle::Tensor>& excess_blocks,
3940
const std::string& cache_quant_type_str,
4041
const int num_blocks,
4142
const int max_seq_len,
@@ -48,6 +49,7 @@ void EncoderWriteCacheWithRopeKernel(
4849
auto num_heads = meta_data.q_num_heads;
4950
auto kv_num_heads = meta_data.kv_num_heads;
5051
auto head_dim = meta_data.head_dims;
52+
int bsz = cum_offsets.dims()[0];
5153
if (rotary_embs) {
5254
if (num_heads == kv_num_heads) {
5355
rotary_qk_variable(
@@ -93,9 +95,12 @@ void EncoderWriteCacheWithRopeKernel(
9395
*qkv_out,
9496
block_tables,
9597
padding_offsets,
98+
cum_offsets,
9699
seq_lens_encoder,
97100
seq_lens_decoder,
98101
max_seq_len,
102+
bsz,
103+
excess_blocks,
99104
stream,
100105
key_cache_out,
101106
value_cache_out);

csrc/gpu/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
3434
const paddle::optional<paddle::Tensor>& cache_v_scale,
3535
const paddle::optional<paddle::Tensor>& cache_k_zp,
3636
const paddle::optional<paddle::Tensor>& cache_v_zp,
37+
const paddle::optional<paddle::Tensor>& excess_blocks,
3738
const std::string& cache_quant_type_str,
3839
const int num_blocks,
3940
const int max_seq_len,

csrc/gpu/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
3333
const paddle::optional<paddle::Tensor>& cache_v_scale,
3434
const paddle::optional<paddle::Tensor>& cache_k_zp,
3535
const paddle::optional<paddle::Tensor>& cache_v_zp,
36+
const paddle::optional<paddle::Tensor>& excess_blocks,
3637
const std::string& cache_quant_type_str,
3738
const int num_blocks,
3839
const int max_seq_len,

csrc/gpu/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
3333
const paddle::optional<paddle::Tensor>& cache_v_scale,
3434
const paddle::optional<paddle::Tensor>& cache_k_zp,
3535
const paddle::optional<paddle::Tensor>& cache_v_zp,
36+
const paddle::optional<paddle::Tensor>& excess_blocks,
3637
const std::string& cache_quant_type_str,
3738
const int num_blocks,
3839
const int max_seq_len,

csrc/gpu/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
3333
const paddle::optional<paddle::Tensor>& cache_v_scale,
3434
const paddle::optional<paddle::Tensor>& cache_k_zp,
3535
const paddle::optional<paddle::Tensor>& cache_v_zp,
36+
const paddle::optional<paddle::Tensor>& excess_blocks,
3637
const std::string& cache_quant_type_str,
3738
const int num_blocks,
3839
const int max_seq_len,

csrc/gpu/cpp_extensions.cu

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ std::vector<paddle::Tensor> AppendAttention(
5151
const paddle::optional<paddle::Tensor>& cache_v_zp,
5252
const paddle::optional<paddle::Tensor>& out_linear_shifts,
5353
const paddle::optional<paddle::Tensor>& out_linear_smooths,
54+
const paddle::optional<paddle::Tensor>& excess_blocks,
5455
const std::string& compute_dtype,
5556
const std::string& cache_quant_type_str,
5657
const bool use_neox_rotary_style,

0 commit comments

Comments
 (0)