Skip to content

Commit 97316d9

Browse files
updated
Signed-off-by: [email protected] <[email protected]>
1 parent 7dd764b commit 97316d9

File tree

1 file changed

+134
-11
lines changed

1 file changed

+134
-11
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 134 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,139 @@
33
from unittest.mock import Mock
44

55
import pytest
6+
import torch
67

7-
from vllm.multimodal.inputs import PlaceholderRange
8-
from vllm.tests.v1.utils import EOS_TOKEN_ID, create_requests, create_scheduler
8+
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
9+
SchedulerConfig, VllmConfig)
10+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
11+
from vllm.sampling_params import SamplingParams
912
from vllm.v1.core.sched.output import SchedulerOutput
1013
from vllm.v1.core.sched.scheduler import Scheduler
14+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
15+
KVCacheGroupSpec)
1116
from vllm.v1.outputs import ModelRunnerOutput
12-
from vllm.v1.request import RequestStatus
17+
from vllm.v1.request import Request, RequestStatus
18+
from vllm.v1.structured_output import StructuredOutputManager
19+
20+
EOS_TOKEN_ID = 50256
21+
22+
23+
def create_scheduler(
24+
model: str = "facebook/opt-125m",
25+
max_num_seqs: int = 16,
26+
max_num_batched_tokens: int = 8192,
27+
enable_prefix_caching: Optional[bool] = None,
28+
long_prefill_token_threshold: int = 0,
29+
disable_chunked_mm_input: bool = False,
30+
use_kv_connector: bool = False,
31+
num_blocks: int = 10000,
32+
block_size: int = 16,
33+
max_model_len: Optional[int] = None,
34+
) -> Scheduler:
35+
'''Create scheduler under test.
36+
37+
Args:
38+
model: model under test
39+
max_num_seqs: max sequences to schedule
40+
max_num_batch_tokens: max num tokens to batch
41+
enable_prefix_caching: optionally force APC config
42+
(True/False) or use default
43+
(None)
44+
45+
Returns:
46+
:class:`Scheduler` instance
47+
'''
48+
if max_model_len is None:
49+
max_model_len = max_num_batched_tokens
50+
scheduler_config = SchedulerConfig(
51+
max_num_seqs=max_num_seqs,
52+
max_num_batched_tokens=max_num_batched_tokens,
53+
max_model_len=max_model_len,
54+
long_prefill_token_threshold=long_prefill_token_threshold,
55+
disable_chunked_mm_input=disable_chunked_mm_input,
56+
enable_chunked_prefill=True,
57+
)
58+
model_config = ModelConfig(
59+
model=model,
60+
task="auto",
61+
tokenizer=model,
62+
tokenizer_mode="auto",
63+
trust_remote_code=True,
64+
dtype="float16",
65+
seed=42,
66+
)
67+
# Cache config, optionally force APC
68+
kwargs_cache = ({} if enable_prefix_caching is None else {
69+
'enable_prefix_caching': enable_prefix_caching
70+
})
71+
cache_config = CacheConfig(
72+
block_size=block_size,
73+
gpu_memory_utilization=0.9,
74+
swap_space=0,
75+
cache_dtype="auto",
76+
**kwargs_cache,
77+
)
78+
kv_transfer_config = KVTransferConfig(
79+
kv_connector="SharedStorageConnector",
80+
kv_role="kv_both",
81+
kv_connector_extra_config={"shared_storage_path": "local_storage"},
82+
) if use_kv_connector else None
83+
84+
vllm_config = VllmConfig(
85+
scheduler_config=scheduler_config,
86+
model_config=model_config,
87+
cache_config=cache_config,
88+
kv_transfer_config=kv_transfer_config,
89+
)
90+
kv_cache_config = KVCacheConfig(
91+
num_blocks=num_blocks, # A large number of blocks to hold all requests
92+
tensors={},
93+
kv_cache_groups=[
94+
KVCacheGroupSpec(['layer'],
95+
FullAttentionSpec(block_size, 1, 1, torch.float32,
96+
False))
97+
],
98+
)
99+
cache_config.num_gpu_blocks = num_blocks
100+
return Scheduler(
101+
vllm_config=vllm_config,
102+
kv_cache_config=kv_cache_config,
103+
log_stats=True,
104+
structured_output_manager=StructuredOutputManager(vllm_config),
105+
)
106+
107+
108+
def create_requests(num_requests: int,
109+
num_tokens: int = 10,
110+
mm_positions: Optional[list[PlaceholderRange]] = None,
111+
max_tokens: int = 16,
112+
stop_token_ids: Optional[list[int]] = None,
113+
prompt_logprobs: Optional[int] = None):
114+
sampling_params = SamplingParams(ignore_eos=False,
115+
max_tokens=max_tokens,
116+
stop_token_ids=stop_token_ids,
117+
prompt_logprobs=prompt_logprobs)
118+
requests = []
119+
for i in range(num_requests):
120+
if mm_positions is not None:
121+
mm_position = mm_positions[i]
122+
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
123+
else:
124+
mm_position = None
125+
mm_inputs = None
126+
request = Request(
127+
request_id=f"{i}",
128+
prompt=None,
129+
prompt_token_ids=[i] * num_tokens,
130+
sampling_params=sampling_params,
131+
multi_modal_inputs=mm_inputs,
132+
multi_modal_placeholders=mm_position,
133+
multi_modal_hashes=None,
134+
eos_token_id=EOS_TOKEN_ID,
135+
arrival_time=0,
136+
)
137+
requests.append(request)
138+
return requests
13139

14140

15141
def test_add_requests():
@@ -174,6 +300,7 @@ def test_no_mm_input_chunking():
174300
model="llava-hf/llava-1.5-7b-hf",
175301
max_num_batched_tokens=1024,
176302
disable_chunked_mm_input=True,
303+
max_model_len=2048,
177304
)
178305
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
179306
requests = create_requests(num_requests=1,
@@ -677,20 +804,17 @@ def _assert_right_kv_cache_manager(
677804
"""Check whether KVCacheManager is correct after allocate."""
678805

679806
# Make sure the request stats are right.
680-
EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size
681-
EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS +
682-
scheduler.kv_cache_manager.num_preallocate_blocks)
807+
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
683808
for req_id in req_ids:
684809
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
685810
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
686811
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
687-
EXPECTED_ACTUAL_BLOCKS)
812+
EXPECTED_TOTAL_BLOCKS)
688813
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
689-
assert len(hashes) == EXPECTED_ACTUAL_BLOCKS
814+
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
690815

691816
# Make sure we actually touched all the blocks.
692-
BLOCKS_PER_REQ = (num_tokens / block_size +
693-
scheduler.kv_cache_manager.num_preallocate_blocks)
817+
BLOCKS_PER_REQ = num_tokens / block_size
694818
assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ==
695819
num_total_blocks - num_requests * BLOCKS_PER_REQ)
696820

@@ -925,7 +1049,6 @@ def test_kv_connector_handles_preemption():
9251049
block_size=BLOCK_SIZE,
9261050
num_blocks=NUM_BLOCKS,
9271051
)
928-
scheduler.kv_cache_manager.num_preallocate_blocks = 0
9291052

9301053
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
9311054
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")

0 commit comments

Comments
 (0)