@@ -1250,6 +1250,31 @@ def _parse_and_validate_image_input(self, **kwargs: object):
1250
1250
1251
1251
raise AssertionError ("This line should be unreachable." )
1252
1252
1253
+ def _get_and_validate_encoder_lens (
1254
+ self ,
1255
+ encoder_seq_lens : List [int ],
1256
+ num_tiles : List [List [int ]],
1257
+ num_tokens_per_tile : int ,
1258
+ ) -> List [int ]:
1259
+ # Get the actual number of encoder tokens for each sample.
1260
+ # Because attn_metadata.encoder_seq_lens only counts the last
1261
+ # group of images for each sample, which is used to cheat the
1262
+ # block manager to allocate blocks for those images only.
1263
+ # See input_processor_for_mllama() for more details.
1264
+ actual_encoder_seq_lens = [
1265
+ sum (num_tile ) * num_tokens_per_tile for num_tile in num_tiles
1266
+ ]
1267
+
1268
+ # remove 0 encoder len entries for text-only requests for these
1269
+ # assertions
1270
+ attn_metadata_lens = [len for len in encoder_seq_lens if len > 0 ]
1271
+ assert len (actual_encoder_seq_lens ) == len (attn_metadata_lens )
1272
+ for actual_len , last_group_len in zip (actual_encoder_seq_lens ,
1273
+ attn_metadata_lens ):
1274
+ assert actual_len >= last_group_len
1275
+
1276
+ return actual_encoder_seq_lens
1277
+
1253
1278
def flat_encoder_result (self , cross_attention_states : torch .Tensor ,
1254
1279
attn_metadata : AttentionMetadata ,
1255
1280
actual_encoder_seq_lens : List [int ]):
@@ -1374,26 +1399,14 @@ def forward(
1374
1399
else :
1375
1400
skip_cross_attention = False
1376
1401
1377
- # Get the actual number of encoder tokens for each sample.
1378
- # Because attn_metadata.encoder_seq_lens only counts the last
1379
- # group of images for each sample, which is used to cheat the
1380
- # block manager to allocate blocks for those images only.
1381
- # See input_processor_for_mllama() for more details.
1382
- num_tiles_tensor = kwargs .pop ("num_tiles" )
1383
- num_tiles = [t .tolist () for t in num_tiles_tensor ]
1402
+ num_tiles = [t .tolist () for t in kwargs .pop ("num_tiles" )]
1384
1403
num_tokens_per_tile = calc_token_per_chunk (self .image_size )
1385
- actual_encoder_seq_lens = [
1386
- sum (num_tile ) * num_tokens_per_tile for num_tile in num_tiles
1387
- ]
1388
1404
1389
- # remove 0 entries for text-only requests for these assertions
1390
- attn_metadata_lens = [
1391
- len for len in attn_metadata .encoder_seq_lens if len > 0
1392
- ]
1393
- assert len (actual_encoder_seq_lens ) == len (attn_metadata_lens )
1394
- for actual_len , last_group_len in zip (actual_encoder_seq_lens ,
1395
- attn_metadata_lens ):
1396
- assert actual_len >= last_group_len
1405
+ actual_encoder_seq_lens = self ._get_and_validate_encoder_lens (
1406
+ attn_metadata .encoder_seq_lens ,
1407
+ num_tiles ,
1408
+ num_tokens_per_tile ,
1409
+ )
1397
1410
1398
1411
cross_attention_states = self .get_cross_attention_states (
1399
1412
image_inputs , attn_metadata , actual_encoder_seq_lens )
0 commit comments