Skip to content

Commit fd783a6

Browse files
heheda12345shreyankg
authored andcommitted
[V1] Implement sliding window attention in kv_cache_manager (vllm-project#14097)
Signed-off-by: Chen Zhang <[email protected]>
1 parent b06a55c commit fd783a6

15 files changed

+662
-158
lines changed

tests/core/block/e2e/test_correctness_sliding_window.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
129129
check_answers(indices, answer, test_texts)
130130

131131

132-
def prep_prompts(batch_size: int):
132+
def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
133133
"""
134134
Generate prompts which a bunch of assignments,
135135
then asking for the value of one of them.
136136
The prompt is just under 10k tokens; sliding window is 4k
137137
so the answer is outside sliding window, but should still be correct.
138+
139+
Args:
140+
batch_size: number of prompts to generate
141+
ln_range: an argument to control the length of the prompt
138142
"""
139143
prompts: list[str] = []
140144
answer: list[int] = []
@@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
145149
indices.append(idx)
146150
prompt = "```python\n# We set a number of variables, " + \
147151
f"x{idx} will be important later\n"
148-
ln = random.randint(800, 1100)
152+
ln = random.randint(*ln_range)
149153
for k in range(30, ln):
150154
v = random.randint(10, 99)
151155
if k == idx:
@@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
157161
return prompts, answer, indices
158162

159163

160-
def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
164+
def check_answers(indices: list[int],
165+
answer: list[int],
166+
outputs: list[str],
167+
accept_rate: float = 0.7):
161168
answer2 = [int(text[0:2].strip()) for text in outputs]
162169
print(list(zip(indices, zip(answer, answer2))))
163170
numok = 0
@@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
166173
numok += 1
167174
frac_ok = numok / len(answer)
168175
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
169-
assert frac_ok > 0.7
176+
assert frac_ok >= accept_rate
170177

171178

172179
def check_window(prompts: list[str]):

0 commit comments

Comments
 (0)