From 69fb523305501a452a7ae165ab60715c91974ec6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 21 Jan 2025 14:37:37 +0000 Subject: [PATCH] Simplify post-processing of replacement info Signed-off-by: DarkLight1337 --- .../multimodal/processing/test_common.py | 2 +- tests/models/registry.py | 3 +- tests/multimodal/test_processing.py | 42 +++--- vllm/model_executor/models/aria.py | 10 +- vllm/model_executor/models/blip2.py | 33 ++--- vllm/model_executor/models/chameleon.py | 42 ++---- vllm/model_executor/models/fuyu.py | 38 ++---- vllm/model_executor/models/phi3v.py | 45 +++---- vllm/model_executor/models/qwen2_audio.py | 43 ++---- vllm/multimodal/processing.py | 125 ++++++++++++------ 10 files changed, 175 insertions(+), 208 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index d6d3d3b34ad..fe5b733c750 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -35,7 +35,7 @@ def _test_processing_correctness( task="auto", tokenizer=model_id, tokenizer_mode="auto", - trust_remote_code=True, + trust_remote_code=model_info.trust_remote_code, seed=0, dtype="float16", revision=None, diff --git a/tests/models/registry.py b/tests/models/registry.py index e99dbd16c47..0bd06dea0ec 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -261,7 +261,8 @@ def check_available_online( trust_remote_code=True), "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"), + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3", + trust_remote_code=True), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 9e58ed4cfde..13f820d013e 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -7,12 +7,16 @@ from vllm.config import ModelConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement, +# yapf conflicts with isort for this block +# yapf: disable +from vllm.multimodal.processing import (PlaceholderFeaturesInfo, + PromptReplacement, find_mm_placeholders, find_text_matches, find_token_matches, iter_token_matches, replace_text_matches, replace_token_matches) +# yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -433,19 +437,19 @@ def test_find_replace_tokens( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], { "pattern_1": [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_1", item_idx=0, start_idx=6, - replacement=[32000, 32000], + tokens=[32000, 32000], ), ], "pattern_4": [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_4", item_idx=0, start_idx=3, - replacement=[32000], + tokens=[32000], ), ], } @@ -455,25 +459,25 @@ def test_find_replace_tokens( [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], { "pattern_1": [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_1", item_idx=0, start_idx=1, - replacement=[32000, 32000], + tokens=[32000, 32000], ), - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_1", item_idx=1, start_idx=5, - replacement=[32000, 32000], + tokens=[32000, 32000], ), ], "pattern_3": [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_3", item_idx=0, start_idx=7, - replacement=[1550, 918, 1550], + tokens=[1550, 918, 1550], ), ], # No match for pattern_4 as it has lower priority than pattern_1 @@ -483,33 +487,33 @@ def test_find_replace_tokens( [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], { "pattern_1": [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_1", item_idx=0, start_idx=1, - replacement=[32000, 32000], + tokens=[32000, 32000], ), - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_1", item_idx=1, start_idx=3, - replacement=[32000, 32000], + tokens=[32000, 32000], ), ], "pattern_4": [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_4", item_idx=0, start_idx=5, - replacement=[32000], + tokens=[32000], ), ], "pattern_3": [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality="pattern_3", item_idx=0, start_idx=6, - replacement=[1550, 918, 1550], + tokens=[1550, 918, 1550], ), ], } diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 503d1a38d9e..34ac39b812d 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -342,13 +342,7 @@ def get_vision_config(self): return self.get_hf_config().vision_config def get_hf_processor(self): - processor = self.ctx.get_hf_processor(AriaProcessor) - - # Patch for https://github.com/huggingface/transformers/issues/35768 - processor.tokenizer.image_token = "<|img|>" - processor.image_token = "<|img|>" - - return processor + return self.ctx.get_hf_processor(AriaProcessor) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} @@ -381,7 +375,7 @@ def get_dummy_processor_inputs( } hf_processor = self.info.get_hf_processor() - image_token: str = hf_processor.image_token # type: ignore + image_token: str = hf_processor.tokenizer.image_token # type: ignore return ProcessorInputs( prompt_text=image_token * num_images, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index f5c796b1aca..a8e6f3b6c6d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -14,12 +14,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, - NestedTensors, PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -481,30 +481,13 @@ def _get_prompt_replacements( PromptReplacement( modality="image", target="", - replacement="" * num_image_tokens + "", + replacement=PromptReplacementDetails( + full="" * num_image_tokens + "", + features="" * num_image_tokens, + ), ) ] - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputs: - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) - - # Only tokens should be considered as placeholders, - # so we ignore the trailing bos_token - result["mm_placeholders"] = { - modality: [ - PlaceholderRange(offset=p["offset"], length=p["length"] - 1) - for p in ps - ] - for modality, ps in result["mm_placeholders"].items() - } - - return result - @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, info=Blip2ProcessingInfo, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index e2207865a69..2df50b3c4ef 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -28,12 +28,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, - NestedTensors, PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -141,39 +141,23 @@ def _get_prompt_replacements( out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_tokens = processor.image_token * self.info.get_num_image_tokens() return [ PromptReplacement( modality="image", target="", - replacement="".join([ - processor.image_start_token, - processor.image_token * self.info.get_num_image_tokens(), - processor.image_end_token, - ]), + replacement=PromptReplacementDetails( + full="".join([ + processor.image_start_token, + image_tokens, + processor.image_end_token, + ]), + features=image_tokens, + ), ) ] - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputs: - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) - - # Only tokens should be considered as placeholders, - # so we ignore the image_start_token and image_end_token - result["mm_placeholders"] = { - modality: [ - PlaceholderRange(offset=p["offset"] + 1, - length=p["length"] - 2) for p in ps - ] - for modality, ps in result["mm_placeholders"].items() - } - - return result - class ChameleonLayerNorm(nn.LayerNorm): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 3f16d3ccbd0..4865eb17036 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,7 +16,7 @@ """ PyTorch Fuyu model.""" import math from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) + TypedDict) import torch import torch.nn as nn @@ -30,13 +30,13 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, - NestedTensors, PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -215,9 +215,13 @@ def get_replacement_fuyu(item_idx: int): image_width=image_size.width, image_height=image_size.height, ) + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + + [_NEWLINE_TOKEN_ID]) * nrows - return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows + - [bos_token_id]) + return PromptReplacementDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ) return [ PromptReplacement( @@ -227,26 +231,6 @@ def get_replacement_fuyu(item_idx: int): ) ] - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputs: - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) - - # Only |SPEAKER| (image) tokens should be considered as placeholders, - # so we ignore the trailing bos_token_id - result["mm_placeholders"] = { - modality: [ - PlaceholderRange(offset=p["offset"], length=p["length"] - 1) - for p in ps - ] - for modality, ps in result["mm_placeholders"].items() - } - - return result - @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, info=FuyuProcessingInfo, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index dd3b0b35c92..0fcda81da28 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -30,15 +30,19 @@ VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, - NestedTensors, PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) +# yapf conflicts with isort for this block +# yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BoundPromptReplacement, - PlaceholderInfo, PromptReplacement) + PlaceholderFeaturesInfo, + PromptReplacement, + PromptReplacementDetails) +# yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -437,7 +441,12 @@ def get_replacement_phi3v(item_idx: int): processor=hf_processor, ) - return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id] + image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens + + return PromptReplacementDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ) num_images = mm_items.get_count("image", strict=False) @@ -454,7 +463,7 @@ def _apply_prompt_replacements( token_ids: list[int], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids=token_ids, mm_prompt_repls=mm_prompt_repls, @@ -467,11 +476,11 @@ def _apply_prompt_replacements( token_ids = [token_ids[0], *token_ids[2:]] placeholders = { modality: [ - PlaceholderInfo( + PlaceholderFeaturesInfo( modality=p.modality, item_idx=p.item_idx, start_idx=p.start_idx - 1, - replacement=p.replacement, + tokens=p.tokens, ) for p in ps ] for modality, ps in placeholders.items() @@ -479,26 +488,6 @@ def _apply_prompt_replacements( return token_ids, text, placeholders - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputs: - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) - - # Only <|image|> tokens should be considered as placeholders, - # so we ignore the trailing bos_token_id - result["mm_placeholders"] = { - modality: [ - PlaceholderRange(offset=p["offset"], length=p["length"] - 1) - for p in ps - ] - for modality, ps in result["mm_placeholders"].items() - } - - return result - @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, info=Phi3VProcessingInfo, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 9cb8f83ad78..c21cb815b52 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -36,13 +36,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, - NestedTensors, PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -216,11 +216,16 @@ def get_replacement_qwen2_audio(item_idx: int): f"The audio {audio} (len={len(audio)}) is too short " "to be represented inside the model") - return "".join([ - audio_bos_token, - audio_token * num_placeholders, - audio_eos_token, - ]) + audio_tokens = audio_token * num_placeholders + + return PromptReplacementDetails( + full="".join([ + audio_bos_token, + audio_tokens, + audio_eos_token, + ]), + features=audio_tokens, + ) return [ PromptReplacement( @@ -240,26 +245,6 @@ def _always_apply_prompt_replacements(self) -> bool: # tokens than the number of audio items) return not hasattr(self.info.get_hf_processor(), "audio_token") - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputs: - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) - - # Only <|AUDIO|> tokens should be considered as placeholders, - # so we ignore the audio_bos_token and audio_eos_token - result["mm_placeholders"] = { - modality: [ - PlaceholderRange(offset=p["offset"] + 1, - length=p["length"] - 2) for p in ps - ] - for modality, ps in result["mm_placeholders"].items() - } - - return result - @MULTIMODAL_REGISTRY.register_processor( Qwen2AudioMultiModalProcessor, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index ff02bcc8e1f..a0993f377de 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,7 +1,8 @@ import re from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence +from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, + Sequence) from dataclasses import dataclass, field from functools import lru_cache from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, @@ -31,6 +32,24 @@ _PromptSeq = Union[str, list[int]] +@dataclass +class PromptReplacementDetails: + full: _PromptSeq + """The full replacement.""" + + features: _PromptSeq + """ + The part of the replacement that corresponds to placeholder feature tokens. + """ + + @staticmethod + def from_seq(seq: _PromptSeq): + return PromptReplacementDetails(full=seq, features=seq) + + +_PromptRepl = Union[_PromptSeq, PromptReplacementDetails] + + @dataclass class PromptReplacement: """ @@ -43,8 +62,8 @@ class PromptReplacement: target: _PromptSeq """The token sequence (or text) to find and replace.""" - replacement: Union[Callable[[int], _PromptSeq], - _PromptSeq] = field(repr=False) + replacement: Union[Callable[[int], _PromptRepl], + _PromptRepl] = field(repr=False) """ Given the index of the processed item within :attr:`modality`, output the replacement token sequence (or text). @@ -112,6 +131,14 @@ class _BoundPromptSequence: _text: Optional[str] _token_ids: Optional[list[int]] + @staticmethod + def from_seq(tokenizer: AnyTokenizer, seq: _PromptSeq): + return _BoundPromptSequence( + tokenizer=tokenizer, + _text=seq if isinstance(seq, str) else None, + _token_ids=seq if isinstance(seq, list) else None, + ) + def __post_init__(self) -> None: if self._text is None and self._token_ids is None: raise ValueError("At least one of 'text' and 'token_ids' must be " @@ -134,6 +161,12 @@ def token_ids(self) -> list[int]: return self._token_ids +@dataclass +class _BoundPromptReplacementGroup: + full: _BoundPromptSequence + features: _BoundPromptSequence + + @dataclass class BoundPromptReplacement: """ @@ -145,24 +178,18 @@ class BoundPromptReplacement: modality: str _target: _PromptSeq - _replacement: Union[Callable[[int], _PromptSeq], - _PromptSeq] = field(repr=False) + _replacement: Union[Callable[[int], _PromptRepl], + _PromptRepl] = field(repr=False) def __post_init__(self) -> None: - self._replacement_cache = dict[int, _BoundPromptSequence]() + self._replacement_cache = dict[int, _BoundPromptReplacementGroup]() @property def target(self) -> _BoundPromptSequence: """The token sequence (or text) to find and replace.""" - target = self._target + return _BoundPromptSequence.from_seq(self.tokenizer, self._target) - return _BoundPromptSequence( - tokenizer=self.tokenizer, - _text=target if isinstance(target, str) else None, - _token_ids=target if isinstance(target, list) else None, - ) - - def get_replacement(self, item_idx: int) -> _BoundPromptSequence: + def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup: """ Given the index of the processed item within :attr:`modality`, output the replacement token sequence (or text). @@ -177,10 +204,16 @@ def get_replacement(self, item_idx: int) -> _BoundPromptSequence: else: cache_key = None - bound_replacement = _BoundPromptSequence( - tokenizer=self.tokenizer, - _text=replacement if isinstance(replacement, str) else None, - _token_ids=replacement if isinstance(replacement, list) else None, + if not isinstance(replacement, PromptReplacementDetails): + replacement = PromptReplacementDetails.from_seq(replacement) + + bound_full = _BoundPromptSequence.from_seq(self.tokenizer, + replacement.full) + bound_features = _BoundPromptSequence.from_seq(self.tokenizer, + replacement.features) + bound_replacement = _BoundPromptReplacementGroup( + full=bound_full, + features=bound_features, ) if cache_key is not None: @@ -197,7 +230,7 @@ class _TokenMatch(NamedTuple): def iter_token_matches( token_ids: list[int], match_ids: list[int], -) -> Iterable[_TokenMatch]: +) -> Generator[_TokenMatch]: """ Yield each occurrence of :code:`match_ids` in :code:`token_ids`. @@ -272,15 +305,15 @@ def end_idx(self) -> int: @dataclass -class PlaceholderInfo: +class PlaceholderFeaturesInfo: modality: str item_idx: int start_idx: int - replacement: list[int] + tokens: list[int] @property def length(self) -> int: - return len(self.replacement) + return len(self.tokens) def to_range(self) -> PlaceholderRange: return PlaceholderRange( @@ -362,10 +395,10 @@ def _replace_matches( replacement = repl_info.get_replacement(item_idx) if isinstance(prompt, str): - repl_seq = replacement.text + repl_seq = replacement.full.text out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) else: - repl_seq = replacement.token_ids + repl_seq = replacement.full.token_ids out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) prev_end_idx = end_idx @@ -408,7 +441,7 @@ def _iter_placeholders( mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], prompt: list[int], mm_item_counts: Mapping[str, int], -) -> Iterable[PlaceholderInfo]: +) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in :code:`prompt`. @@ -432,23 +465,33 @@ def _iter_placeholders( for repl_info in modality_repls: replacement = repl_info.get_replacement(item_idx) - repl_tokens = replacement.token_ids - repl_len = len(repl_tokens) - end_idx = start_idx + repl_len + repl_tokens_full = replacement.full.token_ids + repl_len_full = len(repl_tokens_full) + end_idx_full = start_idx + repl_len_full - if repl_len == 0 or end_idx > prompt_len: + if repl_len_full == 0 or end_idx_full > prompt_len: continue - if prompt[start_idx:end_idx] == repl_tokens: - yield PlaceholderInfo( - modality=modality, - item_idx=item_idx, - start_idx=start_idx, - replacement=repl_tokens, - ) + if prompt[start_idx:end_idx_full] == repl_tokens_full: + repl_tokens_feat = replacement.features.token_ids + + try: + match = next( + iter_token_matches(repl_tokens_full, + repl_tokens_feat)) + yield PlaceholderFeaturesInfo( + modality=modality, + item_idx=item_idx, + start_idx=start_idx + match.start_idx, + tokens=repl_tokens_feat, + ) + except StopIteration: + raise AssertionError( + f"{repl_tokens_feat=} should be a " + f"subsequence of {repl_tokens_full=}") from None # Exclude overlapping matches - start_idx = end_idx + start_idx = end_idx_full item_idx_by_modality[modality] += 1 found = True break @@ -464,7 +507,7 @@ def find_mm_placeholders( mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], prompt: list[int], mm_item_counts: Mapping[str, int], -) -> Mapping[str, list[PlaceholderInfo]]: +) -> Mapping[str, list[PlaceholderFeaturesInfo]]: it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) return dict(full_groupby_modality(it)) @@ -679,7 +722,7 @@ def _find_mm_placeholders( mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], new_token_ids: list[int], mm_item_counts: Mapping[str, int], - ) -> Mapping[str, list[PlaceholderInfo]]: + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: return find_mm_placeholders(mm_prompt_repls, new_token_ids, mm_item_counts) @@ -948,7 +991,7 @@ def _apply_prompt_replacements( token_ids: list[int], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() mm_token_matches = { @@ -1037,7 +1080,7 @@ def _validate_mm_kwargs( def _validate_mm_placeholders( self, - mm_placeholders: Mapping[str, list[PlaceholderInfo]], + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], mm_item_counts: Mapping[str, int], *, allow_missing: bool = False,