Skip to content

Commit cb5b704

Browse files
committed
[attn] fix device of tensors
Signed-off-by: MengqingCao <[email protected]>
1 parent c59375c commit cb5b704

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

examples/offline_distributed_inference_npu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929
# Create a sampling params object.
3030
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
3131
# Create an LLM.
32-
# TODO (cmq): ray is not supported currently, need some fixes
3332
llm = LLM(
3433
model="facebook/opt-125m",
3534
tensor_parallel_size=2,
36-
distributed_executor_backend="mp",
35+
distributed_executor_backend="ray",
3736
trust_remote_code=True,
3837
)
3938

vllm_ascend/attention.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,7 @@ def __init__(
458458
self.sliding_window = sliding_window
459459
if alibi_slopes is not None:
460460
alibi_slopes = torch.tensor(alibi_slopes,
461-
dtype=torch.float32,
462-
device="npu")
461+
dtype=torch.float32)
463462
self.alibi_slopes = alibi_slopes
464463
self.attn_type = attn_type
465464

@@ -520,13 +519,13 @@ def forward(
520519
attn_metadata.sparse_mode = 2
521520
attention_mask = gen_input_mask(
522521
attn_metadata.max_prefill_seq_len, self.sliding_window,
523-
num_tokens)
522+
num_tokens, query.device)
524523
attn_metadata.attn_mask = attention_mask
525524

526525
if (self.alibi_slopes is not None
527526
and attn_metadata.pse_shift is None):
528527
attn_metadata.pse_shift = _make_alibi_bias(
529-
self.alibi_slopes,
528+
self.alibi_slopes.to(query.device),
530529
self.num_kv_heads,
531530
dtype=query.dtype,
532531
seq_len=attn_metadata.max_prefill_seq_len,
@@ -571,7 +570,7 @@ def forward(
571570
query = query.view(query.shape[0], -1,
572571
self.num_heads * self.head_size)
573572
output = torch.zeros(query.shape,
574-
device="npu",
573+
device=query.device,
575574
dtype=query.dtype)
576575
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
577576
# support only when `S == 1`, OPTIMIZE ME when prefix caching
@@ -621,7 +620,7 @@ def forward(
621620
return output
622621

623622

624-
def gen_input_mask(seq_len, sliding_window, len):
623+
def gen_input_mask(seq_len, sliding_window, len, device):
625624
"""
626625
Generating lower triangular matrix
627626
"""
@@ -630,15 +629,15 @@ def gen_input_mask(seq_len, sliding_window, len):
630629
global SHARE_MASK_TRIL_PREFIX_CACHE
631630
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
632631
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
633-
torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"),
632+
torch.ones(1, 1, 2048, 2048, dtype=bool, device=device),
634633
diagonal=1,
635634
)
636635
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
637636
else:
638637
global SHARE_MASK_TRIL
639638
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
640639
SHARE_MASK_TRIL = ~torch.tril(
641-
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))
640+
torch.ones(seq_len, seq_len, dtype=bool, device=device))
642641

643642
attention_mask = SHARE_MASK_TRIL
644643
if sliding_window is not None:

0 commit comments

Comments
 (0)