Skip to content

Commit 7e8d3bd

Browse files
committed
Modify minicpmv test
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 2c79295 commit 7e8d3bd

File tree

3 files changed

+46
-99
lines changed

3 files changed

+46
-99
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ steps:
242242
source_file_dependencies:
243243
- vllm/lora
244244
- tests/lora
245-
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
245+
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
246246
parallelism: 4
247247

248248
- label: "PyTorch Fullgraph Smoke Test" # 9min
@@ -533,6 +533,7 @@ steps:
533533
# requires multi-GPU testing for validation.
534534
- pytest -v -s -x lora/test_chatglm3_tp.py
535535
- pytest -v -s -x lora/test_llama_tp.py
536+
- pytest -v -s -x lora/test_minicpmv_tp.py
536537

537538

538539
- label: Weight Loading Multiple GPU Test # 33min

tests/lora/test_minicpmv.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

tests/lora/test_minicpmv_tp.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import pytest
44

55
import vllm
6+
from tests.utils import fork_new_process_for_each_test
67
from vllm.assets.image import ImageAsset
78
from vllm.lora.request import LoRARequest
8-
9-
from ..utils import multi_gpu_test
9+
from vllm.platforms import current_platform
1010

1111
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
1212

@@ -17,13 +17,11 @@
1717

1818
IMAGE_ASSETS = [
1919
ImageAsset("stop_sign"),
20-
ImageAsset("cherry_blossom"),
2120
]
2221

2322
# After fine-tuning with LoRA, all generated content should start begin `A`.
2423
EXPECTED_OUTPUT = [
2524
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
26-
"A pink cherry blossom tree with a blue sky in the background.",
2725
]
2826

2927

@@ -50,37 +48,40 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
5048
# Print the outputs.
5149
generated_texts: List[str] = []
5250
for output in outputs:
53-
prompt = output.prompt
5451
generated_text = output.outputs[0].text.strip()
5552
generated_texts.append(generated_text)
56-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
53+
print(f"Generated text: {generated_text!r}")
5754
return generated_texts
5855

5956

60-
@multi_gpu_test(num_gpus=2)
61-
@pytest.mark.parametrize("fully_sharded", [True, False])
62-
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
57+
@pytest.mark.xfail(
58+
current_platform.is_rocm(),
59+
reason="MiniCPM-V dependency xformers incompatible with ROCm")
60+
@fork_new_process_for_each_test
61+
def test_minicpmv_lora(minicpmv_lora_files):
6362
llm = vllm.LLM(
6463
MODEL_PATH,
65-
enable_lora=True,
6664
max_num_seqs=2,
65+
enable_lora=True,
6766
max_loras=2,
6867
max_lora_rank=8,
69-
tensor_parallel_size=2,
68+
enforce_eager=True,
7069
trust_remote_code=True,
71-
fully_sharded_loras=fully_sharded,
7270
enable_chunked_prefill=True,
7371
)
74-
75-
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
76-
72+
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
7773
for i in range(len(EXPECTED_OUTPUT)):
78-
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
74+
assert EXPECTED_OUTPUT[i].startswith(output1[i])
75+
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
76+
for i in range(len(EXPECTED_OUTPUT)):
77+
assert EXPECTED_OUTPUT[i].startswith(output2[i])
7978

8079

81-
@multi_gpu_test(num_gpus=4)
82-
@pytest.mark.parametrize("fully_sharded", [True, False])
83-
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
80+
@pytest.mark.xfail(
81+
current_platform.is_rocm(),
82+
reason="MiniCPM-V dependency xformers incompatible with ROCm")
83+
@fork_new_process_for_each_test
84+
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
8485
llm = vllm.LLM(
8586
MODEL_PATH,
8687
enable_lora=True,
@@ -90,9 +91,32 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
9091
tensor_parallel_size=4,
9192
trust_remote_code=True,
9293
enforce_eager=True,
93-
fully_sharded_loras=fully_sharded,
9494
enable_chunked_prefill=True,
9595
)
9696
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
9797
for i in range(len(EXPECTED_OUTPUT)):
9898
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
99+
100+
101+
@pytest.mark.xfail(
102+
current_platform.is_rocm(),
103+
reason="MiniCPM-V dependency xformers incompatible with ROCm")
104+
@fork_new_process_for_each_test
105+
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
106+
llm = vllm.LLM(
107+
MODEL_PATH,
108+
enable_lora=True,
109+
max_num_seqs=2,
110+
max_loras=2,
111+
max_lora_rank=8,
112+
tensor_parallel_size=4,
113+
trust_remote_code=True,
114+
fully_sharded_loras=True,
115+
enable_chunked_prefill=True,
116+
)
117+
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
118+
for i in range(len(EXPECTED_OUTPUT)):
119+
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
120+
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=2)
121+
for i in range(len(EXPECTED_OUTPUT)):
122+
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])

0 commit comments

Comments
 (0)