@@ -1471,10 +1471,21 @@ void xetla_paged_attention_impl_v1(
1471
1471
uint32_t num_kv_heads = key_cache.size (1 );
1472
1472
uint32_t max_num_blocks_per_seq = block_tables.size (1 );
1473
1473
1474
- // TODO(zw): alibi_slopes is optional, not used currently.
1475
- const float * alibi_slopes_ptr = alibi_slopes
1476
- ? reinterpret_cast <const float *>(alibi_slopes.value ().data_ptr ())
1477
- : nullptr ;
1474
+ if (alibi_slopes.has_value ()) {
1475
+ TORCH_CHECK (alibi_slopes->is_xpu (), " alibi_slopes_ must on XPU" );
1476
+ TORCH_CHECK (
1477
+ alibi_slopes->is_contiguous (), " alibi_slopes_ must be contiguous" );
1478
+ TORCH_CHECK (
1479
+ alibi_slopes->scalar_type () == at::kFloat ,
1480
+ " XeTLA VarlenAttention: The datatype of alibi_slopes should be float" );
1481
+ int ndim = alibi_slopes->ndimension ();
1482
+ TORCH_CHECK (
1483
+ ndim == 1 , " XeTLA VarlenAttention: only support 1 dim alibi tensor!" );
1484
+ int last_dim = alibi_slopes->size (-1 );
1485
+ TORCH_CHECK (
1486
+ last_dim == num_heads,
1487
+ " XeTLA VarlenAttention: The shape of alibi tensor should equal to [num_head]" );
1488
+ }
1478
1489
1479
1490
auto dpcpp_queue = dpcppGetCurrentQueue ();
1480
1491
#if defined(USE_XETLA)
@@ -1490,6 +1501,8 @@ void xetla_paged_attention_impl_v1(
1490
1501
reinterpret_cast <void *>(query.data_ptr ()),
1491
1502
reinterpret_cast <void *>(key_cache.data_ptr ()),
1492
1503
reinterpret_cast <void *>(value_cache.data_ptr ()),
1504
+ alibi_slopes.has_value () ? alibi_slopes.value ().data_ptr ()
1505
+ : (void *)nullptr ,
1493
1506
reinterpret_cast <void *>(block_tables.data_ptr ()),
1494
1507
reinterpret_cast <void *>(context_lens.data_ptr ()),
1495
1508
num_queries_per_tokens,
@@ -1560,10 +1573,21 @@ void xetla_paged_attention_impl_v2(
1560
1573
uint32_t num_kv_heads = key_cache.size (1 );
1561
1574
uint32_t max_num_blocks_per_seq = block_tables.size (1 );
1562
1575
1563
- // TODO(zw): alibi_slopes is optional, not used currently.
1564
- const float * alibi_slopes_ptr = alibi_slopes
1565
- ? reinterpret_cast <const float *>(alibi_slopes.value ().data_ptr ())
1566
- : nullptr ;
1576
+ if (alibi_slopes.has_value ()) {
1577
+ TORCH_CHECK (alibi_slopes->is_xpu (), " alibi_slopes_ must on XPU" );
1578
+ TORCH_CHECK (
1579
+ alibi_slopes->is_contiguous (), " alibi_slopes_ must be contiguous" );
1580
+ TORCH_CHECK (
1581
+ alibi_slopes->scalar_type () == at::kFloat ,
1582
+ " XeTLA VarlenAttention: The datatype of alibi_slopes should be float" );
1583
+ int ndim = alibi_slopes->ndimension ();
1584
+ TORCH_CHECK (
1585
+ ndim == 1 , " XeTLA VarlenAttention: only support 1 dim alibi tensor!" );
1586
+ int last_dim = alibi_slopes->size (-1 );
1587
+ TORCH_CHECK (
1588
+ last_dim == num_heads,
1589
+ " XeTLA VarlenAttention: The shape of alibi tensor should equal to [num_head]" );
1590
+ }
1567
1591
1568
1592
auto dpcpp_queue = dpcppGetCurrentQueue ();
1569
1593
#if defined(USE_XETLA)
@@ -1579,6 +1603,8 @@ void xetla_paged_attention_impl_v2(
1579
1603
reinterpret_cast <void *>(query.data_ptr ()),
1580
1604
reinterpret_cast <void *>(key_cache.data_ptr ()),
1581
1605
reinterpret_cast <void *>(value_cache.data_ptr ()),
1606
+ alibi_slopes.has_value () ? alibi_slopes.value ().data_ptr ()
1607
+ : (void *)nullptr ,
1582
1608
reinterpret_cast <void *>(block_tables.data_ptr ()),
1583
1609
reinterpret_cast <void *>(context_lens.data_ptr ()),
1584
1610
num_queries_per_tokens,
0 commit comments