Skip to content

Commit 81a23a7

Browse files
Use flat block layout for PA (#92)
* Cleanup AttentionMetadata on HPU * Flat PA - POC * Decode warmup overhaul * Debugging OOM * Experimental profiling * Fix input_hash calculation * Block bucket size 32 -> 16 * Improve host time * Skip UTs * Add GQA/MQA * Add mask instead of filling * 2d block mapping * Optional flipping in PA * Runner updated for 2d block mapping * Restore mark_step * Eliminate physical transposes * Disable warmup_mode * Revert changes to test_attention.py * POC: build block_bias on device * Cleanup * Fix seq_len calculation * Experimental profiling * Add missing call to kv_matmul_op * Fix block_usage calculation * Change default block bucket step for decode to 128 * Fix max decode block bucket calculation * Fix block_usage calculations * Cleanup * Cleanup profiler code * Print values for bucketing vars * Pass block size do HpuModelAdapter --------- Co-authored-by: barak goldberg <[email protected]>
1 parent d4e72b8 commit 81a23a7

File tree

5 files changed

+329
-325
lines changed

5 files changed

+329
-325
lines changed

vllm/attention/backends/habana_attn.py

Lines changed: 30 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -59,59 +59,12 @@ def copy_blocks(
5959
HabanaPagedAttention.copy_blocks(kv_caches, src_to_dists)
6060

6161

62-
@dataclass
63-
class HabanaAttentionMetadata(AttentionMetadataPerStage, HabanaPagedAttentionMetadata):
64-
"""Metadata for HabanaAttentionbackend.
65-
66-
NOTE: Any python object stored here is not updated when it is
67-
cuda-graph replayed. If you have values that need to be changed
68-
dynamically, it should be stored in tensor. The tensor has to be
69-
updated from `CUDAGraphRunner.forward` API.
70-
"""
71-
# Currently, input sequences can only contain all prompts
72-
# or all decoding. True if all sequences are prompts.
73-
is_prompt: bool
74-
# (batch_size,). The sequence length per sequence. Sequence length means
75-
# the computed tokens + new tokens None if it is a decoding.
76-
seq_lens: Optional[List[int]]
77-
# seq_lens stored as a tensor.
62+
@dataclass(frozen=True)
63+
class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadataPerStage):
64+
"""Metadata for HabanaAttentionbackend."""
65+
attn_bias: Optional[torch.Tensor]
7866
seq_lens_tensor: Optional[torch.Tensor]
7967

80-
# |---------- N-1 iteration --------|
81-
# |---------------- N iteration ---------------------|
82-
# |- tokenA -|......................|-- newTokens ---|
83-
# |---------- context_len ----------|
84-
# |-------------------- seq_len ----------------------|
85-
# |-- query_len ---|
86-
87-
# Maximum query length in the batch.
88-
max_query_len: Optional[int]
89-
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
90-
# the batch, used to index into subquery. E.g., if the subquery length
91-
# is [4, 6], it is [0, 4, 10].
92-
subquery_start_loc: Optional[torch.Tensor]
93-
# FIXME: It is for flash attn.
94-
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
95-
# the batch, used to index into sequence. E.g., if the sequence length is
96-
# [4, 6], it is [0, 4, 10].
97-
seq_start_loc: Optional[torch.Tensor]
98-
# (batch_size,) A tensor of context lengths (tokens that are computed
99-
# so far).
100-
context_lens_tensor: Optional[torch.Tensor]
101-
102-
# Whether or not if cuda graph is enabled.
103-
# Cuda-graph is currently enabled for decoding only.
104-
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
105-
use_cuda_graph: bool
106-
107-
def __post_init__(self):
108-
# Set during the execution of the first attention op.
109-
# It is a list because it is needed to set per prompt
110-
# when alibi slopes is used. It is because of the limitation
111-
# from xformer API.
112-
# will not appear in the __repr__ and __init__
113-
self.attn_bias: Optional[List[AttentionBias]] = None
114-
11568

11669
class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
11770
"""
@@ -202,57 +155,35 @@ def forward(
202155

203156
if prefill_meta := attn_metadata.prefill_metadata:
204157
# Prompt run.
205-
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
206-
# TODO: move this outside of model
207-
assert prefill_meta.attn_bias is not None, 'attn_bias must be set before calling model.forward!'
208-
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
209-
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
210-
out = xops.prompt_attention(
211-
query.view(query_shape),
212-
key.view(kv_shape),
213-
value.view(kv_shape),
214-
attn_bias=prefill_meta.attn_bias,
215-
p=0.0,
216-
scale=self.scale,
217-
qk_matmul_op=self.qk_matmul,
218-
softmax_op=self.softmax,
219-
kv_matmul_op=self.kv_matmul,
220-
)
221-
output = out.reshape(batch_size, seq_len, hidden_size)
222-
else:
223-
# prefix-enabled attention
224-
output = HabanaPagedAttention.forward_prefix(
225-
query,
226-
key,
227-
value,
228-
key_cache,
229-
value_cache,
230-
prefill_meta.block_tables,
231-
prefill_meta.subquery_start_loc,
232-
prefill_meta.seq_lens_tensor,
233-
prefill_meta.context_lens_tensor,
234-
prefill_meta.max_query_len,
235-
self.alibi_slopes,
236-
)
158+
assert prefill_meta.attn_bias is not None, 'attn_bias must be set before calling model.forward!'
159+
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
160+
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
161+
out = xops.prompt_attention(
162+
query.view(query_shape),
163+
key.view(kv_shape),
164+
value.view(kv_shape),
165+
attn_bias=prefill_meta.attn_bias,
166+
p=0.0,
167+
scale=self.scale,
168+
qk_matmul_op=self.qk_matmul,
169+
softmax_op=self.softmax,
170+
kv_matmul_op=self.kv_matmul,
171+
)
172+
output = out.reshape(batch_size, seq_len, hidden_size)
237173
if decode_meta := attn_metadata.decode_metadata:
238174
# Decoding run.
239175
output = HabanaPagedAttention.forward_decode(
240-
query,
241-
key_cache,
242-
value_cache,
243-
decode_meta.block_tables,
244-
decode_meta.seq_lens_tensor,
245-
attn_metadata.kv_cache_dtype,
246-
self.num_kv_heads,
247-
self.scale,
248-
self.alibi_slopes,
249-
kv_scale,
250-
self.qk_matmul,
251-
self.softmax,
252-
self.kv_matmul,
253-
self.key_cache.fetch_from_cache,
254-
self.value_cache.fetch_from_cache,
255-
)
176+
query=query,
177+
key_cache=key_cache,
178+
value_cache=value_cache,
179+
block_list=decode_meta.block_list,
180+
block_mapping=decode_meta.block_mapping,
181+
block_bias=decode_meta.attn_bias,
182+
scale=self.scale,
183+
qk_matmul_op=self.qk_matmul,
184+
kv_matmul_op=self.kv_matmul,
185+
keys_fetch_func=self.key_cache.fetch_from_cache,
186+
values_fetch_func=self.value_cache.fetch_from_cache)
256187

257188
# Reshape the output tensor.
258189
return output.view(batch_size, seq_len, hidden_size)

vllm/attention/ops/habana_paged_attn.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,12 @@
1313
_PARTITION_SIZE = 512
1414

1515

16-
@dataclass
16+
@dataclass(frozen=True)
1717
class HabanaPagedAttentionMetadata:
1818
"""Metadata for PagedAttention."""
19-
# (batch_size,). The length of sequences (entire tokens seen so far) per
20-
# sequence.
21-
seq_lens_tensor: Optional[torch.Tensor]
22-
# (batch_size, max_blocks_per_seq).
23-
# Block addresses per sequence. (Seq id -> list of physical block)
24-
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
25-
# in the kv cache. Each block can contain up to block_size tokens.
26-
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
27-
# captured.
28-
block_tables: Optional[torch.Tensor]
19+
block_list: Optional[torch.Tensor]
20+
block_mapping: Optional[torch.Tensor]
21+
block_usage: Optional[torch.Tensor]
2922

3023

3124
class HabanaPagedAttention:
@@ -74,41 +67,8 @@ def write_to_paged_cache(
7467
)
7568

7669
@staticmethod
77-
def forward_decode(
78-
query: torch.Tensor,
79-
key_cache: torch.Tensor,
80-
value_cache: torch.Tensor,
81-
block_tables: torch.Tensor,
82-
seq_lens: torch.Tensor,
83-
kv_cache_dtype: str,
84-
num_kv_heads: int,
85-
scale: float,
86-
alibi_slopes: Optional[torch.Tensor],
87-
kv_scale: float,
88-
qk_op=torch.matmul,
89-
softmax_op=torch.softmax,
90-
kv_op=torch.matmul,
91-
keys_fetch=ops.fetch_from_cache,
92-
values_fetch=ops.fetch_from_cache,
93-
) -> torch.Tensor:
94-
block_size = value_cache.shape[1]
95-
return ops.paged_attention_v1(
96-
query,
97-
key_cache,
98-
value_cache,
99-
num_kv_heads,
100-
scale,
101-
block_tables,
102-
seq_lens,
103-
block_size,
104-
alibi_slopes,
105-
kv_cache_dtype,
106-
qk_op,
107-
softmax_op,
108-
kv_op,
109-
keys_fetch,
110-
values_fetch,
111-
)
70+
def forward_decode(**kwargs) -> torch.Tensor:
71+
return ops.flat_pa(**kwargs)
11272

11373
@staticmethod
11474
def forward_prefix(

vllm/hpu/ops.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,45 +31,62 @@ def gelu_fast(output, input):
3131
raise NotImplementedError
3232

3333

34-
def fetch_from_cache(cache, blocks, permutations):
35-
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]
36-
37-
38-
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None,
39-
qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache) -> None:
40-
seq_len = block_tables.size(1)
41-
batch_size, query_heads, _ = query.shape
42-
_, _, kv_heads, _ = key_cache.shape
43-
min_inf = torch.finfo(query.dtype).min
44-
mask = (torch.arange(0, seq_len * block_size, dtype=torch.int32, device=key_cache.device)
45-
.view(1, -1)
46-
.expand(batch_size, -1)
47-
.ge(context_lens.view(-1, 1))
48-
.view(batch_size, 1, 1, -1))
49-
query.mul_(scale)
50-
query = query.unsqueeze(-2)
51-
keys = keys_fetch_func(key_cache, block_tables, (0, 2, 3, 1))
52-
if query_heads != kv_heads:
34+
def batch2block(tensor, block_mapping):
35+
shape = tuple(tensor.shape)
36+
return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:])
37+
38+
39+
def block2batch(tensor, block_mapping):
40+
shape = tuple(tensor.shape)
41+
return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:])
42+
43+
44+
def block_softmax(batch_size, attn, block_mapping):
45+
attn = attn.exp_()
46+
sums = attn.sum(dim=-1).unsqueeze(-1)
47+
sums = block2batch(sums, block_mapping)
48+
sums = batch2block(sums, block_mapping)
49+
attn.div_(sums)
50+
return attn
51+
52+
53+
def flat_pa(query,
54+
key_cache,
55+
value_cache,
56+
block_list,
57+
block_mapping,
58+
block_bias,
59+
scale,
60+
qk_matmul_op,
61+
kv_matmul_op,
62+
keys_fetch_func,
63+
values_fetch_func):
64+
batch_size = query.size(0)
65+
q_heads = query.size(1)
66+
kv_heads = key_cache.size(2)
67+
68+
query = batch2block(scale * query, block_mapping).unsqueeze(-2)
69+
key = keys_fetch_func(key_cache, block_list).transpose(1, 2)
70+
value = values_fetch_func(value_cache, block_list).transpose(1, 2)
71+
block_bias = block_bias.view(key.size(0), 1, 1, -1)
72+
73+
if kv_heads != q_heads:
74+
block_bias = block_bias.unsqueeze(1)
5375
query = query.unflatten(1, (kv_heads, -1))
54-
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
55-
mask = mask.unsqueeze(2)
56-
attn_weights = [qk_matmul_op(query, k) for k in keys]
57-
attn_weights = softmax_op(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf),
58-
dim=-1)
59-
60-
values = values_fetch_func(value_cache, block_tables, (0, 2, 1, 3))
61-
if PA_SPLIT_VALUE:
62-
attn_weights = attn_weights.split(block_size, dim=-1)
76+
key = key.unflatten(1, (kv_heads, 1))
77+
value = value.unflatten(1, (kv_heads, 1))
78+
key = key.transpose(3, 4)
6379
else:
64-
values = [torch.cat(values, dim=-2)]
65-
attn_weights = [attn_weights]
66-
if query_heads != kv_heads:
67-
values = [v.unflatten(1, (kv_heads, 1)) for v in values]
68-
attn_weights = [kv_matmul_op(a, v) for a, v in zip(attn_weights, values)]
69-
if query_heads != kv_heads:
70-
attn_weights = [a.flatten(1, 2) for a in attn_weights]
71-
attn_weights = sum(attn_weights)
72-
return attn_weights.squeeze(-2)
80+
key = key.transpose(2, 3)
81+
82+
attn = qk_matmul_op(query, key) + block_bias
83+
attn = block_softmax(batch_size, attn, block_mapping)
84+
attn = kv_matmul_op(attn, value)
85+
attn = block2batch(attn, block_mapping)
86+
attn = attn.squeeze(-2)
87+
if kv_heads != q_heads:
88+
attn = attn.flatten(1, 2)
89+
return attn
7390

7491

7592
def rms_norm(out, hidden_states, weight, eps):

vllm/hpu/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,5 @@ def forward(self, input, cache, block_indices, block_offset):
125125
insert_or_update_cache(input, cache, block_indices, block_offset)
126126
return cache
127127

128-
def fetch_from_cache(self, cache, blocks, permutations):
129-
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]
128+
def fetch_from_cache(self, cache, blocks):
129+
return cache.index_select(0, blocks)

0 commit comments

Comments
 (0)