2
2
3
3
Run `pytest tests/test_cache_block_hashing.py`.
4
4
"""
5
+ from typing import List , Optional
6
+
5
7
import pytest
6
8
9
+ from vllm .lora .request import LoRARequest
7
10
from vllm .transformers_utils .tokenizer import TokenizerGroup
8
11
from vllm .sequence import Sequence
9
12
@@ -36,7 +39,10 @@ def flatten_2d(li):
36
39
@pytest .mark .parametrize ("model" , ["facebook/opt-125m" ])
37
40
@pytest .mark .parametrize ("block_size" , [16 ])
38
41
@pytest .mark .parametrize ("max_num_seqs" , [256 ])
39
- def test_auto_prefix_caching (model : str , block_size : int , max_num_seqs : int ):
42
+ @pytest .mark .parametrize ("concurrent_lora_int_ids" ,
43
+ [[None ], [1 ], [None , 1 ], [None , 1 , 2 ], [1 , 2 ]])
44
+ def test_auto_prefix_caching (model : str , block_size : int , max_num_seqs : int ,
45
+ concurrent_lora_int_ids : List [Optional [int ]]):
40
46
41
47
tokenizer = TokenizerGroup (
42
48
tokenizer_id = "facebook/opt-125m" ,
@@ -48,20 +54,30 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
48
54
hashes = []
49
55
50
56
for prefix in prefixes :
51
- hashes .append ([])
52
- prompts = [prefix + prompt for prompt in sample_prompts ]
53
- seq_id = 0
54
- for prompt in prompts :
55
- hashes [- 1 ].append ([])
56
- prompt_token_ids = tokenizer .encode (prompt )
57
- seq = Sequence (seq_id , prompt , prompt_token_ids , block_size ,
58
- tokenizer .tokenizer .eos_token_id )
59
-
60
- num_blocks = len (prompt_token_ids ) // block_size
61
- for idx in range (num_blocks ):
62
- hashes [- 1 ][- 1 ].append (seq .hash_of_block (idx ))
63
-
64
- seq_id += 1
57
+ for lora_int_id in concurrent_lora_int_ids :
58
+ lora_request = None
59
+
60
+ if lora_int_id is not None :
61
+ lora_request = LoRARequest (
62
+ f"example_lora_{ lora_int_id } " ,
63
+ lora_int_id ,
64
+ f"example/path/to/lora_{ lora_int_id } " ,
65
+ )
66
+
67
+ hashes .append ([])
68
+ prompts = [prefix + prompt for prompt in sample_prompts ]
69
+ seq_id = 0
70
+ for prompt in prompts :
71
+ hashes [- 1 ].append ([])
72
+ prompt_token_ids = tokenizer .encode (prompt )
73
+ seq = Sequence (seq_id , prompt , prompt_token_ids , block_size ,
74
+ tokenizer .tokenizer .eos_token_id , lora_request )
75
+
76
+ num_blocks = len (prompt_token_ids ) // block_size
77
+ for idx in range (num_blocks ):
78
+ hashes [- 1 ][- 1 ].append (seq .hash_of_block (idx ))
79
+
80
+ seq_id += 1
65
81
66
82
# Check that hashes made with two prefixes with different first blocks are
67
83
# different everywhere.
0 commit comments