Skip to content

[V1] Implement sliding window attention in kv_cache_manager #14097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
c209a59
kv cache config refactor
heheda12345 Mar 1, 2025
2b30e35
update comments
heheda12345 Mar 1, 2025
34b283e
setup test for sliding window
heheda12345 Mar 1, 2025
87578e2
Merge branch 'main' of github.com:vllm-project/vllm into sliding_window
heheda12345 Mar 1, 2025
cd54423
can run sliding window, cannot run prefix cache
heheda12345 Mar 1, 2025
a000848
Merge branch 'main' of github.com:vllm-project/vllm into sliding_window
heheda12345 Mar 2, 2025
b9c9e0b
real_null_block, can run prefix cache
heheda12345 Mar 2, 2025
e733bcd
update tests
heheda12345 Mar 2, 2025
7140643
minor updates
heheda12345 Mar 2, 2025
40e0967
hack for interleaved model
heheda12345 Mar 2, 2025
93adab8
address review comments
heheda12345 Mar 9, 2025
9082be5
Merge branch 'main' of github.com:vllm-project/vllm into virtual_layer
heheda12345 Mar 9, 2025
4d05626
ManagerKVLayer
heheda12345 Mar 9, 2025
530d4bf
update names
heheda12345 Mar 18, 2025
b9fd999
Merge remote-tracking branch 'origin/main' into virtual_layer
heheda12345 Mar 18, 2025
56e0b5d
update comments
heheda12345 Mar 18, 2025
19b2589
update the explaination of kv cache groups
heheda12345 Mar 18, 2025
757f350
fix tests
heheda12345 Mar 21, 2025
33860dd
Merge remote-tracking branch 'origin/main' into virtual_layer
heheda12345 Mar 21, 2025
abb64f0
Merge branch 'main' of github.com:vllm-project/vllm into virtual_layer
heheda12345 Mar 21, 2025
7e70b50
Merge branch 'virtual_layer' of github.com:heheda12345/vllm into slid…
heheda12345 Mar 21, 2025
e10360d
Merge branch 'main' of github.com:vllm-project/vllm into sliding_window
heheda12345 Mar 21, 2025
d2618cb
fix tests
heheda12345 Mar 21, 2025
1222eba
small updates
heheda12345 Mar 21, 2025
368ce04
update function name
heheda12345 Mar 21, 2025
39f9d14
fix
heheda12345 Mar 21, 2025
e963864
update based on review
heheda12345 Mar 27, 2025
2b3b64f
fix bug in prefix cache
heheda12345 Mar 27, 2025
a91fdae
Merge branch 'main' of github.com:vllm-project/vllm into sliding_window
heheda12345 Mar 27, 2025
1c2da34
fix bug
heheda12345 Mar 27, 2025
58e2253
fix tests
heheda12345 Mar 27, 2025
2d4645f
fix comment
heheda12345 Mar 27, 2025
9b966b1
exclude null block
heheda12345 Mar 29, 2025
818fb83
Merge branch 'main' of github.com:vllm-project/vllm into sliding_window
heheda12345 Mar 29, 2025
2ca014e
simplify specialized manager interface
heheda12345 Mar 30, 2025
941f770
update based on review
heheda12345 Mar 30, 2025
0451bd3
minor updates
heheda12345 Mar 30, 2025
2d25ef1
change the meaning of sliding window and fix some nits
heheda12345 Mar 31, 2025
8c9e5a5
handle last block problem in kv cache manager
heheda12345 Mar 31, 2025
8eeb8d5
remove empty line
heheda12345 Mar 31, 2025
809fa99
Merge branch 'main' of github.com:vllm-project/vllm into sliding_window
heheda12345 Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions tests/core/block/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
check_answers(indices, answer, test_texts)


def prep_prompts(batch_size: int):
def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.

Args:
batch_size: number of prompts to generate
ln_range: an argument to control the length of the prompt
"""
prompts: list[str] = []
answer: list[int] = []
Expand All @@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
ln = random.randint(*ln_range)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
Expand All @@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
return prompts, answer, indices


def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
def check_answers(indices: list[int],
answer: list[int],
outputs: list[str],
accept_rate: float = 0.7):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
Expand All @@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7
assert frac_ok >= accept_rate


def check_window(prompts: list[str]):
Expand Down
Loading