Skip to content

Commit 0578e5a

Browse files
authored
[Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue (#14310)
Signed-off-by: Chengji Yao <[email protected]>
1 parent 0422298 commit 0578e5a

File tree

3 files changed

+58
-66
lines changed

3 files changed

+58
-66
lines changed

requirements-tpu.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ ray[default]
1717
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1818
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
1919
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
20-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
20+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

vllm/v1/attention/backends/pallas.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.attention.backends.utils import CommonAttentionState
1313

1414
# These are the 2 tunable parameters of the paged attention Pallas kernel.
15-
NUM_QUERIES_PER_BLOCK = 16
15+
NUM_QUERIES_PER_BLOCK = 32
1616
NUM_KV_PAGES_PER_BLOCK = 128
1717

1818

@@ -41,7 +41,7 @@ def get_kv_cache_shape(
4141
num_kv_heads: int,
4242
head_size: int,
4343
) -> tuple[int, ...]:
44-
return (num_kv_heads, num_blocks, block_size, head_size)
44+
return (num_blocks, block_size, num_kv_heads, head_size)
4545

4646
@staticmethod
4747
def swap_blocks(
@@ -115,6 +115,17 @@ def __init__(
115115
"are not implemented for "
116116
"PallasAttentionBackendImpl")
117117

118+
tpu_version = torch_xla.tpu.version()
119+
if tpu_version < 4:
120+
raise NotImplementedError("TPU version must be 4 or higher.")
121+
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
122+
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
123+
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
124+
if tpu_version == 4:
125+
self.vmem_limit_bytes = 16 * 1024 * 1024
126+
else:
127+
self.vmem_limit_bytes = 64 * 1024 * 1024
128+
118129
def forward(
119130
self,
120131
layer: AttentionLayer,
@@ -131,8 +142,8 @@ def forward(
131142
query: shape = [num_tokens, num_heads * head_size]
132143
key: shape = [num_tokens, num_kv_heads * head_size]
133144
value: shape = [num_tokens, num_kv_heads * head_size]
134-
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
135-
[num_kv_heads, num_blocks, block_size, head_size])
145+
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size],
146+
[num_blocks, block_size, num_kv_heads, head_size])
136147
attn_metadata: Metadata for attention.
137148
Returns:
138149
shape = [num_tokens, num_heads * head_size]
@@ -154,10 +165,6 @@ def forward(
154165
slot_mapping = attn_metadata.slot_mapping
155166
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
156167

157-
query = query * self.scale
158-
# use_kernel switches between using kernel or reference implementation
159-
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
160-
use_kernel = False
161168
output = torch.ops.xla.ragged_paged_attention(
162169
query,
163170
key_cache,
@@ -168,8 +175,9 @@ def forward(
168175
attn_metadata.num_seqs,
169176
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
170177
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
171-
use_kernel=use_kernel,
172-
)
178+
vmem_limit_bytes=self.vmem_limit_bytes,
179+
use_kernel=True,
180+
sm_scale=self.scale)
173181

174182
return output.reshape(num_tokens, hidden_size)
175183

@@ -186,16 +194,15 @@ def write_to_kv_cache(
186194
Args:
187195
key: shape = [num_tokens, num_kv_heads, head_size]
188196
value: shape = [num_tokens, num_kv_heads, head_size]
189-
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
190-
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
197+
k_cache = [num_blocks, block_size, num_kv_heads, head_size]
198+
v_cache = [num_blocks, block_size, num_kv_heads, head_size]
191199
192200
"""
193201
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
194202
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
195203

196-
key = key.flatten(0, 1)
197-
value = value.flatten(0, 1)
198-
key_cache = key_cache.flatten(0, 2)
199-
value_cache = value_cache.flatten(0, 2)
204+
key_cache = key_cache.flatten(0, 1)
205+
value_cache = value_cache.flatten(0, 1)
206+
slot_mapping = slot_mapping.flatten()
200207
key_cache.index_copy_(0, slot_mapping, key)
201208
value_cache.index_copy_(0, slot_mapping, value)

vllm/v1/worker/tpu_model_runner.py

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.attention.backends.abstract import AttentionType
1515
from vllm.attention.layer import Attention
1616
from vllm.config import VllmConfig
17-
from vllm.forward_context import get_forward_context, set_forward_context
17+
from vllm.forward_context import set_forward_context
1818
from vllm.inputs import INPUT_REGISTRY
1919
from vllm.logger import init_logger
2020
from vllm.model_executor.model_loader import get_model
@@ -416,8 +416,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
416416
num_scheduled_tokens_per_req)
417417

418418
# Do the padding and copy the tensors to the TPU.
419-
padded_total_num_scheduled_tokens = _get_padded_number(
420-
total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK)
419+
padded_total_num_scheduled_tokens = _get_padded_token_len(
420+
total_num_scheduled_tokens)
421421
self.input_ids = self.input_ids_cpu[:
422422
padded_total_num_scheduled_tokens].to(
423423
self.device)
@@ -428,23 +428,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
428428
slot_mapping = self.slot_mapping_cpu[:
429429
padded_total_num_scheduled_tokens].to(
430430
self.device)
431-
padded_block_table = self.block_table_cpu[:
432-
padded_total_num_scheduled_tokens]
433-
padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = (
431+
block_tables = self.block_table_cpu[:self.max_num_reqs]
432+
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
434433
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
435-
padded_block_table = padded_block_table.to(self.device)
436-
query_start_loc = self.query_start_loc_cpu[:
437-
padded_total_num_scheduled_tokens
438-
+ 1].to(self.device)
439-
seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(
434+
block_tables = block_tables.to(self.device)
435+
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
440436
self.device)
437+
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
441438

442439
attn_metadata = PallasMetadata(
443440
slot_mapping=slot_mapping,
444-
block_tables=padded_block_table,
441+
block_tables=block_tables,
445442
context_lens=seq_lens,
446443
query_start_loc=query_start_loc,
447-
num_seqs=num_reqs,
444+
num_seqs=torch.tensor([num_reqs],
445+
dtype=torch.int32,
446+
device=self.device),
448447
)
449448
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
450449
# request in the batch. While we should not sample any token from this
@@ -693,29 +692,34 @@ def _dummy_run(
693692
dtype=torch.int32,
694693
device=self.device)
695694
inputs_embeds = None
695+
actual_num_reqs = min(num_tokens, self.max_num_reqs)
696696
position_ids = torch.zeros(num_tokens,
697697
dtype=torch.int32,
698698
device=self.device)
699699
slot_mapping = torch.zeros(num_tokens,
700700
dtype=torch.int64,
701701
device=self.device)
702-
block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]),
703-
dtype=torch.int32,
704-
device=self.device)
705-
query_lens = [1] * num_tokens
702+
block_tables = torch.zeros(
703+
(self.max_num_reqs, self.block_table_cpu.shape[1]),
704+
dtype=torch.int32,
705+
device=self.device)
706+
query_lens = [1] * self.max_num_reqs
706707
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
707708
dtype=torch.int32),
708709
dim=0,
709710
dtype=torch.int32).to(self.device)
710-
context_lens = torch.ones((num_tokens, ),
711+
context_lens = torch.ones((self.max_num_reqs, ),
711712
dtype=torch.int32,
712713
device=self.device)
714+
num_seqs = torch.tensor([actual_num_reqs],
715+
dtype=torch.int32,
716+
device=self.device)
713717
attn_metadata = PallasMetadata(
714718
slot_mapping=slot_mapping,
715719
block_tables=block_tables,
716720
context_lens=context_lens,
717721
query_start_loc=query_start_loc,
718-
num_seqs=num_tokens,
722+
num_seqs=num_seqs,
719723
)
720724

721725
if self.is_multimodal_model:
@@ -724,9 +728,6 @@ def _dummy_run(
724728
torch._dynamo.mark_dynamic(input_ids, 0)
725729
torch._dynamo.mark_dynamic(position_ids, 0)
726730
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
727-
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
728-
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
729-
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
730731

731732
with set_forward_context(attn_metadata, self.vllm_config, 0):
732733
assert self.model is not None
@@ -817,28 +818,6 @@ def forward(
817818
inputs_embeds: The input embeddings of shape [num_tokens,
818819
hidden_size]. It is used for multimodal models.
819820
"""
820-
# Skip this in memory profiling at initialization.
821-
if kv_caches[0][0].numel() > 0:
822-
attn_metadata = get_forward_context().attn_metadata
823-
# index_copy_(slot_mapping) only works when the inserted dimension
824-
# is 0. However, the KV cache in the Pallas backend has the shape
825-
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
826-
# work, we need to flatten the first three dimensions and modify
827-
# the slot_mapping accordingly.
828-
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
829-
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
830-
slot_mapping = attn_metadata.slot_mapping
831-
slot_mapping = slot_mapping.flatten()
832-
head_indicies = torch.arange(0,
833-
num_kv_heads,
834-
device=slot_mapping.device,
835-
dtype=slot_mapping.dtype)
836-
head_indicies *= block_size * num_blocks
837-
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
838-
-1, num_kv_heads)
839-
slot_mapping = slot_mapping + head_indicies.view(1, -1)
840-
slot_mapping = slot_mapping.flatten()
841-
attn_metadata.slot_mapping = slot_mapping
842821

843822
assert self.model is not None
844823
hidden_states = self.model(
@@ -866,3 +845,9 @@ def get_input_embeddings(self, *args, **kwargs):
866845

867846
def _get_padded_number(n: int, multiple: int) -> int:
868847
return ((n + multiple - 1) // multiple) * multiple
848+
849+
850+
def _get_padded_token_len(x: int) -> int:
851+
if x <= 16:
852+
return 16
853+
return 1 << (x - 1).bit_length()

0 commit comments

Comments
 (0)