Skip to content

Commit 2d19893

Browse files
authored
add alibi_slopes to paged attention (#5483)
* add alibi_slopes to paged attention * format * remove -inf * remove -inf * format
1 parent 3e98d9a commit 2d19893

File tree

8 files changed

+144
-41
lines changed

8 files changed

+144
-41
lines changed

csrc/gpu/aten/operators/transformers/attention.cpp

+34-8
Original file line numberDiff line numberDiff line change
@@ -1471,10 +1471,21 @@ void xetla_paged_attention_impl_v1(
14711471
uint32_t num_kv_heads = key_cache.size(1);
14721472
uint32_t max_num_blocks_per_seq = block_tables.size(1);
14731473

1474-
// TODO(zw): alibi_slopes is optional, not used currently.
1475-
const float* alibi_slopes_ptr = alibi_slopes
1476-
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
1477-
: nullptr;
1474+
if (alibi_slopes.has_value()) {
1475+
TORCH_CHECK(alibi_slopes->is_xpu(), "alibi_slopes_ must on XPU");
1476+
TORCH_CHECK(
1477+
alibi_slopes->is_contiguous(), "alibi_slopes_ must be contiguous");
1478+
TORCH_CHECK(
1479+
alibi_slopes->scalar_type() == at::kFloat,
1480+
"XeTLA VarlenAttention: The datatype of alibi_slopes should be float");
1481+
int ndim = alibi_slopes->ndimension();
1482+
TORCH_CHECK(
1483+
ndim == 1, "XeTLA VarlenAttention: only support 1 dim alibi tensor!");
1484+
int last_dim = alibi_slopes->size(-1);
1485+
TORCH_CHECK(
1486+
last_dim == num_heads,
1487+
"XeTLA VarlenAttention: The shape of alibi tensor should equal to [num_head]");
1488+
}
14781489

14791490
auto dpcpp_queue = dpcppGetCurrentQueue();
14801491
#if defined(USE_XETLA)
@@ -1490,6 +1501,8 @@ void xetla_paged_attention_impl_v1(
14901501
reinterpret_cast<void*>(query.data_ptr()),
14911502
reinterpret_cast<void*>(key_cache.data_ptr()),
14921503
reinterpret_cast<void*>(value_cache.data_ptr()),
1504+
alibi_slopes.has_value() ? alibi_slopes.value().data_ptr()
1505+
: (void*)nullptr,
14931506
reinterpret_cast<void*>(block_tables.data_ptr()),
14941507
reinterpret_cast<void*>(context_lens.data_ptr()),
14951508
num_queries_per_tokens,
@@ -1560,10 +1573,21 @@ void xetla_paged_attention_impl_v2(
15601573
uint32_t num_kv_heads = key_cache.size(1);
15611574
uint32_t max_num_blocks_per_seq = block_tables.size(1);
15621575

1563-
// TODO(zw): alibi_slopes is optional, not used currently.
1564-
const float* alibi_slopes_ptr = alibi_slopes
1565-
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
1566-
: nullptr;
1576+
if (alibi_slopes.has_value()) {
1577+
TORCH_CHECK(alibi_slopes->is_xpu(), "alibi_slopes_ must on XPU");
1578+
TORCH_CHECK(
1579+
alibi_slopes->is_contiguous(), "alibi_slopes_ must be contiguous");
1580+
TORCH_CHECK(
1581+
alibi_slopes->scalar_type() == at::kFloat,
1582+
"XeTLA VarlenAttention: The datatype of alibi_slopes should be float");
1583+
int ndim = alibi_slopes->ndimension();
1584+
TORCH_CHECK(
1585+
ndim == 1, "XeTLA VarlenAttention: only support 1 dim alibi tensor!");
1586+
int last_dim = alibi_slopes->size(-1);
1587+
TORCH_CHECK(
1588+
last_dim == num_heads,
1589+
"XeTLA VarlenAttention: The shape of alibi tensor should equal to [num_head]");
1590+
}
15671591

15681592
auto dpcpp_queue = dpcppGetCurrentQueue();
15691593
#if defined(USE_XETLA)
@@ -1579,6 +1603,8 @@ void xetla_paged_attention_impl_v2(
15791603
reinterpret_cast<void*>(query.data_ptr()),
15801604
reinterpret_cast<void*>(key_cache.data_ptr()),
15811605
reinterpret_cast<void*>(value_cache.data_ptr()),
1606+
alibi_slopes.has_value() ? alibi_slopes.value().data_ptr()
1607+
: (void*)nullptr,
15821608
reinterpret_cast<void*>(block_tables.data_ptr()),
15831609
reinterpret_cast<void*>(context_lens.data_ptr()),
15841610
num_queries_per_tokens,

csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_utils.h

-4
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,6 @@ struct tile_mask_t {
270270
#pragma unroll
271271
for (int k = 0; k < block_size_y; k++) {
272272
src_sub.row(k) += (blk_seq_x * alibi_slopes);
273-
xetla_mask<block_size_x> mask = blk_seq_x > blk_start_y + k;
274-
src_sub.row(k).xetla_merge(kNegInfinity, mask);
275273
}
276274
}
277275
}
@@ -296,8 +294,6 @@ struct tile_mask_t {
296294
#pragma unroll
297295
for (int k = 0; k < tail_size_y; k++) {
298296
src_sub.row(k) += (blk_seq_x * alibi_slopes);
299-
xetla_mask<block_size_x> mask = blk_seq_x > blk_start_y + k;
300-
src_sub.row(k).xetla_merge(kNegInfinity, mask);
301297
}
302298
}
303299
}

csrc/gpu/aten/operators/xetla/kernels/SDP/paged_attention_kernel.hpp

+23
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class paged_attention_kernel {
292292
scalar_t* query; // [num_seqs, num_heads, head_size]
293293
scalar_t* key_cache; // [num_blocks, num_kv_heads, head_size, block_size]
294294
scalar_t* value_cache; // [num_blocks, num_kv_heads, head_size, block_size]
295+
float* alibi_slopes; // [num_heads] - alibi_slopes
295296

296297
// Index
297298
index_t* block_tables; // [num_seqs, max_blocks_per_seq]
@@ -318,6 +319,7 @@ class paged_attention_kernel {
318319
scalar_t* query,
319320
scalar_t* key_cache,
320321
scalar_t* value_cache,
322+
float* alibi_slopes,
321323
index_t* block_tables,
322324
index_t* context_lens,
323325
uint32_t num_queries_per_tokens,
@@ -334,6 +336,7 @@ class paged_attention_kernel {
334336
query(query),
335337
key_cache(key_cache),
336338
value_cache(value_cache),
339+
alibi_slopes(alibi_slopes),
337340
block_tables(block_tables),
338341
context_lens(context_lens),
339342
num_queries_per_tokens(num_queries_per_tokens),
@@ -404,6 +407,8 @@ class paged_attention_kernel {
404407
int end_block_id;
405408
int loop_count;
406409

410+
float alibi_slopes;
411+
407412
xetla_nbarrier_t<wg_size, wg_size, arch_tag> nbarrier;
408413

409414
inline context_t() = default;
@@ -415,6 +420,10 @@ class paged_attention_kernel {
415420
partition_id = item.get_group(2);
416421
max_num_partitions = item.get_group_range(2);
417422

423+
if (args.alibi_slopes != nullptr) {
424+
alibi_slopes = args.alibi_slopes[head_id];
425+
}
426+
418427
context_len = args.context_lens[seq_id];
419428
block_table = args.block_tables + seq_id * args.max_blocks_per_seq;
420429
num_blocks_per_sg = 0;
@@ -611,6 +620,15 @@ class paged_attention_kernel {
611620
xetla_tanh<typename score_tile_t::dtype, block_size>(score_sub);
612621
score_sub *= args.softcap;
613622
}
623+
624+
if (args.alibi_slopes != nullptr) {
625+
int32_t mat_real_x = bid * block_size;
626+
int32_t mat_real_y = ctx.seq_id;
627+
xetla_vector<float, block_size> pos_id =
628+
xetla_vector_gen<float, block_size>(mat_real_x, 1);
629+
score_sub += (pos_id * ctx.alibi_slopes);
630+
}
631+
614632
uint32_t remained_len = ctx.context_len - bid * block_size;
615633
if (remained_len < block_size) {
616634
xetla_mask<block_size> mask =
@@ -646,6 +664,11 @@ class paged_attention_kernel {
646664
accum_t group_sum = wg_reduce_sum(mat_score);
647665
mat_score.reg /= group_sum;
648666

667+
if (use_partition && group_max == neg_infinity) {
668+
mat_score.reg = 0.f;
669+
group_sum = 0.f;
670+
}
671+
649672
if (use_partition && ctx.sg_id == 0) {
650673
// store the max and exp_sum
651674
using tile_desc_t = subgroup::tile_desc_t<1, 1, 1, 1>;

csrc/gpu/aten/operators/xetla/kernels/SDP/paged_attention_v1.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ cgfs_t launch_kernels(paged_attention_fwd_kernel_args_t fwd_args) {
9898
reinterpret_cast<T*>(fwd_args.query),
9999
reinterpret_cast<T*>(fwd_args.key_cache),
100100
reinterpret_cast<T*>(fwd_args.value_cache),
101+
reinterpret_cast<float*>(fwd_args.alibi_slopes),
101102
reinterpret_cast<U*>(fwd_args.block_tables),
102103
reinterpret_cast<U*>(fwd_args.context_lens),
103104
fwd_args.num_queries_per_tokens,

csrc/gpu/aten/operators/xetla/kernels/SDP/paged_attention_v2.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ std::vector<std::function<void(sycl::handler&)>> launch_split_kv_kernels(
6262
reinterpret_cast<T*>(fwd_args.query),
6363
reinterpret_cast<T*>(fwd_args.key_cache),
6464
reinterpret_cast<T*>(fwd_args.value_cache),
65+
reinterpret_cast<float*>(fwd_args.alibi_slopes),
6566
reinterpret_cast<U*>(fwd_args.block_tables),
6667
reinterpret_cast<U*>(fwd_args.context_lens),
6768
fwd_args.num_queries_per_tokens,

csrc/gpu/aten/operators/xetla/mha.h

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ struct paged_attention_fwd_kernel_args_t {
6767
void* query;
6868
void* key_cache;
6969
void* value_cache;
70+
void* alibi_slopes;
7071
void* block_tables;
7172
void* context_lens;
7273
uint32_t num_queries_per_tokens;

0 commit comments

Comments
 (0)