From 9310c2b2ac89249cd47bd6101fd947b88e5d6ff6 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 25 Apr 2025 00:51:29 +0000 Subject: [PATCH 1/3] Add tuned block sizes --- .../ragged_paged_attention_v2.py | 720 +++++++++++++++--- 1 file changed, 635 insertions(+), 85 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 48c28825418b..89693b5fe952 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """TPU-Friendly Ragged Paged Attention kernel. This kernel offers a highly optimized implementation of ragged paged attention, @@ -23,9 +24,448 @@ from jax import lax from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes import jax.numpy as jnp DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) +# The page size is too small. We only have 32 SREGs in TC. If the pages +# per seq is too large, SREGs will spill. +MAX_PAGES_PER_SEQ = 16 + +# key: +# - q_dtype_name +# - kv_dtype_name +# - num_q_heads_per_blk +# - num_kv_heads_per_blk +# - head_dim +# - page_size +# - max_num_batched_tokens +# - max_model_len = page_size * pages_per_seq +# value: +# - num_kv_pages_per_block +# - num_queries_per_block +TUNED_BLOCK_SIZES = { + 'TPU v6': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, + 'TPU v5': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, +} + +def next_power_of_2(x: int): + """Finds the smallest power of 2 >= x using bit manipulation. + + Args: + x: The input number (should be an integer). + + Returns: + The smallest integer power of 2 that is >= x. + """ + assert x > 0 + if x == 1: + return 1 + return 1 << (x - 1).bit_length() + + +def simplify_key(key): + """Simplify the key to reduce the number of combinations.""" + ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) = key + return ( + jnp.dtype(q_dtype).name, + jnp.dtype(kv_dtype).name, + next_power_of_2(num_q_heads_per_blk), + next_power_of_2(num_kv_heads_per_blk), + (head_dim + 127) // 128 * 128, + next_power_of_2(page_size), + next_power_of_2(max_num_batched_tokens), + next_power_of_2(page_size * pages_per_seq), + ) + + +def get_tpu_version() -> int: + """Returns the numeric version of the TPU, or -1 if not on TPU.""" + kind = jax.devices()[0].device_kind + if 'TPU' not in kind: + return -1 + if kind.endswith(' lite'): + kind = kind[: -len(' lite')] + assert kind[:-1] == 'TPU v', kind + return int(kind[-1]) + + +def get_device_name(num_devices:int | None = None): + name = ' '.join(jax.devices()[0].device_kind.split()[:2]) + if num_devices is not None: + name += f'-{num_devices}' + return name + + +def get_tuned_block_sizes( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, +) -> tuple[int, int]: + """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.""" + tpu_version = get_tpu_version() + if tpu_version < 4: + raise NotImplementedError('TPU version must be 4 or higher.') + key = ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) + key = simplify_key(key) + device_name = get_device_name() + + # Default block sizes. + bkv, bq = (128, 32) + if tpu_version == 4: + # This default block size is not tuned, only make sure there's no + # OOM in vmem + bkv, bq = (32, 32) + elif device_name in TUNED_BLOCK_SIZES: + if key in TUNED_BLOCK_SIZES[device_name]: + bkv, bq = TUNED_BLOCK_SIZES[device_name][key] + return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) + + +def get_min_page_size(max_model_len, min_page_size=16): + """Recommended min page size for high-performance kernel.""" + return max(next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size) class MultiPageAsyncCopyDescriptor: @@ -47,14 +487,16 @@ def __init__( # a bunch of if-ops. Check the performance when we have benchmarking setup. for i in range(vmem_buf.shape[0]): page_idx = kv_pages_start + i - page_idx = jax.lax.select(page_idx < pages_per_seq, page_idx, - pages_per_seq - 1) + page_idx = jax.lax.select( + page_idx < pages_per_seq, page_idx, pages_per_seq - 1 + ) self._async_copies.append( pltpu.make_async_copy( pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], vmem_buf.at[i], sem, - )) + ) + ) def start(self): """Starts the async copies.""" @@ -69,8 +511,7 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -81,8 +522,18 @@ def ref_ragged_paged_attention( soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, ): - validate_static_inputs(queries, kv_pages, kv_lens, page_indices, cu_q_lens, - num_seqs, sliding_window, soft_cap) + static_validate_inputs( + queries, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE _, _, num_combined_kv_heads, head_dim = kv_pages.shape @@ -99,16 +550,19 @@ def ref_ragged_paged_attention( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, - head_dim)[:kv_len] - v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, - head_dim)[:kv_len] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) attn *= sm_scale q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( - jnp.int32, attn.shape, 1) + jnp.int32, attn.shape, 1 + ) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) mask = q_span < kv_span if sliding_window is not None: @@ -124,19 +578,39 @@ def ref_ragged_paged_attention( # Expect to run these checks during runtime. -def validate_dynamic_inputs( +def dynamic_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, sliding_window: int | None = None, soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): - validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, - num_seqs, sliding_window, soft_cap) + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) max_num_batched_tokens = q.shape[0] page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape @@ -147,60 +621,91 @@ def validate_dynamic_inputs( if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" - f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") + f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}." + ) if cu_q_lens[num_seqs[0]] > max_num_batched_tokens: raise ValueError( f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to" - f" {max_num_batched_tokens=}.") + f" {max_num_batched_tokens=}." + ) for i in range(num_seqs[0]): q_len = cu_q_lens[i + 1] - cu_q_lens[i] kv_len = kv_lens[i] if q_len > kv_len: raise ValueError( - f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") + f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." + ) # Expect to run these checks during compile time. -def validate_static_inputs( +def static_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, sliding_window: int | None = None, soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): _, num_q_heads, head_dim = q.shape _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 - max_num_seqs, _ = page_indices.shape + max_num_seqs, pages_per_seq = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") if head_dim_k != head_dim: raise ValueError( - f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") + f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." + ) if kv_lens.shape != (max_num_seqs,): - raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" - " `max_num_seqs` is `page_indices.shape[0]`.") + raise ValueError( + f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" + " `max_num_seqs` is `page_indices.shape[0]`." + ) if cu_q_lens.shape != (max_num_seqs + 1,): raise ValueError( f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" - " `max_num_seqs` is `page_indices.shape[0]`.") - if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or - cu_q_lens.dtype != jnp.int32): + " `max_num_seqs` is `page_indices.shape[0]`." + ) + if ( + kv_lens.dtype != jnp.int32 + or page_indices.dtype != jnp.int32 + or cu_q_lens.dtype != jnp.int32 + ): raise ValueError( "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," - f" {cu_q_lens.dtype=}.") + f" {cu_q_lens.dtype=}." + ) if num_q_heads % num_kv_heads != 0: raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") if sliding_window is not None and sliding_window <= 0: raise ValueError(f"{sliding_window=} must be positive.") if soft_cap is not None and soft_cap == 0.0: raise ValueError(f"{soft_cap=} must not be 0.0.") + if ( + num_kv_pages_per_block is not None + and not 0 < num_kv_pages_per_block <= pages_per_seq + ): + raise ValueError( + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]." + ) + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") + if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: + raise ValueError(f"{vmem_limit_bytes=} must be positive.") + del sm_scale # No constraints on sm_scale. + del mask_value # No consstraints on mask_value. def ragged_paged_attention_kernel( @@ -233,7 +738,8 @@ def ragged_paged_attention_kernel( num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( - kv_bufs.shape) + kv_bufs.shape + ) num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk @@ -247,14 +753,15 @@ def ragged_paged_attention_kernel( q_len_start = q_blk_idx * num_q_per_blk q_len_end = q_len_start + num_q_per_blk - def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, - buf_idx): + def create_kv_async_copy_descriptors( + heads_blk_idx, seq_idx, kv_blk_idx, buf_idx + ): offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) heads_start = heads_blk_idx * num_combined_kv_heads_per_blk async_copy_kv = MultiPageAsyncCopyDescriptor( - kv_pages_hbm_ref.at[:, :, - pl.ds(heads_start, num_combined_kv_heads_per_blk - ), :], + kv_pages_hbm_ref.at[ + :, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), : + ], kv_bufs.at[buf_idx], sems.at[buf_idx], page_indices_ref, @@ -267,7 +774,7 @@ def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, # 2. Support arbitrary strided load/store for any last dimension. def strided_load_kv(ref, start, step): if ref.dtype == jnp.float32: - return ref[start::step, :], ref[start + 1::step, :] + return ref[start::step, :], ref[start + 1 :: step, :] packing = get_dtype_packing(ref.dtype) assert ref.dtype == jnp.bfloat16 assert step % packing == 0 @@ -292,9 +799,9 @@ def fold_on_2nd_minor(vec): @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): - async_copy_kv = create_kv_async_copy_descriptors(heads_blk_idx, - init_seq_idx, 0, - init_buf_idx) + async_copy_kv = create_kv_async_copy_descriptors( + heads_blk_idx, init_seq_idx, 0, init_buf_idx + ) async_copy_kv.start() def is_cur_q_blk_needed(q_states): @@ -310,8 +817,9 @@ def compute_with_cur_q_blk(q_states): q_len = q_end - q_start kv_len = kv_lens_ref[cur_seq_idx] - def get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, - cur_buf_idx): + def get_next_prefetch_ids( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ): next_kv_blk_idx = kv_blk_idx + 1 is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len next_kv_blk_idx = lax.select( @@ -376,12 +884,12 @@ def flash_attention( def masked_store(ref, val, start, end, group=1): iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group mask = jnp.logical_and(iota >= start, iota < end) - pl.store( - ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) + pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) qk = ( - jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) * - sm_scale) + jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) + * sm_scale + ) store_start = jnp.maximum(q_start - q_len_start, 0) store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) @@ -408,12 +916,17 @@ def init_scratch_ref(): store_end, ) - row_ids = ((kv_len - q_len) + q_len_start - q_start + - jax.lax.broadcasted_iota( - jnp.int32, - (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), - 0, - ) // num_q_heads_per_kv_head) + row_ids = ( + (kv_len - q_len) + + q_len_start + - q_start + + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 0, + ) + // num_q_heads_per_kv_head + ) col_ids = kv_len_start + jax.lax.broadcasted_iota( jnp.int32, (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), @@ -421,8 +934,8 @@ def init_scratch_ref(): ) causal_mask = row_ids < col_ids if sliding_window is not None: - causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window - >= col_ids) + causal_mask = jnp.logical_or(causal_mask, + row_ids - sliding_window >= col_ids) if soft_cap is not None: qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) @@ -432,12 +945,14 @@ def init_scratch_ref(): lm_store_shape = head_m_ref.shape m_curr = jnp.broadcast_to(m_curr, lm_store_shape) l_curr = jnp.broadcast_to( - s_curr.sum(axis=1, keepdims=True), lm_store_shape) + s_curr.sum(axis=1, keepdims=True), lm_store_shape + ) m_prev = head_m_ref[...] l_prev = head_l_ref[...] m_next = jnp.maximum(m_prev, m_curr) - masked_store(head_m_ref, m_next, store_start, store_end, - num_q_heads_per_kv_head) + masked_store( + head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head + ) alpha = jnp.exp(m_prev - m_next) beta = jnp.exp(m_curr - m_next) l_alpha = alpha * l_prev @@ -458,8 +973,9 @@ def broadcast_to_shape(arr, shape): assert arr.shape[0] == shape[0] assert shape[1] % arr.shape[1] == 0 # no-op concatenation. - return jnp.concatenate([arr for _ in range(shape[1] // arr.shape[1])], - axis=1) + return jnp.concatenate( + [arr for _ in range(shape[1] // arr.shape[1])], axis=1 + ) o_curr = head_acc_ref[...].reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) @@ -483,8 +999,10 @@ def is_valid_kv_blk_in_cur_seq(kv_states): def compute_with_kv_blk_in_cur_seq(kv_states): kv_blk_idx, cur_buf_idx = kv_states next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = ( - get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, - cur_buf_idx)) + get_next_prefetch_ids( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ) + ) @pl.when(next_heads_blk_idx < num_heads_blks) def prefetch_next_kv_blk(): @@ -492,11 +1010,13 @@ def prefetch_next_kv_blk(): # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and # DMA to fixed size buffer! next_async_copy_kv = create_kv_async_copy_descriptors( - next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx) + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx + ) next_async_copy_kv.start() cur_async_copy_kv = create_kv_async_copy_descriptors( - heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx) + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ) kv_ref = cur_async_copy_kv.wait().reshape( num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, @@ -505,17 +1025,19 @@ def prefetch_next_kv_blk(): q_head_idx = kv_head_idx * num_q_heads_per_kv_head # TODO(jevinjiang): extra handlig for packed type that can start at # unaligned position! - q = fold_on_2nd_minor(q_ref[:, q_head_idx:q_head_idx + - num_q_heads_per_kv_head, :]) - k, v = strided_load_kv(kv_ref, kv_head_idx * 2, - num_combined_kv_heads_per_blk) + q = fold_on_2nd_minor( + q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] + ) + k, v = strided_load_kv( + kv_ref, kv_head_idx * 2, num_combined_kv_heads_per_blk + ) flash_attention( q, k, v, l_ref.at[kv_head_idx], m_ref.at[kv_head_idx], - acc_ref.at[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :], + acc_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], kv_blk_idx=kv_blk_idx, ) return kv_blk_idx + 1, next_buf_idx @@ -557,8 +1079,9 @@ def get_dtype_packing(dtype): raise ValueError(f"Not implemented: unsupported {dtype=}") -def get_min_heads_per_blk(num_q_heads, num_combined_kv_heads, q_dtype, - kv_dtype): +def get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype +): q_packing = get_dtype_packing(q_dtype) kv_packing = get_dtype_packing(kv_dtype) @@ -581,8 +1104,10 @@ def can_be_xla_fully_tiled(x, packing): # second minor tiling is not on. max_combined_kv_tiling = 8 * kv_packing min_combined_kv_heads = ( - max_combined_kv_tiling if num_combined_kv_heads % - max_combined_kv_tiling == 0 else num_combined_kv_heads) + max_combined_kv_tiling + if num_combined_kv_heads % max_combined_kv_tiling == 0 + else num_combined_kv_heads + ) min_q_heads = min_combined_kv_heads // 2 * ratio if can_be_xla_fully_tiled(min_q_heads, q_packing): return min_q_heads, min_combined_kv_heads @@ -604,8 +1129,7 @@ def can_be_xla_fully_tiled(x, packing): def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] # TODO(jevinjiang): create a write_to_kv_cache kernel! - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -615,8 +1139,8 @@ def ragged_paged_attention( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, - num_kv_pages_per_block: int = 16, - num_queries_per_block: int = 128, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, ): """Ragged paged attention that supports mixed prefill and decode. @@ -643,20 +1167,46 @@ def ragged_paged_attention( Returns: The output of the attention. """ - validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, - num_seqs, sliding_window, soft_cap) + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE - num_q, num_q_heads, head_dim = q.shape + num_q_tokens, num_q_heads, head_dim = q.shape _, page_size, num_combined_kv_heads, _ = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 + _, pages_per_seq = page_indices.shape + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype + ) num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block + if num_q_per_blk is None or num_kv_pages_per_blk is None: + num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( + q.dtype, + kv_pages.dtype, + num_q_heads_per_blk, + num_combined_kv_heads_per_blk // 2, + head_dim, + page_size, + num_q_tokens, + pages_per_seq, + ) num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = cdiv(num_q, num_q_per_blk) - num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype) + num_q_blks = cdiv(num_q_tokens, num_q_per_blk) assert num_combined_kv_heads_per_blk % 2 == 0 num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 From 88e7c790c042eaf12e3135a74d160a5b7b04d7be Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 25 Apr 2025 01:16:15 +0000 Subject: [PATCH 2/3] Update tests --- test/test_pallas.py | 26 +-- torch_xla/experimental/custom_kernel.py | 11 -- .../ragged_paged_attention_v2.py | 186 +++++++----------- torch_xla/experimental/tuned_block_sizes.py | 86 -------- 4 files changed, 78 insertions(+), 231 deletions(-) delete mode 100644 torch_xla/experimental/tuned_block_sizes.py diff --git a/test/test_pallas.py b/test/test_pallas.py index c1f9df9ba0c2..d3f0e0f30c31 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -637,8 +637,8 @@ def _test_ragged_paged_attention( sm_scale=1.0, sliding_window=None, soft_cap=None, - num_kv_pages_per_block=16, - num_queries_per_block=128, + num_kv_pages_per_block=None, + num_queries_per_block=None, pad_tokens_and_seqs=False, use_dynamo=True, ): @@ -751,16 +751,6 @@ def ragged_paged_attention_wrapper( num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32) from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention - from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size - if num_kv_pages_per_block is None: - assert num_queries_per_block is None - token_num, q_head_num, _ = q.shape - _, page_size, num_combined_kv_heads, _ = kv_pages.shape - _, pages_per_seq = page_indices.shape - num_kv_heads = num_combined_kv_heads // 2 - max_model_len = pages_per_seq * page_size - num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size( - q_head_num, num_kv_heads, token_num, max_model_len) jax_kernel_output = torch.from_numpy( np.array( jax_ragged_paged_attention( @@ -790,8 +780,7 @@ def ragged_paged_attention_wrapper( sm_scale=[1.0, 0.5], sliding_window=[None, 128], soft_cap=[None, 10.0], - pad_tokens_and_seqs=[False, True], - block_sizes=[(16, 128), (None, None)]) + pad_tokens_and_seqs=[False, True]) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_ragged_paged_attention_wrapper_with_dynamo( @@ -803,12 +792,10 @@ def test_ragged_paged_attention_wrapper_with_dynamo( sliding_window, soft_cap, pad_tokens_and_seqs, - block_sizes, ): head_dim = 128 page_size = 16 num_pages = 1000 - num_kv_pages_per_block, num_queries_per_block = block_sizes self._test_ragged_paged_attention( seq_lens, @@ -822,8 +809,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo( soft_cap=soft_cap, pad_tokens_and_seqs=pad_tokens_and_seqs, use_dynamo=True, - num_kv_pages_per_block=num_kv_pages_per_block, - num_queries_per_block=num_queries_per_block, ) @parameterized.product( @@ -834,7 +819,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo( sliding_window=[None, 128], soft_cap=[None, 10.0], pad_tokens_and_seqs=[False, True], - block_sizes=[(16, 128), (None, None)], ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") @@ -847,12 +831,10 @@ def test_ragged_paged_attention_wrapper_without_dynamo( sliding_window, soft_cap, pad_tokens_and_seqs, - block_sizes, ): head_dim = 128 page_size = 16 num_pages = 1000 - num_kv_pages_per_block, num_queries_per_block = block_sizes self._test_ragged_paged_attention( seq_lens, @@ -866,8 +848,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo( soft_cap=soft_cap, pad_tokens_and_seqs=pad_tokens_and_seqs, use_dynamo=False, - num_kv_pages_per_block=num_kv_pages_per_block, - num_queries_per_block=num_queries_per_block, ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index cfc3fb66617d..347ad7907d5f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -9,7 +9,6 @@ from torch_xla.distributed.spmd import Mesh import torch_xla.distributed.spmd as xs from torch_xla._internal.jax_workarounds import requires_jax -from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size # Re-expose this API used that is referenced by docs from torch_xla._internal.jax_workarounds import jax_import_guard # noqa: F401, pylint: disable=unused-import @@ -956,16 +955,6 @@ def ragged_paged_attention( # in the global scope which could cause problems for xmp.spawn. from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention - if num_kv_pages_per_block is None: - assert num_queries_per_block is None - token_num, q_head_num, _ = q.shape - _, page_size, num_combined_kv_heads, _ = kv_pages.shape - _, pages_per_seq = page_indices.shape - num_kv_heads = num_combined_kv_heads // 2 - max_model_len = pages_per_seq * page_size - num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size( - q_head_num, num_kv_heads, token_num, max_model_len) - if vmem_limit_bytes is None: vmem_limit_bytes = 64 * 1024 * 1024 diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 89693b5fe952..55fc906cf263 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """TPU-Friendly Ragged Paged Attention kernel. This kernel offers a highly optimized implementation of ragged paged attention, @@ -367,6 +366,7 @@ }, } + def next_power_of_2(x: int): """Finds the smallest power of 2 >= x using bit manipulation. @@ -412,12 +412,12 @@ def get_tpu_version() -> int: if 'TPU' not in kind: return -1 if kind.endswith(' lite'): - kind = kind[: -len(' lite')] + kind = kind[:-len(' lite')] assert kind[:-1] == 'TPU v', kind return int(kind[-1]) -def get_device_name(num_devices:int | None = None): +def get_device_name(num_devices: int | None = None): name = ' '.join(jax.devices()[0].device_kind.split()[:2]) if num_devices is not None: name += f'-{num_devices}' @@ -487,16 +487,14 @@ def __init__( # a bunch of if-ops. Check the performance when we have benchmarking setup. for i in range(vmem_buf.shape[0]): page_idx = kv_pages_start + i - page_idx = jax.lax.select( - page_idx < pages_per_seq, page_idx, pages_per_seq - 1 - ) + page_idx = jax.lax.select(page_idx < pages_per_seq, page_idx, + pages_per_seq - 1) self._async_copies.append( pltpu.make_async_copy( pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], vmem_buf.at[i], sem, - ) - ) + )) def start(self): """Starts the async copies.""" @@ -511,7 +509,8 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -550,19 +549,16 @@ def ref_ragged_paged_attention( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[ - :kv_len - ] - v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ - :kv_len - ] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, + head_dim)[:kv_len] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, + head_dim)[:kv_len] k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) attn *= sm_scale q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( - jnp.int32, attn.shape, 1 - ) + jnp.int32, attn.shape, 1) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) mask = q_span < kv_span if sliding_window is not None: @@ -580,7 +576,8 @@ def ref_ragged_paged_attention( # Expect to run these checks during runtime. def dynamic_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -621,26 +618,24 @@ def dynamic_validate_inputs( if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" - f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}." - ) + f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") if cu_q_lens[num_seqs[0]] > max_num_batched_tokens: raise ValueError( f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to" - f" {max_num_batched_tokens=}." - ) + f" {max_num_batched_tokens=}.") for i in range(num_seqs[0]): q_len = cu_q_lens[i + 1] - cu_q_lens[i] kv_len = kv_lens[i] if q_len > kv_len: raise ValueError( - f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." - ) + f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") # Expect to run these checks during compile time. def static_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -665,41 +660,30 @@ def static_validate_inputs( raise ValueError(f"{num_seqs.shape=} must be (1,)") if head_dim_k != head_dim: raise ValueError( - f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." - ) + f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") if kv_lens.shape != (max_num_seqs,): - raise ValueError( - f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" - " `max_num_seqs` is `page_indices.shape[0]`." - ) + raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" + " `max_num_seqs` is `page_indices.shape[0]`.") if cu_q_lens.shape != (max_num_seqs + 1,): raise ValueError( f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" - " `max_num_seqs` is `page_indices.shape[0]`." - ) - if ( - kv_lens.dtype != jnp.int32 - or page_indices.dtype != jnp.int32 - or cu_q_lens.dtype != jnp.int32 - ): + " `max_num_seqs` is `page_indices.shape[0]`.") + if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or + cu_q_lens.dtype != jnp.int32): raise ValueError( "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," - f" {cu_q_lens.dtype=}." - ) + f" {cu_q_lens.dtype=}.") if num_q_heads % num_kv_heads != 0: raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") if sliding_window is not None and sliding_window <= 0: raise ValueError(f"{sliding_window=} must be positive.") if soft_cap is not None and soft_cap == 0.0: raise ValueError(f"{soft_cap=} must not be 0.0.") - if ( - num_kv_pages_per_block is not None - and not 0 < num_kv_pages_per_block <= pages_per_seq - ): + if (num_kv_pages_per_block is not None and + not 0 < num_kv_pages_per_block <= pages_per_seq): raise ValueError( - f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]." - ) + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}].") if num_queries_per_block is not None and num_queries_per_block <= 0: raise ValueError(f"{num_queries_per_block=} must be positive.") if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: @@ -738,8 +722,7 @@ def ragged_paged_attention_kernel( num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( - kv_bufs.shape - ) + kv_bufs.shape) num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk @@ -753,15 +736,14 @@ def ragged_paged_attention_kernel( q_len_start = q_blk_idx * num_q_per_blk q_len_end = q_len_start + num_q_per_blk - def create_kv_async_copy_descriptors( - heads_blk_idx, seq_idx, kv_blk_idx, buf_idx - ): + def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, + buf_idx): offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) heads_start = heads_blk_idx * num_combined_kv_heads_per_blk async_copy_kv = MultiPageAsyncCopyDescriptor( - kv_pages_hbm_ref.at[ - :, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), : - ], + kv_pages_hbm_ref.at[:, :, + pl.ds(heads_start, num_combined_kv_heads_per_blk + ), :], kv_bufs.at[buf_idx], sems.at[buf_idx], page_indices_ref, @@ -774,7 +756,7 @@ def create_kv_async_copy_descriptors( # 2. Support arbitrary strided load/store for any last dimension. def strided_load_kv(ref, start, step): if ref.dtype == jnp.float32: - return ref[start::step, :], ref[start + 1 :: step, :] + return ref[start::step, :], ref[start + 1::step, :] packing = get_dtype_packing(ref.dtype) assert ref.dtype == jnp.bfloat16 assert step % packing == 0 @@ -799,9 +781,9 @@ def fold_on_2nd_minor(vec): @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): - async_copy_kv = create_kv_async_copy_descriptors( - heads_blk_idx, init_seq_idx, 0, init_buf_idx - ) + async_copy_kv = create_kv_async_copy_descriptors(heads_blk_idx, + init_seq_idx, 0, + init_buf_idx) async_copy_kv.start() def is_cur_q_blk_needed(q_states): @@ -817,9 +799,8 @@ def compute_with_cur_q_blk(q_states): q_len = q_end - q_start kv_len = kv_lens_ref[cur_seq_idx] - def get_next_prefetch_ids( - heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx - ): + def get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, + cur_buf_idx): next_kv_blk_idx = kv_blk_idx + 1 is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len next_kv_blk_idx = lax.select( @@ -884,12 +865,12 @@ def flash_attention( def masked_store(ref, val, start, end, group=1): iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group mask = jnp.logical_and(iota >= start, iota < end) - pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) + pl.store( + ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) qk = ( - jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) - * sm_scale - ) + jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) * + sm_scale) store_start = jnp.maximum(q_start - q_len_start, 0) store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) @@ -916,17 +897,12 @@ def init_scratch_ref(): store_end, ) - row_ids = ( - (kv_len - q_len) - + q_len_start - - q_start - + jax.lax.broadcasted_iota( - jnp.int32, - (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), - 0, - ) - // num_q_heads_per_kv_head - ) + row_ids = ((kv_len - q_len) + q_len_start - q_start + + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 0, + ) // num_q_heads_per_kv_head) col_ids = kv_len_start + jax.lax.broadcasted_iota( jnp.int32, (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), @@ -934,8 +910,8 @@ def init_scratch_ref(): ) causal_mask = row_ids < col_ids if sliding_window is not None: - causal_mask = jnp.logical_or(causal_mask, - row_ids - sliding_window >= col_ids) + causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window + >= col_ids) if soft_cap is not None: qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) @@ -945,14 +921,12 @@ def init_scratch_ref(): lm_store_shape = head_m_ref.shape m_curr = jnp.broadcast_to(m_curr, lm_store_shape) l_curr = jnp.broadcast_to( - s_curr.sum(axis=1, keepdims=True), lm_store_shape - ) + s_curr.sum(axis=1, keepdims=True), lm_store_shape) m_prev = head_m_ref[...] l_prev = head_l_ref[...] m_next = jnp.maximum(m_prev, m_curr) - masked_store( - head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head - ) + masked_store(head_m_ref, m_next, store_start, store_end, + num_q_heads_per_kv_head) alpha = jnp.exp(m_prev - m_next) beta = jnp.exp(m_curr - m_next) l_alpha = alpha * l_prev @@ -973,9 +947,8 @@ def broadcast_to_shape(arr, shape): assert arr.shape[0] == shape[0] assert shape[1] % arr.shape[1] == 0 # no-op concatenation. - return jnp.concatenate( - [arr for _ in range(shape[1] // arr.shape[1])], axis=1 - ) + return jnp.concatenate([arr for _ in range(shape[1] // arr.shape[1])], + axis=1) o_curr = head_acc_ref[...].reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) @@ -999,10 +972,8 @@ def is_valid_kv_blk_in_cur_seq(kv_states): def compute_with_kv_blk_in_cur_seq(kv_states): kv_blk_idx, cur_buf_idx = kv_states next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = ( - get_next_prefetch_ids( - heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx - ) - ) + get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, + cur_buf_idx)) @pl.when(next_heads_blk_idx < num_heads_blks) def prefetch_next_kv_blk(): @@ -1010,13 +981,11 @@ def prefetch_next_kv_blk(): # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and # DMA to fixed size buffer! next_async_copy_kv = create_kv_async_copy_descriptors( - next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx - ) + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx) next_async_copy_kv.start() cur_async_copy_kv = create_kv_async_copy_descriptors( - heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx - ) + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx) kv_ref = cur_async_copy_kv.wait().reshape( num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, @@ -1025,19 +994,17 @@ def prefetch_next_kv_blk(): q_head_idx = kv_head_idx * num_q_heads_per_kv_head # TODO(jevinjiang): extra handlig for packed type that can start at # unaligned position! - q = fold_on_2nd_minor( - q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] - ) - k, v = strided_load_kv( - kv_ref, kv_head_idx * 2, num_combined_kv_heads_per_blk - ) + q = fold_on_2nd_minor(q_ref[:, q_head_idx:q_head_idx + + num_q_heads_per_kv_head, :]) + k, v = strided_load_kv(kv_ref, kv_head_idx * 2, + num_combined_kv_heads_per_blk) flash_attention( q, k, v, l_ref.at[kv_head_idx], m_ref.at[kv_head_idx], - acc_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], + acc_ref.at[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :], kv_blk_idx=kv_blk_idx, ) return kv_blk_idx + 1, next_buf_idx @@ -1079,9 +1046,8 @@ def get_dtype_packing(dtype): raise ValueError(f"Not implemented: unsupported {dtype=}") -def get_min_heads_per_blk( - num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype -): +def get_min_heads_per_blk(num_q_heads, num_combined_kv_heads, q_dtype, + kv_dtype): q_packing = get_dtype_packing(q_dtype) kv_packing = get_dtype_packing(kv_dtype) @@ -1104,10 +1070,8 @@ def can_be_xla_fully_tiled(x, packing): # second minor tiling is not on. max_combined_kv_tiling = 8 * kv_packing min_combined_kv_heads = ( - max_combined_kv_tiling - if num_combined_kv_heads % max_combined_kv_tiling == 0 - else num_combined_kv_heads - ) + max_combined_kv_tiling if num_combined_kv_heads % + max_combined_kv_tiling == 0 else num_combined_kv_heads) min_q_heads = min_combined_kv_heads // 2 * ratio if can_be_xla_fully_tiled(min_q_heads, q_packing): return min_q_heads, min_combined_kv_heads @@ -1129,7 +1093,8 @@ def can_be_xla_fully_tiled(x, packing): def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] # TODO(jevinjiang): create a write_to_kv_cache kernel! - kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -1190,8 +1155,7 @@ def ragged_paged_attention( num_kv_heads = num_combined_kv_heads // 2 _, pages_per_seq = page_indices.shape num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype - ) + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype) num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block if num_q_per_blk is None or num_kv_pages_per_blk is None: diff --git a/torch_xla/experimental/tuned_block_sizes.py b/torch_xla/experimental/tuned_block_sizes.py deleted file mode 100644 index d3cb0ae25da2..000000000000 --- a/torch_xla/experimental/tuned_block_sizes.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch_xla - - -def _next_power_of_2_bit_manipulation(x): - """ - Finds the smallest power of 2 >= x using bit manipulation. - Assumes x is an integer. - - Args: - x: The input number (should be an integer). - - Returns: - The smallest integer power of 2 that is >= x. - Returns 1 if x <= 0. - """ - if x <= 0: - return 1 - if x == 1: - return 1 - return 1 << (x - 1).bit_length() - - -# ragged_paged_attention -# key: (q_head_num, kv_head_num, token_num, max_model_len) -# value: (num_kv_pages_per_block, num_queries_per_block) - - -def _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num, - max_model_len): - token_num = _next_power_of_2_bit_manipulation(token_num) - max_model_len = _next_power_of_2_bit_manipulation(max_model_len) - return q_head_num, kv_head_num, token_num, max_model_len - - -# TODO: add more tuned block sizes in the table -# q_head_num, kv_head_num, token_num, max_model_len -_ragged_attention_table = { - (32, 8, 4096, 2048): (128, 64), - (4, 1, 4096, 2048): (128, 128), - (32, 8, 2048, 2048): (128, 32), - (4, 1, 2048, 2048): (128, 64), - (32, 8, 1024, 2048): (64, 32), - (1, 1, 1024, 2048): (64, 32), - (32, 8, 4096, 4096): (128, 64), - (4, 1, 4096, 4096): (128, 128), - (32, 8, 2048, 4096): (128, 32), - (4, 1, 2048, 4096): (128, 64), - (32, 8, 1024, 4096): (64, 32), - (1, 1, 1024, 4096): (64, 32), - (32, 8, 4096, 64): (32, 32), - (4, 1, 4096, 64): (32, 32), - (32, 8, 2048, 64): (32, 32), - (4, 1, 2048, 64): (32, 32), - (32, 8, 1024, 64): (32, 32), - (1, 1, 1024, 64): (32, 32), - (32, 8, 4096, 128): (32, 32), - (4, 1, 4096, 128): (32, 32), - (32, 8, 2048, 128): (32, 32), - (4, 1, 2048, 128): (32, 32), - (32, 8, 1024, 128): (32, 32), - (1, 1, 1024, 128): (32, 32), - (10, 2, 4096, 2048): (128, 32), # Qwen/Qwen2.5-32B - (10, 2, 2048, 2048): (128, 32), # Qwen/Qwen2.5-32B - (10, 2, 1024, 2048): (128, 32), # Qwen/Qwen2.5-32B - (5, 1, 4098, 2048): (128, 64), # Qwen/Qwen2.5-32B - (5, 1, 2048, 2048): (128, 32), # Qwen/Qwen2.5-32B - (5, 1, 1024, 2048): (128, 32), # Qwen/Qwen2.5-32B -} - - -def get_ragged_attention_tuned_block_size(q_head_num, kv_head_num, token_num, - max_model_len): - tpu_version = torch_xla.tpu.version() - if tpu_version < 4: - raise NotImplementedError("TPU version must be 4 or higher.") - if tpu_version == 4: - # This default block size is not tuned, only make sure there's no - # OOM in vmem - num_kv_pages_per_block = 16 - num_queries_per_block = 128 - return num_kv_pages_per_block, num_queries_per_block - - key = _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num, - max_model_len) - block_sizes = _ragged_attention_table.get(key, (128, 32)) - return block_sizes From a43e89db4f3767506a78ef64ae294eec9e28c461 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 25 Apr 2025 06:01:49 +0000 Subject: [PATCH 3/3] Remove unused import --- .../experimental/pallas_kernels/ragged_paged_attention_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 55fc906cf263..4f2724a42deb 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -23,7 +23,6 @@ from jax import lax from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu -from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes import jax.numpy as jnp DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)