Skip to content

Commit 5644f44

Browse files
authored
Integrate ragged paged attention v2 (#8791)
1 parent 17270e2 commit 5644f44

File tree

3 files changed

+930
-271
lines changed

3 files changed

+930
-271
lines changed

test/test_pallas.py

Lines changed: 85 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ def _pagedattention_generate_qkv(
8787
q = torch.randn(batch_size, query_len, num_heads, head_dim, dtype=dtype)
8888
return q, k_pages, v_pages, page_indices
8989

90-
def _round_up_closest_multiple_of(self, x, base):
91-
return (x + base - 1) // base * base
90+
def _ceil_div(self, a, b):
91+
assert b != 0
92+
return (a + b - 1) // b
9293

9394
def _ragged_pagedattention_generate_qkv(
9495
self,
@@ -97,64 +98,50 @@ def _ragged_pagedattention_generate_qkv(
9798
head_dim,
9899
page_size,
99100
num_pages,
100-
dtype=torch.float32,
101-
num_queries_per_block=None,
102-
pad_num_q_tokens=False,
101+
dtype,
102+
*,
103+
num_kv_pages_per_block=None,
104+
max_num_batched_tokens=None,
105+
max_num_seqs=16,
103106
):
104-
num_seqs = len(seq_lens)
105-
# Make sure the q_len is no longer than the kv_len. For example,
106-
# seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because
107-
# the 3rd sequence has q_len(506) > kv_len(463).
108-
for i in range(num_seqs):
109-
cur_q_len = seq_lens[i][0]
110-
cur_kv_len = seq_lens[i][1]
111-
assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}"
112-
113-
query_lens = [seq_len[0] for seq_len in seq_lens]
114-
actual_num_q_tokens = sum(query_lens)
115-
num_q_tokens = self._round_up_closest_multiple_of(
116-
actual_num_q_tokens,
117-
num_queries_per_block) if pad_num_q_tokens else actual_num_q_tokens
118-
kv_lens = torch.tensor([seq_len[1] for seq_len in seq_lens],
119-
dtype=torch.int32)
120-
num_q_heads = num_heads[0]
121-
num_kv_heads = num_heads[1]
122-
assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0."
123-
queries = torch.randn((num_q_tokens, num_q_heads, head_dim), dtype=dtype)
124-
k_pages = torch.randn((num_kv_heads, num_pages, page_size, head_dim),
107+
cu_q_lens = [0]
108+
kv_lens = []
109+
for q_len, kv_len in seq_lens:
110+
assert q_len <= kv_len
111+
cu_q_lens.append(cu_q_lens[-1] + q_len)
112+
kv_lens.append(kv_len)
113+
114+
if max_num_batched_tokens is None:
115+
max_num_batched_tokens = cu_q_lens[-1]
116+
else:
117+
max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens)
118+
if max_num_seqs is None:
119+
max_num_seqs = len(seq_lens)
120+
else:
121+
max_num_seqs = max(len(seq_lens), max_num_seqs)
122+
max_kv_len = max(kv_lens)
123+
pages_per_seq = self._ceil_div(max_kv_len, page_size)
124+
pages_per_seq = (
125+
self._ceil_div(pages_per_seq, num_kv_pages_per_block) *
126+
num_kv_pages_per_block)
127+
128+
num_q_heads, num_kv_heads = num_heads
129+
cu_q_lens = torch.tensor(cu_q_lens, dtype=torch.int32)
130+
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
131+
cu_q_lens = torch.nn.functional.pad(
132+
cu_q_lens, (0, max_num_seqs + 1 - cu_q_lens.shape[0]), "constant", 0)
133+
kv_lens = torch.nn.functional.pad(kv_lens,
134+
(0, max_num_seqs - kv_lens.shape[0]),
135+
"constant", 0)
136+
q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim),
137+
dtype=dtype)
138+
k_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim),
125139
dtype=dtype)
126-
v_pages = torch.randn((num_kv_heads, num_pages, page_size, head_dim),
140+
v_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim),
127141
dtype=dtype)
128-
129-
# Create a kv_lens: i32[num_tokens]
130-
kv_lens_with_paddings = [0] * num_q_tokens
131-
for i in range(num_seqs):
132-
kv_lens_with_paddings[i] = kv_lens[i]
133-
kv_lens_ = torch.tensor(kv_lens_with_paddings, dtype=torch.int32)
134-
135-
# Create a page_indices i32[num_tokens, pages_per_sequence]
136-
max_kv_len = max([seq_len[1] for seq_len in seq_lens])
137-
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
138-
139-
# The reason why we need to pad max_num_pages_per_seq is that
140-
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
141-
max_num_pages_per_seq = 2**int(np.ceil(np.log2(max_num_pages_per_seq)))
142-
143-
# The assert below mimics the reality that each page get a unique index.
144-
# But for testing, the assert could be omitted.
145-
# assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
146142
page_indices = torch.randint(
147-
0, num_pages, (num_q_tokens, max_num_pages_per_seq), dtype=torch.int32)
148-
149-
# Create a cu_q_lens i32[num_tokens + 1]
150-
q_lens_with_paddings = [0] * num_q_tokens
151-
for i in range(num_seqs):
152-
q_lens_with_paddings[i] = query_lens[i]
153-
cu_q_lens = torch.cumsum(
154-
torch.tensor([0] + q_lens_with_paddings, dtype=torch.int32),
155-
dim=0,
156-
dtype=torch.int32)
157-
return queries, k_pages, v_pages, page_indices, cu_q_lens, kv_lens_
143+
0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32)
144+
return q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens
158145

