Skip to content

Commit abdd700

Browse files
tjohnson31415Mu Huai
authored and
Mu Huai
committed
[Bugfix] handle alignment of encoder_seq_lens in mllama.py (vllm-project#14784)
Signed-off-by: Travis Johnson <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent c38673f commit abdd700

File tree

2 files changed

+82
-22
lines changed

2 files changed

+82
-22
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,15 @@ def _run_test(
209209
# will hurt multiprocessing backend with fork method (the default method).
210210

211211
# max_model_len should be greater than image_feature_size
212-
with vllm_runner(model,
213-
dtype=dtype,
214-
max_model_len=8192,
215-
max_num_seqs=3,
216-
tensor_parallel_size=tensor_parallel_size,
217-
distributed_executor_backend=distributed_executor_backend,
218-
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
219-
}) as vllm_model:
212+
with vllm_runner(
213+
model,
214+
dtype=dtype,
215+
max_model_len=19212, # 3 max size images
216+
max_num_seqs=3,
217+
tensor_parallel_size=tensor_parallel_size,
218+
distributed_executor_backend=distributed_executor_backend,
219+
limit_mm_per_prompt={"image":
220+
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
220221
vllm_outputs_per_image = [
221222
vllm_model.generate_greedy_logprobs(prompts,
222223
max_tokens,
@@ -507,7 +508,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
507508
model,
508509
dtype=dtype,
509510
max_model_len=8192,
510-
max_num_seqs=2,
511+
max_num_seqs=4,
511512
tensor_parallel_size=1,
512513
limit_mm_per_prompt={"image":
513514
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
@@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
552553
num_logprobs,
553554
images=images)
554555

556+
# Mixed batch with text and images with different numbers of tiles
557+
prompts = [
558+
"<|begin_of_text|>Hello!",
559+
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
560+
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
561+
]
562+
images = [
563+
None,
564+
[stop_sign],
565+
# smaller image must be 2nd for the repro
566+
[stop_sign.resize((448, 448))],
567+
]
568+
vllm_model.generate_greedy_logprobs(prompts,
569+
max_tokens,
570+
num_logprobs,
571+
images=images)
572+
555573

556574
class DummyModel:
557575
image_token_id = MLLAMA_IMAGE_TOKEN_ID
@@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
674692
f"full_text_row_masked_out_mask[{idx}] must be " \
675693
f"'{must_be_masked}' "
676694
idx += 1
695+
696+
697+
@pytest.mark.core_model
698+
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
699+
([6404], [[4]], [6404]),
700+
([0, 6404], [[4]], [6404]),
701+
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
702+
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
703+
])
704+
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
705+
expected) -> None:
706+
707+
dummy = DummyModel()
708+
num_tokens_per_tile = 1601
709+
actual_encoder_seq_lens = MllamaForConditionalGeneration \
710+
._get_and_validate_encoder_lens(
711+
dummy,
712+
encoder_seq_lens,
713+
num_tiles,
714+
num_tokens_per_tile,
715+
)
716+
assert actual_encoder_seq_lens == expected, \
717+
f"Expected {expected} but got {actual_encoder_seq_lens}"

vllm/model_executor/models/mllama.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):
13011301

13021302
raise AssertionError("This line should be unreachable.")
13031303

1304+
def _get_and_validate_encoder_lens(
1305+
self,
1306+
encoder_seq_lens: List[int],
1307+
num_tiles: List[List[int]],
1308+
num_tokens_per_tile: int,
1309+
) -> List[int]:
1310+
# Get the actual number of encoder tokens for each sample.
1311+
# Because attn_metadata.encoder_seq_lens only counts the last
1312+
# group of images for each sample, which is used to cheat the
1313+
# block manager to allocate blocks for those images only.
1314+
# See MllamaMultiModalProcessor for more details.
1315+
actual_encoder_seq_lens = [
1316+
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1317+
]
1318+
1319+
# remove 0 encoder len entries for text-only requests for these
1320+
# assertions
1321+
attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
1322+
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
1323+
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
1324+
attn_metadata_lens):
1325+
assert actual_len >= last_group_len
1326+
1327+
return actual_encoder_seq_lens
1328+
13041329
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
13051330
attn_metadata: AttentionMetadata,
13061331
actual_encoder_seq_lens: List[int]):
@@ -1428,20 +1453,14 @@ def forward(
14281453
else:
14291454
skip_cross_attention = False
14301455

1431-
# Get the actual number of encoder tokens for each sample.
1432-
# Because attn_metadata.encoder_seq_lens only counts the last
1433-
# group of images for each sample, which is used to cheat the
1434-
# block manager to allocate blocks for those images only.
1435-
# See MllamaMultiModalProcessor for more details.
1436-
num_tiles_tensor = kwargs.pop("num_tiles")
1437-
num_tiles = [t.tolist() for t in num_tiles_tensor]
1456+
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
14381457
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
1439-
actual_encoder_seq_lens = [
1440-
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1441-
]
1442-
for actual_len, last_group_len in zip(
1443-
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
1444-
assert actual_len >= last_group_len
1458+
1459+
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
1460+
attn_metadata.encoder_seq_lens,
1461+
num_tiles,
1462+
num_tokens_per_tile,
1463+
)
14451464

14461465
cross_attention_states = self.get_cross_attention_states(
14471466
image_inputs, attn_metadata, actual_encoder_seq_lens)

0 commit comments

Comments
 (0)