Skip to content

[Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue #14310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
39 changes: 23 additions & 16 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.attention.backends.utils import CommonAttentionState

# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 16
NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 128


Expand Down Expand Up @@ -41,7 +41,7 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
return (num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def swap_blocks(
Expand Down Expand Up @@ -115,6 +115,17 @@ def __init__(
"are not implemented for "
"PallasAttentionBackendImpl")

tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if tpu_version == 4:
self.vmem_limit_bytes = 16 * 1024 * 1024
else:
self.vmem_limit_bytes = 64 * 1024 * 1024

def forward(
self,
layer: AttentionLayer,
Expand All @@ -131,8 +142,8 @@ def forward(
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
[num_kv_heads, num_blocks, block_size, head_size])
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size],
[num_blocks, block_size, num_kv_heads, head_size])
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -154,10 +165,6 @@ def forward(
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)

query = query * self.scale
# use_kernel switches between using kernel or reference implementation
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
use_kernel = False
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
Expand All @@ -168,8 +175,9 @@ def forward(
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
use_kernel=use_kernel,
)
vmem_limit_bytes=self.vmem_limit_bytes,
use_kernel=True,
sm_scale=self.scale)

return output.reshape(num_tokens, hidden_size)

Expand All @@ -186,16 +194,15 @@ def write_to_kv_cache(
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
k_cache = [num_blocks, block_size, num_kv_heads, head_size]
v_cache = [num_blocks, block_size, num_kv_heads, head_size]
"""
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

key = key.flatten(0, 1)
value = value.flatten(0, 1)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache = key_cache.flatten(0, 1)
value_cache = value_cache.flatten(0, 1)
slot_mapping = slot_mapping.flatten()
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
73 changes: 29 additions & 44 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -416,8 +416,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
num_scheduled_tokens_per_req)

# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_number(
total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK)
padded_total_num_scheduled_tokens = _get_padded_token_len(
total_num_scheduled_tokens)
self.input_ids = self.input_ids_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
Expand All @@ -428,23 +428,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
slot_mapping = self.slot_mapping_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
padded_block_table = self.block_table_cpu[:
padded_total_num_scheduled_tokens]
padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = (
block_tables = self.block_table_cpu[:self.max_num_reqs]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
padded_block_table = padded_block_table.to(self.device)
query_start_loc = self.query_start_loc_cpu[:
padded_total_num_scheduled_tokens
+ 1].to(self.device)
seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(
block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
self.device)
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)

attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=padded_block_table,
block_tables=block_tables,
context_lens=seq_lens,
query_start_loc=query_start_loc,
num_seqs=num_reqs,
num_seqs=torch.tensor([num_reqs],
dtype=torch.int32,
device=self.device),
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
Expand Down Expand Up @@ -693,29 +692,34 @@ def _dummy_run(
dtype=torch.int32,
device=self.device)
inputs_embeds = None
actual_num_reqs = min(num_tokens, self.max_num_reqs)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]),
dtype=torch.int32,
device=self.device)
query_lens = [1] * num_tokens
block_tables = torch.zeros(
(self.max_num_reqs, self.block_table_cpu.shape[1]),
dtype=torch.int32,
device=self.device)
query_lens = [1] * self.max_num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32),
dim=0,
dtype=torch.int32).to(self.device)
context_lens = torch.ones((num_tokens, ),
context_lens = torch.ones((self.max_num_reqs, ),
dtype=torch.int32,
device=self.device)
num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32,
device=self.device)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_tokens,
num_seqs=num_seqs,
)

if self.is_multimodal_model:
Expand All @@ -724,9 +728,6 @@ def _dummy_run(
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)

with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
Expand Down Expand Up @@ -817,28 +818,6 @@ def forward(
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indicies *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping

assert self.model is not None
hidden_states = self.model(
Expand Down Expand Up @@ -866,3 +845,9 @@ def get_input_embeddings(self, *args, **kwargs):

def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple


def _get_padded_token_len(x: int) -> int:
if x <= 16:
return 16
return 1 << (x - 1).bit_length()