Skip to content

[Bugfix] handle alignment of encoder_seq_lens in mllama.py #14784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,15 @@ def _run_test(
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=8192,
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
with vllm_runner(
model,
dtype=dtype,
max_model_len=19212, # 3 max size images
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
Expand Down Expand Up @@ -507,7 +508,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
model,
dtype=dtype,
max_model_len=8192,
max_num_seqs=2,
max_num_seqs=4,
tensor_parallel_size=1,
limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
Expand Down Expand Up @@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
num_logprobs,
images=images)

# Mixed batch with text and images with different numbers of tiles
prompts = [
"<|begin_of_text|>Hello!",
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
]
images = [
None,
[stop_sign],
# smaller image must be 2nd for the repro
[stop_sign.resize((448, 448))],
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)


class DummyModel:
image_token_id = MLLAMA_IMAGE_TOKEN_ID
Expand Down Expand Up @@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
f"full_text_row_masked_out_mask[{idx}] must be " \
f"'{must_be_masked}' "
idx += 1


@pytest.mark.core_model
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
([6404], [[4]], [6404]),
([0, 6404], [[4]], [6404]),
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
])
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
expected) -> None:

dummy = DummyModel()
num_tokens_per_tile = 1601
actual_encoder_seq_lens = MllamaForConditionalGeneration \
._get_and_validate_encoder_lens(
dummy,
encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)
assert actual_encoder_seq_lens == expected, \
f"Expected {expected} but got {actual_encoder_seq_lens}"
45 changes: 32 additions & 13 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):

raise AssertionError("This line should be unreachable.")

def _get_and_validate_encoder_lens(
self,
encoder_seq_lens: List[int],
num_tiles: List[List[int]],
num_tokens_per_tile: int,
) -> List[int]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]

# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
attn_metadata_lens):
assert actual_len >= last_group_len

return actual_encoder_seq_lens

def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
Expand Down Expand Up @@ -1436,20 +1461,14 @@ def forward(
else:
skip_cross_attention = False

# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t.tolist() for t in num_tiles_tensor]
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
for actual_len, last_group_len in zip(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len

actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
attn_metadata.encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)

cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
Expand Down