@@ -458,8 +458,7 @@ def __init__(
458
458
self .sliding_window = sliding_window
459
459
if alibi_slopes is not None :
460
460
alibi_slopes = torch .tensor (alibi_slopes ,
461
- dtype = torch .float32 ,
462
- device = "npu" )
461
+ dtype = torch .float32 )
463
462
self .alibi_slopes = alibi_slopes
464
463
self .attn_type = attn_type
465
464
@@ -520,13 +519,13 @@ def forward(
520
519
attn_metadata .sparse_mode = 2
521
520
attention_mask = gen_input_mask (
522
521
attn_metadata .max_prefill_seq_len , self .sliding_window ,
523
- num_tokens )
522
+ num_tokens , query . device )
524
523
attn_metadata .attn_mask = attention_mask
525
524
526
525
if (self .alibi_slopes is not None
527
526
and attn_metadata .pse_shift is None ):
528
527
attn_metadata .pse_shift = _make_alibi_bias (
529
- self .alibi_slopes ,
528
+ self .alibi_slopes . to ( query . device ) ,
530
529
self .num_kv_heads ,
531
530
dtype = query .dtype ,
532
531
seq_len = attn_metadata .max_prefill_seq_len ,
@@ -571,7 +570,7 @@ def forward(
571
570
query = query .view (query .shape [0 ], - 1 ,
572
571
self .num_heads * self .head_size )
573
572
output = torch .zeros (query .shape ,
574
- device = "npu" ,
573
+ device = query . device ,
575
574
dtype = query .dtype )
576
575
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
577
576
# support only when `S == 1`, OPTIMIZE ME when prefix caching
@@ -621,7 +620,7 @@ def forward(
621
620
return output
622
621
623
622
624
- def gen_input_mask (seq_len , sliding_window , len ):
623
+ def gen_input_mask (seq_len , sliding_window , len , device ):
625
624
"""
626
625
Generating lower triangular matrix
627
626
"""
@@ -630,15 +629,15 @@ def gen_input_mask(seq_len, sliding_window, len):
630
629
global SHARE_MASK_TRIL_PREFIX_CACHE
631
630
if SHARE_MASK_TRIL_PREFIX_CACHE is None :
632
631
SHARE_MASK_TRIL_PREFIX_CACHE = torch .triu (
633
- torch .ones (1 , 1 , 2048 , 2048 , dtype = bool , device = "npu" ),
632
+ torch .ones (1 , 1 , 2048 , 2048 , dtype = bool , device = device ),
634
633
diagonal = 1 ,
635
634
)
636
635
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
637
636
else :
638
637
global SHARE_MASK_TRIL
639
638
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL .shape [0 ] < seq_len :
640
639
SHARE_MASK_TRIL = ~ torch .tril (
641
- torch .ones (seq_len , seq_len , dtype = bool , device = "npu" ))
640
+ torch .ones (seq_len , seq_len , dtype = bool , device = device ))
642
641
643
642
attention_mask = SHARE_MASK_TRIL
644
643
if sliding_window is not None :
0 commit comments