159146
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
160147
def test_tpu_custom_call_pallas_add(self):
@@ -648,7 +635,7 @@ def test_paged_attention_wrapper(self):
648635
"This test only works on TPUv4+.")
649636
def test_ragged_paged_attention_wrapper_without_dynamo(self):
650637
from torch_xla.experimental.custom_kernel import ragged_paged_attention
651-
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
638+
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
652639

653640
seq_lens = [
654641
(1, 1328),
@@ -663,18 +650,25 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
663650
(1, 17),
664651
(99, 123)
665652
] # last 3 physical q blocks [(q_len, kv_len),...]
666-
num_heads = (4, 4)
653+
num_heads = (32, 8)
667654
head_dim = 128
668655
dtype = torch.float32
669656
page_size = 16
670657
num_pages = 32768
671658
num_seqs = len(seq_lens)
672-
num_kv_pages_per_block = 128
659+
num_kv_pages_per_block = 16
673660
num_queries_per_block = 8
674-
block_kv_size = 256
675661

676662
q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
677-
seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype)
663+
seq_lens,
664+
num_heads,
665+
head_dim,
666+
page_size,
667+
num_pages,
668+
dtype,
669+
num_kv_pages_per_block=num_kv_pages_per_block,
670+
max_num_batched_tokens=1024,
671+
max_num_seqs=16)
678672

679673
q_xla = q.to("xla")
680674
k_pages_xla = k_pages.to("xla")
@@ -693,7 +687,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
693687
num_seqs=num_seqs,
694688
num_kv_pages_per_block=num_kv_pages_per_block,
695689
num_queries_per_block=num_queries_per_block,
696-
use_kernel=True)
690+
use_kernel=True)[:cu_q_lens[num_seqs]]
697691

698692
nonkernel_output = ragged_paged_attention(
699693
q_xla,
@@ -726,7 +720,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
726720
num_seqs=num_seqs,
727721
num_kv_pages_per_block=num_kv_pages_per_block,
728722
num_queries_per_block=num_queries_per_block,
729-
)[1]))
723+
)[:cu_q_lens[num_seqs]]))
730724

731725
self.assertTrue(
732726
torch.allclose(
@@ -745,19 +739,25 @@ def _verify_ragged_paged_attention_with_dynamo(
745739
dtype,
746740
num_kv_pages_per_block,
747741
num_queries_per_block,
748-
pad_num_q_tokens=False,
742+
pad_tokens_and_seqs=False,
749743
sm_scale=1.0,
750744
):
751745
num_seqs = len(seq_lens)
746+
max_num_batched_tokens = None
747+
max_num_seqs = None
748+
if pad_tokens_and_seqs:
749+
max_num_batched_tokens = 1024
750+
max_num_seqs = 16
752751
q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
753752
seq_lens,
754753
num_heads,
755754
head_dim,
756755
page_size,
757756
num_pages,
758-
dtype=dtype,
759-
num_queries_per_block=num_queries_per_block,
760-
pad_num_q_tokens=pad_num_q_tokens)
757+
dtype,
758+
num_kv_pages_per_block=num_kv_pages_per_block,
759+
max_num_batched_tokens=max_num_batched_tokens,
760+
max_num_seqs=max_num_seqs)
761761

