Skip to content

Commit 582552e

Browse files
committed
little refactor and add CI test
Signed-off-by: Travis Johnson <[email protected]>
1 parent 87144b7 commit 582552e

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
691691
f"full_text_row_masked_out_mask[{idx}] must be " \
692692
f"'{must_be_masked}' "
693693
idx += 1
694+
695+
696+
@pytest.mark.core_model
697+
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
698+
([6404], [[4]], [6404]),
699+
([0, 6404], [[4]], [6404]),
700+
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
701+
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
702+
])
703+
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
704+
expected) -> None:
705+
706+
dummy = DummyModel()
707+
num_tokens_per_tile = 1601
708+
actual_encoder_seq_lens = MllamaForConditionalGeneration \
709+
._get_and_validate_encoder_lens(
710+
dummy,
711+
encoder_seq_lens,
712+
num_tiles,
713+
num_tokens_per_tile,
714+
)
715+
assert actual_encoder_seq_lens == expected, \
716+
f"Expected {expected} but got {actual_encoder_seq_lens}"

vllm/model_executor/models/mllama.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):
13061306

13071307
raise AssertionError("This line should be unreachable.")
13081308

1309+
def _get_and_validate_encoder_lens(
1310+
self,
1311+
encoder_seq_lens: List[int],
1312+
num_tiles: List[List[int]],
1313+
num_tokens_per_tile: int,
1314+
) -> List[int]:
1315+
# Get the actual number of encoder tokens for each sample.
1316+
# Because attn_metadata.encoder_seq_lens only counts the last
1317+
# group of images for each sample, which is used to cheat the
1318+
# block manager to allocate blocks for those images only.
1319+
# See input_processor_for_mllama() for more details.
1320+
actual_encoder_seq_lens = [
1321+
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1322+
]
1323+
1324+
# remove 0 encoder len entries for text-only requests for these
1325+
# assertions
1326+
attn_metadata_lens = [len for len in encoder_seq_lens if len > 0]
1327+
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
1328+
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
1329+
attn_metadata_lens):
1330+
assert actual_len >= last_group_len
1331+
1332+
return actual_encoder_seq_lens
1333+
13091334
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
13101335
attn_metadata: AttentionMetadata,
13111336
actual_encoder_seq_lens: List[int]):
@@ -1430,26 +1455,14 @@ def forward(
14301455
else:
14311456
skip_cross_attention = False
14321457

1433-
# Get the actual number of encoder tokens for each sample.
1434-
# Because attn_metadata.encoder_seq_lens only counts the last
1435-
# group of images for each sample, which is used to cheat the
1436-
# block manager to allocate blocks for those images only.
1437-
# See MllamaMultiModalProcessor for more details.
1438-
num_tiles_tensor = kwargs.pop("num_tiles")
1439-
num_tiles = [t.tolist() for t in num_tiles_tensor]
1458+
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
14401459
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
1441-
actual_encoder_seq_lens = [
1442-
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1443-
]
14441460

1445-
# remove 0 entries for text-only requests for these assertions
1446-
attn_metadata_lens = [
1447-
len for len in attn_metadata.encoder_seq_lens if len > 0
1448-
]
1449-
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
1450-
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
1451-
attn_metadata_lens):
1452-
assert actual_len >= last_group_len
1461+
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
1462+
attn_metadata.encoder_seq_lens,
1463+
num_tiles,
1464+
num_tokens_per_tile,
1465+
)
14531466

14541467
cross_attention_states = self.get_cross_attention_states(
14551468
image_inputs, attn_metadata, actual_encoder_seq_lens)

0 commit comments

Comments
 (0)