@@ -87,8 +87,9 @@ def _pagedattention_generate_qkv(
87
87
q = torch .randn (batch_size , query_len , num_heads , head_dim , dtype = dtype )
88
88
return q , k_pages , v_pages , page_indices
89
89
90
- def _round_up_closest_multiple_of (self , x , base ):
91
- return (x + base - 1 ) // base * base
90
+ def _ceil_div (self , a , b ):
91
+ assert b != 0
92
+ return (a + b - 1 ) // b
92
93
93
94
def _ragged_pagedattention_generate_qkv (
94
95
self ,
@@ -97,64 +98,50 @@ def _ragged_pagedattention_generate_qkv(
97
98
head_dim ,
98
99
page_size ,
99
100
num_pages ,
100
- dtype = torch .float32 ,
101
- num_queries_per_block = None ,
102
- pad_num_q_tokens = False ,
101
+ dtype ,
102
+ * ,
103
+ num_kv_pages_per_block = None ,
104
+ max_num_batched_tokens = None ,
105
+ max_num_seqs = 16 ,
103
106
):
104
- num_seqs = len (seq_lens )
105
- # Make sure the q_len is no longer than the kv_len. For example,
106
- # seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because
107
- # the 3rd sequence has q_len(506) > kv_len(463).
108
- for i in range (num_seqs ):
109
- cur_q_len = seq_lens [i ][0 ]
110
- cur_kv_len = seq_lens [i ][1 ]
111
- assert cur_q_len <= cur_kv_len , f"cur_q_len must be less than or equal to cur_kv_len. Got { cur_q_len } and { cur_kv_len } "
112
-
113
- query_lens = [seq_len [0 ] for seq_len in seq_lens ]
114
- actual_num_q_tokens = sum (query_lens )
115
- num_q_tokens = self ._round_up_closest_multiple_of (
116
- actual_num_q_tokens ,
117
- num_queries_per_block ) if pad_num_q_tokens else actual_num_q_tokens
118
- kv_lens = torch .tensor ([seq_len [1 ] for seq_len in seq_lens ],
119
- dtype = torch .int32 )
120
- num_q_heads = num_heads [0 ]
121
- num_kv_heads = num_heads [1 ]
122
- assert num_q_heads % num_kv_heads == 0 , "num_q_heads % num_kv_heads !=0."
123
- queries = torch .randn ((num_q_tokens , num_q_heads , head_dim ), dtype = dtype )
124
- k_pages = torch .randn ((num_kv_heads , num_pages , page_size , head_dim ),
107
+ cu_q_lens = [0 ]
108
+ kv_lens = []
109
+ for q_len , kv_len in seq_lens :
110
+ assert q_len <= kv_len
111
+ cu_q_lens .append (cu_q_lens [- 1 ] + q_len )
112
+ kv_lens .append (kv_len )
113
+
114
+ if max_num_batched_tokens is None :
115
+ max_num_batched_tokens = cu_q_lens [- 1 ]
116
+ else :
117
+ max_num_batched_tokens = max (cu_q_lens [- 1 ], max_num_batched_tokens )
118
+ if max_num_seqs is None :
119
+ max_num_seqs = len (seq_lens )
120
+ else :
121
+ max_num_seqs = max (len (seq_lens ), max_num_seqs )
122
+ max_kv_len = max (kv_lens )
123
+ pages_per_seq = self ._ceil_div (max_kv_len , page_size )
124
+ pages_per_seq = (
125
+ self ._ceil_div (pages_per_seq , num_kv_pages_per_block ) *
126
+ num_kv_pages_per_block )
127
+
128
+ num_q_heads , num_kv_heads = num_heads
129
+ cu_q_lens = torch .tensor (cu_q_lens , dtype = torch .int32 )
130
+ kv_lens = torch .tensor (kv_lens , dtype = torch .int32 )
131
+ cu_q_lens = torch .nn .functional .pad (
132
+ cu_q_lens , (0 , max_num_seqs + 1 - cu_q_lens .shape [0 ]), "constant" , 0 )
133
+ kv_lens = torch .nn .functional .pad (kv_lens ,
134
+ (0 , max_num_seqs - kv_lens .shape [0 ]),
135
+ "constant" , 0 )
136
+ q = torch .randn ((max_num_batched_tokens , num_q_heads , head_dim ),
137
+ dtype = dtype )
138
+ k_pages = torch .randn ((num_pages , page_size , num_kv_heads , head_dim ),
125
139
dtype = dtype )
126
- v_pages = torch .randn ((num_kv_heads , num_pages , page_size , head_dim ),
140
+ v_pages = torch .randn ((num_pages , page_size , num_kv_heads , head_dim ),
127
141
dtype = dtype )
128
-
129
- # Create a kv_lens: i32[num_tokens]
130
- kv_lens_with_paddings = [0 ] * num_q_tokens
131
- for i in range (num_seqs ):
132
- kv_lens_with_paddings [i ] = kv_lens [i ]
133
- kv_lens_ = torch .tensor (kv_lens_with_paddings , dtype = torch .int32 )
134
-
135
- # Create a page_indices i32[num_tokens, pages_per_sequence]
136
- max_kv_len = max ([seq_len [1 ] for seq_len in seq_lens ])
137
- max_num_pages_per_seq = (max_kv_len + page_size - 1 ) // page_size
138
-
139
- # The reason why we need to pad max_num_pages_per_seq is that
140
- # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
141
- max_num_pages_per_seq = 2 ** int (np .ceil (np .log2 (max_num_pages_per_seq )))
142
-
143
- # The assert below mimics the reality that each page get a unique index.
144
- # But for testing, the assert could be omitted.
145
- # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
146
142
page_indices = torch .randint (
147
- 0 , num_pages , (num_q_tokens , max_num_pages_per_seq ), dtype = torch .int32 )
148
-
149
- # Create a cu_q_lens i32[num_tokens + 1]
150
- q_lens_with_paddings = [0 ] * num_q_tokens
151
- for i in range (num_seqs ):
152
- q_lens_with_paddings [i ] = query_lens [i ]
153
- cu_q_lens = torch .cumsum (
154
- torch .tensor ([0 ] + q_lens_with_paddings , dtype = torch .int32 ),
155
- dim = 0 ,
156
- dtype = torch .int32 )
157
- return queries , k_pages , v_pages , page_indices , cu_q_lens , kv_lens_
143
+ 0 , num_pages , (max_num_seqs , pages_per_seq ), dtype = torch .int32 )
144
+ return q , k_pages , v_pages , page_indices , cu_q_lens , kv_lens
158
145
159
146
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
160
147
def test_tpu_custom_call_pallas_add (self ):
@@ -648,7 +635,7 @@ def test_paged_attention_wrapper(self):
648
635
"This test only works on TPUv4+." )
649
636
def test_ragged_paged_attention_wrapper_without_dynamo (self ):
650
637
from torch_xla .experimental .custom_kernel import ragged_paged_attention
651
- from torch_xla .experimental .pallas_kernels .ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
638
+ from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
652
639
653
640
seq_lens = [
654
641
(1 , 1328 ),
@@ -663,18 +650,25 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
663
650
(1 , 17 ),
664
651
(99 , 123 )
665
652
] # last 3 physical q blocks [(q_len, kv_len),...]
666
- num_heads = (4 , 4 )
653
+ num_heads = (32 , 8 )
667
654
head_dim = 128
668
655
dtype = torch .float32
669
656
page_size = 16
670
657
num_pages = 32768
671
658
num_seqs = len (seq_lens )
672
- num_kv_pages_per_block = 128
659
+ num_kv_pages_per_block = 16
673
660
num_queries_per_block = 8
674
- block_kv_size = 256
675
661
676
662
q , k_pages , v_pages , page_indices , cu_q_lens , kv_lens = self ._ragged_pagedattention_generate_qkv (
677
- seq_lens , num_heads , head_dim , page_size , num_pages , dtype = dtype )
663
+ seq_lens ,
664
+ num_heads ,
665
+ head_dim ,
666
+ page_size ,
667
+ num_pages ,
668
+ dtype ,
669
+ num_kv_pages_per_block = num_kv_pages_per_block ,
670
+ max_num_batched_tokens = 1024 ,
671
+ max_num_seqs = 16 )
678
672
679
673
q_xla = q .to ("xla" )
680
674
k_pages_xla = k_pages .to ("xla" )
@@ -693,7 +687,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
693
687
num_seqs = num_seqs ,
694
688
num_kv_pages_per_block = num_kv_pages_per_block ,
695
689
num_queries_per_block = num_queries_per_block ,
696
- use_kernel = True )
690
+ use_kernel = True )[: cu_q_lens [ num_seqs ]]
697
691
698
692
nonkernel_output = ragged_paged_attention (
699
693
q_xla ,
@@ -726,7 +720,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
726
720
num_seqs = num_seqs ,
727
721
num_kv_pages_per_block = num_kv_pages_per_block ,
728
722
num_queries_per_block = num_queries_per_block ,
729
- )[1 ]))
723
+ )[: cu_q_lens [ num_seqs ] ]))
730
724
731
725
self .assertTrue (
732
726
torch .allclose (
@@ -745,19 +739,25 @@ def _verify_ragged_paged_attention_with_dynamo(
745
739
dtype ,
746
740
num_kv_pages_per_block ,
747
741
num_queries_per_block ,
748
- pad_num_q_tokens = False ,
742
+ pad_tokens_and_seqs = False ,
749
743
sm_scale = 1.0 ,
750
744
):
751
745
num_seqs = len (seq_lens )
746
+ max_num_batched_tokens = None
747
+ max_num_seqs = None
748
+ if pad_tokens_and_seqs :
749
+ max_num_batched_tokens = 1024
750
+ max_num_seqs = 16
752
751
q , k_pages , v_pages , page_indices , cu_q_lens , kv_lens = self ._ragged_pagedattention_generate_qkv (
753
752
seq_lens ,
754
753
num_heads ,
755
754
head_dim ,
756
755
page_size ,
757
756
num_pages ,
758
- dtype = dtype ,
759
- num_queries_per_block = num_queries_per_block ,
760
- pad_num_q_tokens = pad_num_q_tokens )
757
+ dtype ,
758
+ num_kv_pages_per_block = num_kv_pages_per_block ,
759
+ max_num_batched_tokens = max_num_batched_tokens ,
760
+ max_num_seqs = max_num_seqs )
761
761
762
762
q_xla = q .to ("xla" )
763
763
k_pages_xla = k_pages .to ("xla" )
@@ -766,29 +766,7 @@ def _verify_ragged_paged_attention_with_dynamo(
766
766
page_indices_xla = page_indices .to ("xla" )
767
767
cu_q_lens_xla = cu_q_lens .to ("xla" )
768
768
769
- def ragged_paged_attention_wrapper (q , k_pages , v_pages , kv_lens ,
770
- page_indices , cu_q_lens , num_seqs ,
771
- num_kv_pages_per_block ,
772
- num_queries_per_block , use_kernel ,
773
- sm_scale ):
774
- return torch .ops .xla .ragged_paged_attention (
775
- q ,
776
- k_pages ,
777
- v_pages ,
778
- kv_lens ,
779
- page_indices ,
780
- cu_q_lens ,
781
- num_seqs ,
782
- num_kv_pages_per_block ,
783
- num_queries_per_block ,
784
- use_kernel = use_kernel ,
785
- sm_scale = sm_scale ,
786
- )
787
-
788
- compiled_paged_attention = torch .compile (
789
- ragged_paged_attention_wrapper , backend = "openxla" )
790
-
791
- kernel_output = compiled_paged_attention (
769
+ kernel_output = torch .ops .xla .ragged_paged_attention (
792
770
q_xla ,
793
771
k_pages_xla ,
794
772
v_pages_xla ,
@@ -800,9 +778,9 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
800
778
num_queries_per_block = num_queries_per_block ,
801
779
use_kernel = True ,
802
780
sm_scale = sm_scale ,
803
- )
781
+ )[: cu_q_lens [ num_seqs ]]
804
782
805
- nonkernel_output = compiled_paged_attention (
783
+ nonkernel_output = torch . ops . xla . ragged_paged_attention (
806
784
q_xla ,
807
785
k_pages_xla ,
808
786
v_pages_xla ,
@@ -828,7 +806,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
828
806
page_indices_jax = jnp .array (page_indices .numpy (), dtype = jnp .int32 )
829
807
cu_q_lens_jax = jnp .array (cu_q_lens .numpy (), dtype = jnp .int32 )
830
808
831
- from torch_xla .experimental .pallas_kernels .ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
809
+ from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
832
810
jax_kernel_output = torch .from_numpy (
833
811
np .array (
834
812
jax_ragged_paged_attention (
@@ -842,34 +820,19 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
842
820
num_kv_pages_per_block = num_kv_pages_per_block ,
843
821
num_queries_per_block = num_queries_per_block ,
844
822
sm_scale = sm_scale ,
845
- )[1 ]))
823
+ )[: cu_q_lens [ num_seqs ] ]))
846
824
jax_kernel_output_cpu = jax_kernel_output .cpu ()
847
825
848
- if pad_num_q_tokens :
849
- actual_num_q_tokens = cu_q_lens [num_seqs ]
850
- self .assertTrue (
851
- torch .allclose (
852
- kernel_output_cpu [:actual_num_q_tokens ],
853
- nonkernel_output_cpu [:actual_num_q_tokens ],
854
- atol = 2e-2 ,
855
- rtol = 1e-2 ))
856
- self .assertTrue (
857
- torch .allclose (
858
- kernel_output_cpu [:actual_num_q_tokens ],
859
- jax_kernel_output_cpu [:actual_num_q_tokens ],
860
- atol = 2e-2 ,
861
- rtol = 1e-2 ))
862
- else :
863
- self .assertTrue (
864
- torch .allclose (
865
- kernel_output_cpu , nonkernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
866
- self .assertTrue (
867
- torch .allclose (
868
- kernel_output_cpu , jax_kernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
826
+ self .assertTrue (
827
+ torch .allclose (
828
+ kernel_output_cpu , nonkernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
829
+ self .assertTrue (
830
+ torch .allclose (
831
+ kernel_output_cpu , jax_kernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
869
832
870
833
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
871
834
"This test only works on TPUv4+." )
872
- def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo (self ):
835
+ def test_ragged_paged_attention_wrapper_no_padding_with_dynamo (self ):
873
836
seq_lens = [
874
837
(1 , 1328 ),
875
838
(5 , 18 ),
@@ -883,7 +846,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
883
846
(1 , 17 ),
884
847
(99 , 123 )
885
848
] # last 3 physical q blocks [(q_len, kv_len),...]
886
- num_heads = (4 , 4 )
849
+ num_heads = (32 , 8 )
887
850
head_dim = 128
888
851
dtype = torch .float32
889
852
page_size = 16
@@ -897,7 +860,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
897
860
page_size ,
898
861
num_pages ,
899
862
dtype ,
900
- num_kv_pages_per_block = 128 ,
863
+ num_kv_pages_per_block = 16 ,
901
864
num_queries_per_block = 8 ,
902
865
sm_scale = sm_scale ,
903
866
)
@@ -908,12 +871,12 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
908
871
)
909
872
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
910
873
"This test only works on TPUv4+." )
911
- def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo (
874
+ def test_ragged_paged_attention_wrapper_with_padding_with_dynamo (
912
875
self ,
913
876
seq_lens ,
914
877
num_queries_per_block ,
915
878
):
916
- num_heads = (4 , 4 )
879
+ num_heads = (32 , 8 )
917
880
head_dim = 128
918
881
dtype = torch .float32
919
882
page_size = 16
@@ -927,9 +890,9 @@ def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
927
890
page_size ,
928
891
num_pages ,
929
892
dtype ,
930
- num_kv_pages_per_block = 128 ,
893
+ num_kv_pages_per_block = 16 ,
931
894
num_queries_per_block = num_queries_per_block ,
932
- pad_num_q_tokens = True ,
895
+ pad_tokens_and_seqs = True ,
933
896
sm_scale = sm_scale ,
934
897
)
935
898
0 commit comments