@@ -1306,6 +1306,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):
1306
1306
1307
1307
raise AssertionError ("This line should be unreachable." )
1308
1308
1309
+ def _get_and_validate_encoder_lens (
1310
+ self ,
1311
+ encoder_seq_lens : List [int ],
1312
+ num_tiles : List [List [int ]],
1313
+ num_tokens_per_tile : int ,
1314
+ ) -> List [int ]:
1315
+ # Get the actual number of encoder tokens for each sample.
1316
+ # Because attn_metadata.encoder_seq_lens only counts the last
1317
+ # group of images for each sample, which is used to cheat the
1318
+ # block manager to allocate blocks for those images only.
1319
+ # See input_processor_for_mllama() for more details.
1320
+ actual_encoder_seq_lens = [
1321
+ sum (num_tile ) * num_tokens_per_tile for num_tile in num_tiles
1322
+ ]
1323
+
1324
+ # remove 0 encoder len entries for text-only requests for these
1325
+ # assertions
1326
+ attn_metadata_lens = [len for len in encoder_seq_lens if len > 0 ]
1327
+ assert len (actual_encoder_seq_lens ) == len (attn_metadata_lens )
1328
+ for actual_len , last_group_len in zip (actual_encoder_seq_lens ,
1329
+ attn_metadata_lens ):
1330
+ assert actual_len >= last_group_len
1331
+
1332
+ return actual_encoder_seq_lens
1333
+
1309
1334
def flat_encoder_result (self , cross_attention_states : torch .Tensor ,
1310
1335
attn_metadata : AttentionMetadata ,
1311
1336
actual_encoder_seq_lens : List [int ]):
@@ -1430,26 +1455,14 @@ def forward(
1430
1455
else :
1431
1456
skip_cross_attention = False
1432
1457
1433
- # Get the actual number of encoder tokens for each sample.
1434
- # Because attn_metadata.encoder_seq_lens only counts the last
1435
- # group of images for each sample, which is used to cheat the
1436
- # block manager to allocate blocks for those images only.
1437
- # See MllamaMultiModalProcessor for more details.
1438
- num_tiles_tensor = kwargs .pop ("num_tiles" )
1439
- num_tiles = [t .tolist () for t in num_tiles_tensor ]
1458
+ num_tiles = [t .tolist () for t in kwargs .pop ("num_tiles" )]
1440
1459
num_tokens_per_tile = calc_token_per_chunk (self .image_size )
1441
- actual_encoder_seq_lens = [
1442
- sum (num_tile ) * num_tokens_per_tile for num_tile in num_tiles
1443
- ]
1444
1460
1445
- # remove 0 entries for text-only requests for these assertions
1446
- attn_metadata_lens = [
1447
- len for len in attn_metadata .encoder_seq_lens if len > 0
1448
- ]
1449
- assert len (actual_encoder_seq_lens ) == len (attn_metadata_lens )
1450
- for actual_len , last_group_len in zip (actual_encoder_seq_lens ,
1451
- attn_metadata_lens ):
1452
- assert actual_len >= last_group_len
1461
+ actual_encoder_seq_lens = self ._get_and_validate_encoder_lens (
1462
+ attn_metadata .encoder_seq_lens ,
1463
+ num_tiles ,
1464
+ num_tokens_per_tile ,
1465
+ )
1453
1466
1454
1467
cross_attention_states = self .get_cross_attention_states (
1455
1468
image_inputs , attn_metadata , actual_encoder_seq_lens )
0 commit comments