762762
q_xla = q.to("xla")
763763
k_pages_xla = k_pages.to("xla")
@@ -766,29 +766,7 @@ def _verify_ragged_paged_attention_with_dynamo(
766766
page_indices_xla = page_indices.to("xla")
767767
cu_q_lens_xla = cu_q_lens.to("xla")
768768

769-
def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
770-
page_indices, cu_q_lens, num_seqs,
771-
num_kv_pages_per_block,
772-
num_queries_per_block, use_kernel,
773-
sm_scale):
774-
return torch.ops.xla.ragged_paged_attention(
775-
q,
776-
k_pages,
777-
v_pages,
778-
kv_lens,
779-
page_indices,
780-
cu_q_lens,
781-
num_seqs,
782-
num_kv_pages_per_block,
783-
num_queries_per_block,
784-
use_kernel=use_kernel,
785-
sm_scale=sm_scale,
786-
)
787-
788-
compiled_paged_attention = torch.compile(
789-
ragged_paged_attention_wrapper, backend="openxla")
790-
791-
kernel_output = compiled_paged_attention(
769+
kernel_output = torch.ops.xla.ragged_paged_attention(
792770
q_xla,
793771
k_pages_xla,
794772
v_pages_xla,
@@ -800,9 +778,9 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
800778
num_queries_per_block=num_queries_per_block,
801779
use_kernel=True,
802780
sm_scale=sm_scale,
803-
)
781+
)[:cu_q_lens[num_seqs]]
804782

805-
nonkernel_output = compiled_paged_attention(
783+
nonkernel_output = torch.ops.xla.ragged_paged_attention(
806784
q_xla,
807785
k_pages_xla,
808786
v_pages_xla,
@@ -828,7 +806,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
828806
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
829807
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)
830808

831-
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
809+
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
832810
jax_kernel_output = torch.from_numpy(
833811
np.array(
834812
jax_ragged_paged_attention(
@@ -842,34 +820,19 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
842820
num_kv_pages_per_block=num_kv_pages_per_block,
843821
num_queries_per_block=num_queries_per_block,
844822
sm_scale=sm_scale,
845-
)[1]))
823+
)[:cu_q_lens[num_seqs]]))
846824
jax_kernel_output_cpu = jax_kernel_output.cpu()
847825

848-
if pad_num_q_tokens:
849-
actual_num_q_tokens = cu_q_lens[num_seqs]
850-
self.assertTrue(
851-
torch.allclose(
852-
kernel_output_cpu[:actual_num_q_tokens],
853-
nonkernel_output_cpu[:actual_num_q_tokens],
854-
atol=2e-2,
855-
rtol=1e-2))
856-
self.assertTrue(
857-
torch.allclose(
858-
kernel_output_cpu[:actual_num_q_tokens],
859-
jax_kernel_output_cpu[:actual_num_q_tokens],
860-
atol=2e-2,
861-
rtol=1e-2))
862-
else:
863-
self.assertTrue(
864-
torch.allclose(
865-
kernel_output_cpu, nonkernel_output_cpu, atol=2e-2, rtol=1e-2))
866-
self.assertTrue(
867-
torch.allclose(
868-
kernel_output_cpu, jax_kernel_output_cpu, atol=2e-2, rtol=1e-2))
826+
self.assertTrue(
827+
torch.allclose(
828+
kernel_output_cpu, nonkernel_output_cpu, atol=2e-2, rtol=1e-2))
829+
self.assertTrue(
830+
torch.allclose(
831+
kernel_output_cpu, jax_kernel_output_cpu, atol=2e-2, rtol=1e-2))
869832

870833
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
871834
"This test only works on TPUv4+.")
872-
def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
835+
def test_ragged_paged_attention_wrapper_no_padding_with_dynamo(self):
873836
seq_lens = [
874837
(1, 1328),
875838
(5, 18),
@@ -883,7 +846,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
883846
(1, 17),
884847
(99, 123)
885848
] # last 3 physical q blocks [(q_len, kv_len),...]
886-
num_heads = (4, 4)
849+
num_heads = (32, 8)
887850
head_dim = 128
888851
dtype = torch.float32
889852
page_size = 16
@@ -897,7 +860,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
897860
page_size,
898861
num_pages,
899862
dtype,
900-
num_kv_pages_per_block=128,
863+
num_kv_pages_per_block=16,
901864
num_queries_per_block=8,
902865
sm_scale=sm_scale,
903866
)
@@ -908,12 +871,12 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
908871
)
909872
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
910873
"This test only works on TPUv4+.")
911-
def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
874+
def test_ragged_paged_attention_wrapper_with_padding_with_dynamo(
912875
self,
913876
seq_lens,
914877
num_queries_per_block,
915878
):
916-
num_heads = (4, 4)
879+
num_heads = (32, 8)
917880
head_dim = 128
918881
dtype = torch.float32
919882
page_size = 16
@@ -927,9 +890,9 @@ def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
927890
page_size,
928891
num_pages,
929892
dtype,
930-
num_kv_pages_per_block=128,
893+
num_kv_pages_per_block=16,
931894
num_queries_per_block=num_queries_per_block,
932-
pad_num_q_tokens=True,
895+
pad_tokens_and_seqs=True,
933896
sm_scale=sm_scale,
934897
)
935898

0 commit comments

Comments
 (0)