Skip to content

[CI/Build] Add test decorator for minimum GPU memory #8925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files):
assert output2[i] == expected_lora_output[i]


@pytest.mark.skip("Requires multiple GPUs")
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
num_gpus_available, fully_sharded):
if num_gpus_available < 4:
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")

llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
Expand Down
17 changes: 8 additions & 9 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def format_prompt_tuples(prompt):

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tp_size):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

llm = vllm.LLM(
model=model.model_path,
Expand Down Expand Up @@ -164,11 +164,10 @@ def expect_match(output, expected_output):


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.skip("Requires multiple GPUs")
def test_quant_model_tp_equality(tinyllama_lora_files, model):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 2:
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
model):
if num_gpus_available < 2:
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")

llm_tp1 = vllm.LLM(
model=model.model_path,
Expand Down
13 changes: 2 additions & 11 deletions tests/models/decoder_only/language/test_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.utils import is_cpu

from ....utils import large_gpu_test
from ...utils import check_logprobs_close

MODELS = [
Expand Down Expand Up @@ -69,20 +70,10 @@ def test_phimoe_routing_function():
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])


def get_gpu_memory():
try:
props = torch.cuda.get_device_properties(torch.cuda.current_device())
gpu_memory = props.total_memory / (1024**3)
return gpu_memory
except Exception:
return 0


@pytest.mark.skipif(condition=is_cpu(),
reason="This test takes a lot time to run on CPU, "
"and vllm CI's disk space is not enough for this model.")
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
reason="Skip this test if GPU memory is insufficient.")
@large_gpu_test(min_gb=80)
Copy link
Member Author

@DarkLight1337 DarkLight1337 Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of any GPUs with over 100 GB memory... so I'm not sure what is the point of this check tbh. Setting it to 80 GB for now, please tell me if this can't even run on single A100/H100 since I don't have access to one atm.

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_VideoAssets)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close

# Video test
Expand Down Expand Up @@ -164,9 +165,7 @@ def process(hf_inputs: BatchEncoding):
)


@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
Expand Down Expand Up @@ -210,9 +209,7 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
)


@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
Expand Down Expand Up @@ -306,9 +303,7 @@ def process(hf_inputs: BatchEncoding):
)


@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
Expand Down
12 changes: 3 additions & 9 deletions tests/models/decoder_only/vision_language/test_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs

from ....utils import VLLM_PATH
from ....utils import VLLM_PATH, large_gpu_test
from ...utils import check_logprobs_close

if TYPE_CHECKING:
Expand Down Expand Up @@ -121,10 +121,7 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
for tokens, text, logprobs in json_data]


@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -157,10 +154,7 @@ def test_chat(
name_1="output")


@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
Expand Down
42 changes: 20 additions & 22 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close

_LIMIT_IMAGE_PER_PROMPT = 1
Expand Down Expand Up @@ -227,29 +228,26 @@ def process(hf_inputs: BatchEncoding):
)


SIZES = [
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
]


@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("sizes", SIZES)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
Expand Down
35 changes: 33 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
get_open_port, is_hip)
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip)

if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage,
Expand Down Expand Up @@ -455,6 +455,37 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
return wrapper


def large_gpu_test(*, min_gb: int):
"""
Decorate a test to be skipped if no GPU is available or it does not have
sufficient memory.

Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
"""
try:
if current_platform.is_cpu():
memory_gb = 0
else:
memory_gb = current_platform.get_device_total_memory() / GB_bytes
except Exception as e:
warnings.warn(
f"An error occurred when finding the available memory: {e}",
stacklevel=2,
)

memory_gb = 0

test_skipif = pytest.mark.skipif(
memory_gb < min_gb,
reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
)

def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_skipif(fork_new_process_for_each_test(f))

return wrapper


def multi_gpu_test(*, num_gpus: int):
"""
Decorate a test to be run only when multiple GPUs are available.
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import psutil
import torch

from .interface import Platform, PlatformEnum
Expand All @@ -10,6 +11,10 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total

@classmethod
def inference_mode(cls):
return torch.no_grad()
12 changes: 12 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
return pynvml.nvmlDeviceGetName(handle)


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_total_memory(device_id: int = 0) -> int:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)


@with_nvml_context
def warn_if_different_devices():
device_ids: int = pynvml.nvmlDeviceGetCount()
Expand Down Expand Up @@ -107,6 +114,11 @@ def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_total_memory(physical_device_id)

@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def has_device_capability(

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
raise NotImplementedError

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
raise NotImplementedError

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory
4 changes: 4 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class TpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError

@classmethod
def inference_mode(cls):
return torch.no_grad()
14 changes: 8 additions & 6 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ class XPUPlatform(Platform):

@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
return DeviceCapability(major=int(
torch.xpu.get_device_capability(device_id)['version'].split('.')
[0]),
minor=int(
torch.xpu.get_device_capability(device_id)
['version'].split('.')[1]))
major, minor, *_ = torch.xpu.get_device_capability(
device_id)['version'].split('.')
return DeviceCapability(major=int(major), minor=int(minor))

@staticmethod
def get_device_name(device_id: int = 0) -> str:
return torch.xpu.get_device_name(device_id)

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
3 changes: 3 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""

GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""

Expand Down
Loading