Skip to content

Commit fb1d347

Browse files
committed
little refactor and add CI test
Signed-off-by: Travis Johnson <[email protected]>
1 parent 48dfb3c commit fb1d347

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
692692
f"full_text_row_masked_out_mask[{idx}] must be " \
693693
f"'{must_be_masked}' "
694694
idx += 1
695+
696+
697+
@pytest.mark.core_model
698+
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
699+
([6404], [[4]], [6404]),
700+
([0, 6404], [[4]], [6404]),
701+
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
702+
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
703+
])
704+
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
705+
expected) -> None:
706+
707+
dummy = DummyModel()
708+
num_tokens_per_tile = 1601
709+
actual_encoder_seq_lens = MllamaForConditionalGeneration \
710+
._get_and_validate_encoder_lens(
711+
dummy,
712+
encoder_seq_lens,
713+
num_tiles,
714+
num_tokens_per_tile,
715+
)
716+
assert actual_encoder_seq_lens == expected, \
717+
f"Expected {expected} but got {actual_encoder_seq_lens}"

vllm/model_executor/models/mllama.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):
12501250

12511251
raise AssertionError("This line should be unreachable.")
12521252

1253+
def _get_and_validate_encoder_lens(
1254+
self,
1255+
encoder_seq_lens: List[int],
1256+
num_tiles: List[List[int]],
1257+
num_tokens_per_tile: int,
1258+
) -> List[int]:
1259+
# Get the actual number of encoder tokens for each sample.
1260+
# Because attn_metadata.encoder_seq_lens only counts the last
1261+
# group of images for each sample, which is used to cheat the
1262+
# block manager to allocate blocks for those images only.
1263+
# See input_processor_for_mllama() for more details.
1264+
actual_encoder_seq_lens = [
1265+
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1266+
]
1267+
1268+
# remove 0 encoder len entries for text-only requests for these
1269+
# assertions
1270+
attn_metadata_lens = [len for len in encoder_seq_lens if len > 0]
1271+
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
1272+
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
1273+
attn_metadata_lens):
1274+
assert actual_len >= last_group_len
1275+
1276+
return actual_encoder_seq_lens
1277+
12531278
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
12541279
attn_metadata: AttentionMetadata,
12551280
actual_encoder_seq_lens: List[int]):
@@ -1374,26 +1399,14 @@ def forward(
13741399
else:
13751400
skip_cross_attention = False
13761401

1377-
# Get the actual number of encoder tokens for each sample.
1378-
# Because attn_metadata.encoder_seq_lens only counts the last
1379-
# group of images for each sample, which is used to cheat the
1380-
# block manager to allocate blocks for those images only.
1381-
# See input_processor_for_mllama() for more details.
1382-
num_tiles_tensor = kwargs.pop("num_tiles")
1383-
num_tiles = [t.tolist() for t in num_tiles_tensor]
1402+
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
13841403
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
1385-
actual_encoder_seq_lens = [
1386-
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
1387-
]
13881404

1389-
# remove 0 entries for text-only requests for these assertions
1390-
attn_metadata_lens = [
1391-
len for len in attn_metadata.encoder_seq_lens if len > 0
1392-
]
1393-
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
1394-
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
1395-
attn_metadata_lens):
1396-
assert actual_len >= last_group_len
1405+
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
1406+
attn_metadata.encoder_seq_lens,
1407+
num_tiles,
1408+
num_tokens_per_tile,
1409+
)
13971410

13981411
cross_attention_states = self.get_cross_attention_states(
13991412
image_inputs, attn_metadata, actual_encoder_seq_lens)

0 commit comments

Comments
 (0)