Skip to content

Commit 12634be

Browse files
jacobthebananadbogunowicz
authored and
dbogunowicz
committed
Possible fix for conflict between Automated Prefix Caching (vllm-project#2762) and multi-LoRA support (vllm-project#1804) (vllm-project#3263)
1 parent fb3c092 commit 12634be

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

tests/test_cache_block_hashing.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
33
Run `pytest tests/test_cache_block_hashing.py`.
44
"""
5+
from typing import List, Optional
6+
57
import pytest
68

9+
from vllm.lora.request import LoRARequest
710
from vllm.transformers_utils.tokenizer import TokenizerGroup
811
from vllm.sequence import Sequence
912

@@ -36,7 +39,10 @@ def flatten_2d(li):
3639
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
3740
@pytest.mark.parametrize("block_size", [16])
3841
@pytest.mark.parametrize("max_num_seqs", [256])
39-
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
42+
@pytest.mark.parametrize("concurrent_lora_int_ids",
43+
[[None], [1], [None, 1], [None, 1, 2], [1, 2]])
44+
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
45+
concurrent_lora_int_ids: List[Optional[int]]):
4046

4147
tokenizer = TokenizerGroup(
4248
tokenizer_id="facebook/opt-125m",
@@ -48,20 +54,30 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
4854
hashes = []
4955

5056
for prefix in prefixes:
51-
hashes.append([])
52-
prompts = [prefix + prompt for prompt in sample_prompts]
53-
seq_id = 0
54-
for prompt in prompts:
55-
hashes[-1].append([])
56-
prompt_token_ids = tokenizer.encode(prompt)
57-
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
58-
tokenizer.tokenizer.eos_token_id)
59-
60-
num_blocks = len(prompt_token_ids) // block_size
61-
for idx in range(num_blocks):
62-
hashes[-1][-1].append(seq.hash_of_block(idx))
63-
64-
seq_id += 1
57+
for lora_int_id in concurrent_lora_int_ids:
58+
lora_request = None
59+
60+
if lora_int_id is not None:
61+
lora_request = LoRARequest(
62+
f"example_lora_{lora_int_id}",
63+
lora_int_id,
64+
f"example/path/to/lora_{lora_int_id}",
65+
)
66+
67+
hashes.append([])
68+
prompts = [prefix + prompt for prompt in sample_prompts]
69+
seq_id = 0
70+
for prompt in prompts:
71+
hashes[-1].append([])
72+
prompt_token_ids = tokenizer.encode(prompt)
73+
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
74+
tokenizer.tokenizer.eos_token_id, lora_request)
75+
76+
num_blocks = len(prompt_token_ids) // block_size
77+
for idx in range(num_blocks):
78+
hashes[-1][-1].append(seq.hash_of_block(idx))
79+
80+
seq_id += 1
6581

6682
# Check that hashes made with two prefixes with different first blocks are
6783
# different everywhere.

vllm/sequence.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def hash_of_block(self, logical_idx: int) -> int:
187187
# TODO: The current hashing function is O(L^2). We should optimize
188188
# this in the future.
189189
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
190-
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
190+
return hash(
191+
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
191192

192193
def num_hashed_tokens_of_block(self, logical_idx: int):
193194
return logical_idx * self.block_size + self.block_size

0 commit comments

Comments
 (0)