Skip to content

Commit 7006835

Browse files
authored
[attn] fix device of tensors in attention (#25)
### What this PR does / why we need it? Fix device of tensors created in `AscendAttentionBackendImpl`. While specifying device to cards except card-0, there'll cause an **device conflict** because the tensors (such as `attn_mask`) will be put on card-0 by default. This pr creates these tensors on the correct card corresponding to the input. ### Does this PR introduce _any_ user-facing change? User could specify device with local rank by this pr, and a modify on vLLM is also needed, will related to this pr when created. ### How was this patch tested? This is tested by the following code locally. Will add a test case when the modify in vLLM is also completed. ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.0) # Create an LLM. llm = LLM(model="~/.cache/modelscope/hub/Qwen/Qwen2___5-7B-Instruct", device="npu:1") # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` Signed-off-by: MengqingCao <[email protected]>
1 parent c59375c commit 7006835

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

examples/offline_distributed_inference_npu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929
# Create a sampling params object.
3030
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
3131
# Create an LLM.
32-
# TODO (cmq): ray is not supported currently, need some fixes
3332
llm = LLM(
3433
model="facebook/opt-125m",
3534
tensor_parallel_size=2,
36-
distributed_executor_backend="mp",
35+
distributed_executor_backend="ray",
3736
trust_remote_code=True,
3837
)
3938

vllm_ascend/attention.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,7 @@ def __init__(
457457
self.kv_cache_dtype = kv_cache_dtype
458458
self.sliding_window = sliding_window
459459
if alibi_slopes is not None:
460-
alibi_slopes = torch.tensor(alibi_slopes,
461-
dtype=torch.float32,
462-
device="npu")
460+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
463461
self.alibi_slopes = alibi_slopes
464462
self.attn_type = attn_type
465463

@@ -520,7 +518,7 @@ def forward(
520518
attn_metadata.sparse_mode = 2
521519
attention_mask = gen_input_mask(
522520
attn_metadata.max_prefill_seq_len, self.sliding_window,
523-
num_tokens)
521+
num_tokens, query.device)
524522
attn_metadata.attn_mask = attention_mask
525523

526524
if (self.alibi_slopes is not None
@@ -531,6 +529,7 @@ def forward(
531529
dtype=query.dtype,
532530
seq_len=attn_metadata.max_prefill_seq_len,
533531
batch_size=num_tokens,
532+
device=query.device,
534533
)
535534

536535
if (len(kv_cache) == 0 or attn_metadata.block_tables is None
@@ -571,7 +570,7 @@ def forward(
571570
query = query.view(query.shape[0], -1,
572571
self.num_heads * self.head_size)
573572
output = torch.zeros(query.shape,
574-
device="npu",
573+
device=query.device,
575574
dtype=query.dtype)
576575
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
577576
# support only when `S == 1`, OPTIMIZE ME when prefix caching
@@ -621,7 +620,7 @@ def forward(
621620
return output
622621

623622

624-
def gen_input_mask(seq_len, sliding_window, len):
623+
def gen_input_mask(seq_len, sliding_window, len, device):
625624
"""
626625
Generating lower triangular matrix
627626
"""
@@ -630,15 +629,15 @@ def gen_input_mask(seq_len, sliding_window, len):
630629
global SHARE_MASK_TRIL_PREFIX_CACHE
631630
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
632631
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
633-
torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"),
632+
torch.ones(1, 1, 2048, 2048, dtype=bool, device=device),
634633
diagonal=1,
635634
)
636635
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
637636
else:
638637
global SHARE_MASK_TRIL
639638
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
640639
SHARE_MASK_TRIL = ~torch.tril(
641-
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))
640+
torch.ones(seq_len, seq_len, dtype=bool, device=device))
642641

643642
attention_mask = SHARE_MASK_TRIL
644643
if sliding_window is not None:
@@ -656,8 +655,10 @@ def _make_alibi_bias(
656655
dtype: torch.dtype,
657656
seq_len: int,
658657
batch_size: int,
658+
device: torch.device,
659659
):
660-
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
660+
alibi_slopes = alibi_slopes.to(device)
661+
bias = torch.arange(seq_len, dtype=dtype, device=device)
661662
# NOTE(zhuohan): HF uses
662663
# `bias = bias[None, :].repeat(seq_len, 1)`
663664
# here. We find that both biases give the same results, but
@@ -674,7 +675,7 @@ def _make_alibi_bias(
674675
num_heads,
675676
seq_len,
676677
padded_len,
677-
device=alibi_slopes.device,
678+
device=device,
678679
dtype=dtype,
679680
)[:, :, :, :seq_len].copy_(bias)
680681
bias.mul_(alibi_slopes[:, None, None])

0 commit comments

Comments
 (0)