Skip to content

Commit 9a2160f

Browse files
authored
[V1] TPU CI - Add basic perf regression test (#15414)
Signed-off-by: Alexander Matveev <[email protected]>
1 parent 2de4118 commit 9a2160f

File tree

5 files changed

+192
-20
lines changed

5 files changed

+192
-20
lines changed

.buildkite/run-tpu-v1-test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ docker run --privileged --net host --shm-size=16G -it \
2121
&& python3 -m pip install lm_eval[api]==0.4.4 \
2222
&& export VLLM_USE_V1=1 \
2323
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
24+
&& echo TEST_0 \
25+
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
2426
&& echo TEST_1 \
2527
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
2628
&& echo TEST_2 \

tests/entrypoints/llm/test_accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
5858
more_args = None
5959
if current_platform.is_tpu():
6060
# Limit compilation time for TPU V1
61-
more_args = "max_num_seqs=64"
61+
more_args = "max_model_len=2048,max_num_seqs=64"
6262

6363
# Add TP test (if provided)
6464
if TPU_TP_TEST_STR:

tests/v1/tpu/test_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
@pytest.mark.parametrize("model", MODELS)
3333
@pytest.mark.parametrize("max_tokens", [5])
3434
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
35-
def test_models(
35+
def test_basic(
3636
vllm_runner: type[VllmRunner],
3737
monkeypatch: pytest.MonkeyPatch,
3838
model: str,
@@ -58,4 +58,5 @@ def test_models(
5858
vllm_outputs = vllm_model.generate_greedy(example_prompts,
5959
max_tokens)
6060
output = vllm_outputs[0][1]
61-
assert "1024" in output
61+
62+
assert "1024" in output or "0, 1" in output

tests/v1/tpu/test_perf.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""A basic performance regression test for TPUs
3+
4+
Run `pytest tests/v1/tpu/test_perf.py`.
5+
"""
6+
from __future__ import annotations
7+
8+
import time
9+
from dataclasses import dataclass
10+
from typing import TYPE_CHECKING
11+
12+
import numpy as np
13+
import pytest
14+
15+
from vllm.platforms import current_platform
16+
from vllm.sampling_params import SamplingParams
17+
from vllm.transformers_utils.tokenizer import get_tokenizer
18+
19+
if TYPE_CHECKING:
20+
from tests.conftest import VllmRunner
21+
22+
23+
@dataclass
24+
class TestParams:
25+
model: str
26+
num_prompts: int
27+
prefix_len: int
28+
decode_len: int
29+
expected_avg_time: float
30+
err_tol: float
31+
32+
33+
TEST_PARAMS = [
34+
# TODO: Cannot run a series of tests because:
35+
# RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
36+
# open(/dev/vfio/0): Device or resource busy: Device or resource busy;
37+
# Couldn't open iommu group /dev/vfio/0
38+
# => Investigate
39+
40+
# TestParams(
41+
# model="Qwen/Qwen2.5-1.5B-Instruct",
42+
# num_prompts=1,
43+
# prefix_len=10,
44+
# decode_len=5,
45+
# expected_avg_time=0.03,
46+
# err_tol=0.01,
47+
# ),
48+
# TestParams(
49+
# model="Qwen/Qwen2.5-1.5B-Instruct",
50+
# num_prompts=10,
51+
# prefix_len=100,
52+
# decode_len=50,
53+
# expected_avg_time=0.234,
54+
# err_tol=0.020,
55+
# ),
56+
TestParams(
57+
model="Qwen/Qwen2.5-1.5B-Instruct",
58+
num_prompts=64,
59+
prefix_len=500,
60+
decode_len=50,
61+
62+
# (This is the active CI/CD instance)
63+
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
64+
# tpu: v5lite (vllm CI/CD)
65+
expected_avg_time=1.4,
66+
err_tol=0.30,
67+
68+
# (TODO: There is no v6e in CI/CD currently)
69+
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
70+
# tpu: v6e
71+
# expected_avg_time=1.5,
72+
# err_tol=0.20,
73+
),
74+
]
75+
76+
NUM_WARMUPS = 5
77+
NUM_RUNS = 10
78+
79+
MAX_MODEL_LEN = 1024
80+
MAX_NUM_SEQS = 32
81+
GPU_UTIL = 0.9
82+
83+
84+
@pytest.mark.skipif(not current_platform.is_tpu(),
85+
reason="This is a basic performance test for TPU only")
86+
@pytest.mark.parametrize("params", TEST_PARAMS)
87+
def test_perf(
88+
vllm_runner: type[VllmRunner],
89+
monkeypatch: pytest.MonkeyPatch,
90+
params: TestParams,
91+
) -> None:
92+
tokenizer = get_tokenizer(params.model,
93+
tokenizer_mode="auto",
94+
trust_remote_code=True)
95+
96+
prompts = []
97+
for i in range(params.num_prompts):
98+
prefix_token_ids = np.random.randint(0,
99+
tokenizer.vocab_size,
100+
size=params.prefix_len).tolist()
101+
prompt = tokenizer.decode(prefix_token_ids)
102+
prompts.append(prompt)
103+
104+
print(
105+
"-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
106+
len(prompts), params.prefix_len, params.decode_len))
107+
108+
with monkeypatch.context() as m:
109+
m.setenv("VLLM_USE_V1", "1")
110+
111+
sampling_params = SamplingParams(max_tokens=params.decode_len,
112+
temperature=1.0,
113+
min_p=0.0)
114+
115+
with vllm_runner(params.model,
116+
max_num_batched_tokens=MAX_MODEL_LEN,
117+
max_model_len=MAX_MODEL_LEN,
118+
max_num_seqs=MAX_NUM_SEQS,
119+
gpu_memory_utilization=GPU_UTIL,
120+
enforce_eager=False,
121+
tensor_parallel_size=1) as vllm_model:
122+
print(" -- Warmup / Compile")
123+
for i in range(NUM_WARMUPS):
124+
_ = vllm_model.generate(prompts, sampling_params)
125+
126+
print(" -- Benchmarking... ")
127+
times = []
128+
for i in range(NUM_RUNS):
129+
start_time = time.time()
130+
_ = vllm_model.generate(prompts, sampling_params)
131+
times.append(time.time() - start_time)
132+
133+
avg_time = sum(times) / len(times)
134+
135+
print(" -- avg_time = {}".format(avg_time))
136+
print(" -- expected_avg_time = {} with err_tol = {}".format(
137+
params.expected_avg_time, params.err_tol))
138+
diff = avg_time - params.expected_avg_time
139+
ok = diff < params.err_tol
140+
if diff < -params.err_tol:
141+
print(" !! WARNING !! Performance has improved by {}, "
142+
"it may be necessary to fine-tune the "
143+
"expected_avg_time = {}".format(
144+
-diff, params.expected_avg_time))
145+
146+
assert ok, " !! ERROR !! Regression detected"

vllm/v1/worker/tpu_model_runner.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,12 @@ def __init__(
7777
parallel_config = self.parallel_config
7878
self.device = device
7979
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
80-
if self.check_recompilation:
81-
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
80+
8281
self.enforce_eager = model_config.enforce_eager
82+
83+
self.num_xla_graphs = 0
84+
self._update_num_xla_graphs("init")
85+
8386
self.pin_memory = is_pin_memory_available()
8487
self.dtype = self.model_config.dtype
8588
self._hidden_states_dtype = self.dtype
@@ -180,6 +183,31 @@ def __init__(
180183
max_token_size=self.max_num_tokens,
181184
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
182185

186+
def _update_num_xla_graphs(self, case_str):
187+
check_comp = self.check_recompilation and not self.enforce_eager
188+
if not check_comp:
189+
return
190+
191+
total_cached_graphs = xr.get_num_cached_compilation_graph()
192+
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
193+
if new_compiled_graphs == 0:
194+
return
195+
196+
logger.info("Add new %d compiled XLA graphs due to %s",
197+
new_compiled_graphs, case_str)
198+
self.num_xla_graphs += new_compiled_graphs
199+
200+
def _verify_num_xla_graphs(self, case_str):
201+
check_comp = self.check_recompilation and not self.enforce_eager
202+
if not check_comp:
203+
return
204+
205+
curr_cached_graph = xr.get_num_cached_compilation_graph()
206+
assert self.num_xla_graphs == curr_cached_graph, (
207+
"Recompilation after warm up is detected during {}."
208+
" num_xla_graphs = {} curr_cached_graph = {}".format(
209+
case_str, self.num_xla_graphs, curr_cached_graph))
210+
183211
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
184212
"""Update the cached states and the persistent batch with the scheduler
185213
output.
@@ -694,12 +722,11 @@ def execute_model(
694722
logprobs=None,
695723
prompt_logprobs_dict=prompt_logprobs_dict,
696724
)
697-
# Check there is no new graph compilation, all the graphs should be
698-
# captured and compiled during warming up.
699-
if self.check_recompilation and not self.enforce_eager:
700-
curr_cached_graph = xr.get_num_cached_compilation_graph()
701-
assert self.num_xla_graphs == curr_cached_graph, (
702-
"Recompilation after warm up is detected.")
725+
726+
# Check there are no new graphs compiled - all the graphs should be
727+
# captured and compiled during warm up.
728+
self._verify_num_xla_graphs("execute_model")
729+
703730
return model_runner_output
704731

705732
def load_model(self) -> None:
@@ -797,7 +824,9 @@ def capture_model(self) -> None:
797824
xm.mark_step()
798825
xm.wait_device_ops()
799826
end = time.perf_counter()
827+
800828
logger.info("Compilation finished in in %.2f [secs].", end - start)
829+
self._update_num_xla_graphs("model")
801830

802831
logger.info("Compiling sampling with different input shapes.")
803832
start = time.perf_counter()
@@ -832,15 +861,9 @@ def capture_model(self) -> None:
832861
num_reqs_to_sample + 1, self.max_num_reqs)
833862
xm.wait_device_ops()
834863
end = time.perf_counter()
835-
logger.info("Compilation finished in %.2f [secs].", end - start)
836-
# Record the number cached XLA graph after warming up, this will be
837-
# used for checking there is no additional graph compilation during
838-
# runtime execution.
839-
if self.check_recompilation:
840-
total_cached_graphs = xr.get_num_cached_compilation_graph()
841-
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
842-
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
843-
self.num_xla_graphs += num_compiled_graphs
864+
865+
logger.info("Compilation finished in in %.2f [secs].", end - start)
866+
self._update_num_xla_graphs("sampling")
844867

845868
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
846869
"""

0 commit comments

Comments
 (0)