Skip to content

[Bugfix] Fix hybrid model tests #17182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 25, 2025
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 66 additions & 101 deletions tests/models/decoder_only/language/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,61 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal
from ...utils import check_logprobs_close, check_outputs_equal

# This test is for the hybrid models
MODELS = [
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
"pfnet/plamo-2-1b"
"ibm-ai-platform/Bamba-9B",
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
# Note: Running Plamo2 in transformers implementation requires to install
# NOTE: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.

# Avoid OOM
MAX_NUM_SEQS = 8


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
# numeric error produces different generation
if "Bamba" in model:
example_prompts.pop(3)
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
Expand All @@ -79,70 +77,50 @@ def test_batching(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_mamba_prefill_chunking_with_parallel_sampling(
hf_runner, vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int) -> None:
vllm_runner,
example_prompts,
model: str,
max_tokens: int,
) -> None:
# Tests prefill chunking in conjunction with n>1, in this case,
# prefill is populated with decoding tokens and we test that it
# doesn't fail This test might fail if cache is not allocated
# correctly for n > 1 decoding steps inside a
# chunked prefill forward pass (where we have both prefills
# and decoding together )

if 'plamo-2' in model:
dtype = "float" # use a different dtype for plamo
# and decoding together)

sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
max_tokens=max_tokens)
with vllm_runner(
model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=30,
max_num_seqs=10 # forces prefill chunks with decoding
max_num_seqs=MAX_NUM_SEQS # forces prefill chunks with decoding
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [7])
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str,
def test_mamba_prefill_chunking(vllm_runner, example_prompts, model: str,
max_tokens: int) -> None:
# numeric error during prefill chunking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now
if "Jamba" in model:
example_prompts.pop(7)
example_prompts.pop(2)
example_prompts.pop(1)
elif "Bamba" in model:
example_prompts.pop(6)
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba

elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
elif "plamo-2-1b" in model:
example_prompts.pop(7)

with hf_runner(model, dtype=dtype) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=5,
max_num_seqs=2) as vllm_model:
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
chunked = vllm_model.generate_greedy(example_prompts,
max_tokens=max_tokens)

with vllm_runner(model,
enable_chunked_prefill=False,
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
non_chunked = vllm_model.generate_greedy(example_prompts,
max_tokens=max_tokens)

check_outputs_equal(
outputs_0_lst=chunked,
outputs_1_lst=non_chunked,
Expand All @@ -152,17 +130,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [15])
def test_parallel_sampling(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
Expand All @@ -188,15 +163,12 @@ def test_parallel_sampling(
)


@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is for verifying that mamba cache is padded to CG captured
Expand All @@ -209,7 +181,7 @@ def test_mamba_cache_cg_padding(
example_prompts.append(example_prompts[0])

try:
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError:
pytest.fail(
Expand All @@ -219,20 +191,15 @@ def test_mamba_cache_cg_padding(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
Expand All @@ -251,12 +218,10 @@ def test_models_preemption_recompute(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
model: str,
) -> None:
# This test is for verifying that the hybrid inner state management doesn't
# collapse in case where the number of incoming requests and
Expand All @@ -265,61 +230,58 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Hybrid inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
model: str,
) -> None:
# This test is for verifying that the Hybrid state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Hybrid inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")


@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
vllm_runner,
model: str,
dtype: str,
example_prompts,
model: str,
) -> None:
# This test is verifying that multistep works correctly
#on mamba-like models
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)


@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
max_tokens: int, example_prompts) -> None:
def test_multistep_correctness(
vllm_runner,
example_prompts,
model: str,
max_tokens: int,
) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy(
example_prompts, max_tokens)

with vllm_runner(model, num_scheduler_steps=1,
max_num_seqs=2) as vllm_model:
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_single_step = vllm_model.generate_greedy(
example_prompts, max_tokens)

Expand All @@ -333,17 +295,20 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str,

@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_hybrid_distributed_produces_identical_generation(
vllm_runner, model: str, dtype: str, max_tokens: int,
example_prompts) -> None:

with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model:
vllm_runner,
example_prompts,
model: str,
max_tokens: int,
) -> None:
with vllm_runner(model, tensor_parallel_size=2,
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts,
max_tokens)

with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model:
with vllm_runner(model, tensor_parallel_size=1,
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts,
max_tokens)

Expand Down