diff --git a/requirements-tpu.txt b/requirements-tpu.txt index d999e8f1c90..4bc6a9b83bd 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -17,9 +17,9 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 543e8487e28..bbbdf50ac0c 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -12,7 +12,7 @@ from vllm.attention.backends.utils import CommonAttentionState # These are the 2 tunable parameters of the paged attention Pallas kernel. -NUM_QUERIES_PER_BLOCK = 16 +NUM_QUERIES_PER_BLOCK = 32 NUM_KV_PAGES_PER_BLOCK = 128 @@ -41,7 +41,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: - return (num_kv_heads, num_blocks, block_size, head_size) + return (num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -115,6 +115,17 @@ def __init__( "are not implemented for " "PallasAttentionBackendImpl") + tpu_version = torch_xla.tpu.version() + if tpu_version < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + # NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB + # TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK, + # NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes + if tpu_version == 4: + self.vmem_limit_bytes = 16 * 1024 * 1024 + else: + self.vmem_limit_bytes = 64 * 1024 * 1024 + def forward( self, layer: AttentionLayer, @@ -131,8 +142,8 @@ def forward( query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], - [num_kv_heads, num_blocks, block_size, head_size]) + kv_cache = ([num_blocks, block_size, num_kv_heads, head_size], + [num_blocks, block_size, num_kv_heads, head_size]) attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -154,10 +165,6 @@ def forward( slot_mapping = attn_metadata.slot_mapping write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) - query = query * self.scale - # use_kernel switches between using kernel or reference implementation - # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). - use_kernel = False output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -168,8 +175,9 @@ def forward( attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - use_kernel=use_kernel, - ) + vmem_limit_bytes=self.vmem_limit_bytes, + use_kernel=True, + sm_scale=self.scale) return output.reshape(num_tokens, hidden_size) @@ -186,16 +194,15 @@ def write_to_kv_cache( Args: key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - k_cache = [num_kv_heads, num_blocks, block_size, head_size] - v_cache = [num_kv_heads, num_blocks, block_size, head_size] + k_cache = [num_blocks, block_size, num_kv_heads, head_size] + v_cache = [num_blocks, block_size, num_kv_heads, head_size] """ torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) - key = key.flatten(0, 1) - value = value.flatten(0, 1) - key_cache = key_cache.flatten(0, 2) - value_cache = value_cache.flatten(0, 2) + key_cache = key_cache.flatten(0, 1) + value_cache = value_cache.flatten(0, 1) + slot_mapping = slot_mapping.flatten() key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f9a3217fbef..f661412d937 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -14,7 +14,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context, set_forward_context +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -416,8 +416,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_scheduled_tokens_per_req) # Do the padding and copy the tensors to the TPU. - padded_total_num_scheduled_tokens = _get_padded_number( - total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) + padded_total_num_scheduled_tokens = _get_padded_token_len( + total_num_scheduled_tokens) self.input_ids = self.input_ids_cpu[: padded_total_num_scheduled_tokens].to( self.device) @@ -428,23 +428,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): slot_mapping = self.slot_mapping_cpu[: padded_total_num_scheduled_tokens].to( self.device) - padded_block_table = self.block_table_cpu[: - padded_total_num_scheduled_tokens] - padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = ( + block_tables = self.block_table_cpu[:self.max_num_reqs] + block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) - padded_block_table = padded_block_table.to(self.device) - query_start_loc = self.query_start_loc_cpu[: - padded_total_num_scheduled_tokens - + 1].to(self.device) - seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to( + block_tables = block_tables.to(self.device) + query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) + seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, - block_tables=padded_block_table, + block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, - num_seqs=num_reqs, + num_seqs=torch.tensor([num_reqs], + dtype=torch.int32, + device=self.device), ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -693,29 +692,34 @@ def _dummy_run( dtype=torch.int32, device=self.device) inputs_embeds = None + actual_num_reqs = min(num_tokens, self.max_num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32, device=self.device) slot_mapping = torch.zeros(num_tokens, dtype=torch.int64, device=self.device) - block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]), - dtype=torch.int32, - device=self.device) - query_lens = [1] * num_tokens + block_tables = torch.zeros( + (self.max_num_reqs, self.block_table_cpu.shape[1]), + dtype=torch.int32, + device=self.device) + query_lens = [1] * self.max_num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device) - context_lens = torch.ones((num_tokens, ), + context_lens = torch.ones((self.max_num_reqs, ), dtype=torch.int32, device=self.device) + num_seqs = torch.tensor([actual_num_reqs], + dtype=torch.int32, + device=self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, query_start_loc=query_start_loc, - num_seqs=num_tokens, + num_seqs=num_seqs, ) if self.is_multimodal_model: @@ -724,9 +728,6 @@ def _dummy_run( torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None @@ -817,28 +818,6 @@ def forward( inputs_embeds: The input embeddings of shape [num_tokens, hidden_size]. It is used for multimodal models. """ - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - attn_metadata = get_forward_context().attn_metadata - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - # kv_caches: list[tuple[torch.Tensor, torch.Tensor]] - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape - slot_mapping = attn_metadata.slot_mapping - slot_mapping = slot_mapping.flatten() - head_indicies = torch.arange(0, - num_kv_heads, - device=slot_mapping.device, - dtype=slot_mapping.dtype) - head_indicies *= block_size * num_blocks - slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( - -1, num_kv_heads) - slot_mapping = slot_mapping + head_indicies.view(1, -1) - slot_mapping = slot_mapping.flatten() - attn_metadata.slot_mapping = slot_mapping assert self.model is not None hidden_states = self.model( @@ -866,3 +845,9 @@ def get_input_embeddings(self, *args, **kwargs): def _get_padded_number(n: int, multiple: int) -> int: return ((n + multiple - 1) // multiple) * multiple + + +def _get_padded_token_len(x: int) -> int: + if x <= 16: + return 16 + return 1 << (x - 1).bit_length()