-
Notifications
You must be signed in to change notification settings - Fork 284
/
Copy pathtest_shared_prefix_kernels.py
256 lines (230 loc) · 9.04 KB
/
test_shared_prefix_kernels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
Copyright (c) 2023 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy
import pytest
import torch
import flashinfer
def ceil_div(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("unique_kv_len", [37, 17])
@pytest.mark.parametrize("shared_kv_len", [54, 97, 1979])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
def test_batch_decode_with_shared_prefix_padded_kv_cache(
batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim
):
q = torch.randn(batch_size, num_heads, head_dim).to(0).half()
k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
k_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half()
v_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half()
o = flashinfer.batch_decode_with_shared_prefix_padded_kv_cache(
q, k_shared, v_shared, k_unique, v_unique
)
for i in range(batch_size):
qi = q[i]
ki = torch.cat([k_shared, k_unique[i]], dim=0)
vi = torch.cat([v_shared, v_unique[i]], dim=0)
o_ref_i = flashinfer.single_decode_with_kv_cache(qi, ki, vi)
o_i_np = o[i].cpu().numpy()
o_ref_i_np = o_ref_i.cpu().numpy()
numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("unique_kv_len", [37, 17])
@pytest.mark.parametrize("shared_kv_len", [54, 97, 1979])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("page_size", [1, 16])
def test_batch_decode_with_shared_prefix_paged_kv_cache(
batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim, page_size
):
kv_layout = "NHD"
q = torch.randn(batch_size, num_heads, head_dim).to(0).half()
k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
k_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half()
v_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half()
kv_data = (
torch.zeros(
batch_size * ceil_div(unique_kv_len, page_size),
2,
page_size,
num_heads,
head_dim,
)
.to(0)
.half()
)
kv_indices = (
torch.arange(0, batch_size * ceil_div(unique_kv_len, page_size)).to(0).int()
)
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * ceil_div(
unique_kv_len, page_size
)
kv_last_page_len = torch.full(
(batch_size,), (unique_kv_len - 1) % page_size + 1, dtype=torch.int32
).to(0)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_heads,
num_heads,
head_dim,
page_size,
kv_data.dtype,
)
append_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len
flashinfer.append_paged_kv_cache(
k_unique.view(-1, num_heads, head_dim),
v_unique.view(-1, num_heads, head_dim),
append_indptr,
kv_data,
kv_indices,
kv_indptr,
kv_last_page_len,
kv_layout,
)
o_padded = flashinfer.batch_decode_with_shared_prefix_padded_kv_cache(
q, k_shared, v_shared, k_unique, v_unique
)
o_paged = wrapper.forward(q, k_shared, v_shared, kv_data)
numpy.testing.assert_allclose(
o_padded.cpu().numpy(), o_paged.cpu().numpy(), rtol=1e-3, atol=1e-3
)
@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("unique_kv_len", [37, 17])
@pytest.mark.parametrize("shared_kv_len", [128, 512, 2048])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("page_size", [1, 16])
def test_batch_prefill_with_shared_prefix_paged_kv_cache(
batch_size, unique_kv_len, shared_kv_len, num_heads, causal, head_dim, page_size
):
assert shared_kv_len % page_size == 0
kv_layout = "NHD"
q = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half()
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len
k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half()
k_unique = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half()
v_unique = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half()
kv_data = (
torch.zeros(
ceil_div(shared_kv_len, page_size)
+ batch_size * ceil_div(unique_kv_len, page_size),
2,
page_size,
num_heads,
head_dim,
)
.to(0)
.half()
)
shared_kv_indices = torch.arange(0, ceil_div(shared_kv_len, page_size)).to(0).int()
shared_append_indptr = torch.arange(0, 2).to(0).int() * shared_kv_len
shared_kv_indptr = torch.arange(0, 2).to(0).int() * ceil_div(
shared_kv_len, page_size
)
shared_last_page_len = torch.full(
(1,), (shared_kv_len - 1) % page_size + 1, dtype=torch.int32
).to(0)
flashinfer.append_paged_kv_cache(
k_shared,
v_shared,
shared_append_indptr,
kv_data,
shared_kv_indices,
shared_kv_indptr,
shared_last_page_len,
kv_layout,
)
unique_kv_indices = torch.arange(
0, batch_size * ceil_div(unique_kv_len, page_size)
).to(0).int() + ceil_div(shared_kv_len, page_size)
unique_append_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len
unique_kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * ceil_div(
unique_kv_len, page_size
)
unique_last_page_len = torch.full(
(batch_size,), (unique_kv_len - 1) % page_size + 1, dtype=torch.int32
).to(0)
flashinfer.append_paged_kv_cache(
k_unique,
v_unique,
unique_append_indptr,
kv_data,
unique_kv_indices,
unique_kv_indptr,
unique_last_page_len,
kv_layout,
)
baseline_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout
)
cascade_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper(
torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout
)
baseline_kv_indices_arr = []
for i in range(batch_size):
baseline_kv_indices_arr.append(
torch.arange(0, ceil_div(shared_kv_len, page_size)).int()
)
baseline_kv_indices_arr.append(
torch.arange(
i * ceil_div(unique_kv_len, page_size),
(i + 1) * ceil_div(unique_kv_len, page_size),
).int()
+ ceil_div(shared_kv_len, page_size)
)
baseline_kv_indices = torch.cat(baseline_kv_indices_arr, dim=0).to(0)
baseline_kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * (
ceil_div(shared_kv_len, page_size) + ceil_div(unique_kv_len, page_size)
)
baseline_kv_last_page_len = unique_last_page_len
baseline_wrapper.begin_forward(
q_indptr,
baseline_kv_indptr,
baseline_kv_indices,
baseline_kv_last_page_len,
num_heads,
num_heads,
head_dim,
)
o_baseline = baseline_wrapper.forward(q, kv_data, causal=causal)
cascade_kv_indices = unique_kv_indices
cascade_kv_indptr = unique_kv_indptr
cascade_kv_last_page_len = unique_last_page_len
cascade_wrapper.begin_forward(
q_indptr,
cascade_kv_indptr,
cascade_kv_indices,
cascade_kv_last_page_len,
num_heads,
num_heads,
head_dim,
)
o_cascade = cascade_wrapper.forward(q, k_shared, v_shared, kv_data, causal=causal)
numpy.testing.assert_allclose(
o_baseline.cpu().numpy(), o_cascade.cpu().numpy(), rtol=1e-3, atol=1e-3
)
if __name__ == "__main__":
test_batch_decode_with_shared_prefix_padded_kv_cache(12, 37, 54, 8, 128)
test_batch_decode_with_shared_prefix_paged_kv_cache(12, 37, 54, 8, 128, 16)
test_batch_prefill_with_shared_prefix_paged_kv_cache(12, 37, 256, 8, True, 128, 16)