@@ -404,73 +404,62 @@ def replace_text_matches(
404
404
return "" .join (texts )
405
405
406
406
407
- def _iter_modality_placeholders (
407
+ def _iter_placeholders (
408
+ mm_prompt_repls : Mapping [str , Sequence [BoundPromptReplacement ]],
408
409
prompt : list [int ],
409
- modality : str ,
410
- modality_repls : Sequence [BoundPromptReplacement ],
411
- modal_item_count : int ,
410
+ mm_item_counts : Mapping [str , int ],
412
411
) -> Iterable [PlaceholderInfo ]:
413
- if modal_item_count == 0 :
414
- return
412
+ """
413
+ Yield each set of placeholder tokens found in :code:`prompt`.
414
+
415
+ Matches are exclusive even when multiple modalities share
416
+ the same placeholder tokens. In that case, the modality that
417
+ appears earlier in `mm_prompt_repls` takes priority.
415
418
419
+ Note that empty matches are ignored.
420
+ """
416
421
prompt_len = len (prompt )
417
- item_idx = 0
422
+ item_idx_by_modality = defaultdict [ str , int ]( lambda : 0 )
418
423
419
424
start_idx = 0
420
425
while start_idx < prompt_len :
421
426
found = False
422
427
423
- for repl_info in modality_repls :
424
- replacement = repl_info .get_replacement (item_idx )
425
- repl_tokens = replacement .token_ids
426
- repl_len = len (repl_tokens )
427
- end_idx = start_idx + repl_len
428
-
429
- if repl_len == 0 or end_idx > prompt_len :
428
+ for modality , modality_repls in mm_prompt_repls .items ():
429
+ item_idx = item_idx_by_modality [modality ]
430
+ if item_idx >= mm_item_counts .get (modality , 0 ):
430
431
continue
431
432
432
- if prompt [start_idx :end_idx ] == repl_tokens :
433
- yield PlaceholderInfo (
434
- modality = modality ,
435
- item_idx = item_idx ,
436
- start_idx = start_idx ,
437
- replacement = repl_tokens ,
438
- )
433
+ for repl_info in modality_repls :
434
+ replacement = repl_info .get_replacement (item_idx )
435
+ repl_tokens = replacement .token_ids
436
+ repl_len = len (repl_tokens )
437
+ end_idx = start_idx + repl_len
438
+
439
+ if repl_len == 0 or end_idx > prompt_len :
440
+ continue
441
+
442
+ if prompt [start_idx :end_idx ] == repl_tokens :
443
+ yield PlaceholderInfo (
444
+ modality = modality ,
445
+ item_idx = item_idx ,
446
+ start_idx = start_idx ,
447
+ replacement = repl_tokens ,
448
+ )
439
449
440
- item_idx += 1
441
- if item_idx >= modal_item_count :
442
- return
450
+ # Exclude overlapping matches
451
+ start_idx = end_idx
452
+ item_idx_by_modality [modality ] += 1
453
+ found = True
454
+ break
443
455
444
- # Exclude overlapping matches
445
- start_idx = end_idx
446
- found = True
447
- break
456
+ if found :
457
+ break # Go back to the outer while loop
448
458
449
459
if not found :
450
460
start_idx += 1
451
461
452
462
453
- def _iter_placeholders (
454
- mm_prompt_repls : Mapping [str , Sequence [BoundPromptReplacement ]],
455
- prompt : list [int ],
456
- mm_item_counts : Mapping [str , int ],
457
- ) -> Iterable [PlaceholderInfo ]:
458
- """
459
- For each modality, yield each set of placeholder tokens found in
460
- :code:`prompt`.
461
-
462
- Note that empty matches are ignored.
463
- """
464
- for modality , modal_item_count in mm_item_counts .items ():
465
- if modality in mm_prompt_repls :
466
- yield from _iter_modality_placeholders (
467
- prompt ,
468
- modality ,
469
- mm_prompt_repls [modality ],
470
- modal_item_count ,
471
- )
472
-
473
-
474
463
def find_mm_placeholders (
475
464
mm_prompt_repls : Mapping [str , Sequence [BoundPromptReplacement ]],
476
465
prompt : list [int ],
@@ -1156,7 +1145,7 @@ def apply(
1156
1145
1157
1146
# If HF processor already inserts placeholder tokens,
1158
1147
# there is no need for us to insert them
1159
- if all (len (repls ) == 0 for repls in mm_missing_repls .items ()):
1148
+ if all (len (repls ) == 0 for repls in mm_missing_repls .values ()):
1160
1149
tokenizer = self .info .get_tokenizer ()
1161
1150
prompt = decode_tokens (tokenizer , prompt_ids )
1162
1151
mm_placeholders = hf_mm_placeholders
0 commit comments