diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_next.py b/tests/models/decoder_only/vision_language/processing/test_llava_next.py index 9fa6a8a10a0..689d17be818 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_next.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_next.py @@ -4,24 +4,17 @@ import pytest from PIL import Image from pqdm.threads import pqdm -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.utils import cached_get_tokenizer from ....utils import build_model_context -# Fixtures lazy import to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_llava_next(): - from vllm.model_executor.models.llava_next import ( - LlavaNextMultiModalProcessor) - return LlavaNextMultiModalProcessor - - def _validate_image_prompt_replacements_one( - processor, + processor: BaseMultiModalProcessor, num_imgs: int, failed_size_excs: list[tuple[ImageSize, Exception]], image_size: ImageSize, @@ -78,20 +71,17 @@ def _test_image_prompt_replacements( @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_prompt_replacements_regression( - processor_for_llava_next, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_regression(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_next(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), (488, 183), (2560, 1669)] @@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression( "Comment this out to run it manually.") @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("num_imgs", [1]) -def test_processor_prompt_replacements_all( - processor_for_llava_next, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_all(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_next(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) seen_aspect_ratios = set[float]() image_sizes = list[ImageSize]() diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py index d4cdffa210b..a033354f0e9 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py @@ -4,24 +4,17 @@ import pytest from PIL import Image from pqdm.threads import pqdm -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.utils import cached_get_tokenizer from ....utils import build_model_context -# Fixtures lazy import to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_llava_onevision(): - from vllm.model_executor.models.llava_onevision import ( - LlavaOnevisionMultiModalProcessor) - return LlavaOnevisionMultiModalProcessor - - def _validate_image_prompt_replacements_one( - processor, + processor: BaseMultiModalProcessor, num_imgs: int, failed_size_excs: list[tuple[ImageSize, Exception]], image_size: ImageSize, @@ -77,20 +70,17 @@ def _test_image_prompt_replacements( @pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_prompt_replacements_regression( - processor_for_llava_onevision, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_regression(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_onevision(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), (488, 183), (2560, 1669)] @@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression( @pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1]) -def test_processor_prompt_replacements_all( - processor_for_llava_onevision, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_all(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_onevision(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) seen_aspect_ratios = set[float]() image_sizes = list[ImageSize]() diff --git a/tests/models/decoder_only/vision_language/processing/test_phi3v.py b/tests/models/decoder_only/vision_language/processing/test_phi3v.py index 249045b3c04..c5b77260c65 100644 --- a/tests/models/decoder_only/vision_language/processing/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/processing/test_phi3v.py @@ -1,21 +1,13 @@ """Tests for phi3v's multimodal preprocessing kwargs.""" import pytest -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext -from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer from .....conftest import _ImageAssets from ....utils import build_model_context -# Wrap lazy imports to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_phi3v(): - from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor - return Phi3VMultiModalProcessor - - @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) # yapf: disable @pytest.mark.parametrize( @@ -29,7 +21,6 @@ def processor_for_phi3v(): # yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_override( - processor_for_phi3v, image_assets: _ImageAssets, model_id: str, mm_processor_kwargs: dict[str, int], @@ -37,21 +28,26 @@ def test_processor_override( num_imgs: int, ): """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Avoid initializing CUDA early + from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID + ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, trust_remote_code=True, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=tokenizer, + ) # Build the image str / prompt based on the number of images we pass img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processor = processor_for_phi3v(ctx) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size diff --git a/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py index b9ac887edf9..0d54802f2b7 100644 --- a/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py @@ -1,19 +1,12 @@ import pytest -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer from .....conftest import _ImageAssets from ....utils import build_model_context -# Fixtures lazy import to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_qwen2_vl(): - from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor - return Qwen2VLMultiModalProcessor - - @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # yapf: disable @pytest.mark.parametrize( @@ -24,7 +17,6 @@ def processor_for_qwen2_vl(): # yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_override( - processor_for_qwen2_vl, image_assets: _ImageAssets, model_id: str, mm_processor_kwargs: dict[str, object], @@ -39,18 +31,20 @@ def test_processor_override( mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=tokenizer, + ) # Build the image str / prompt based on the number of images we pass prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processor = processor_for_qwen2_vl(ctx) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size - hf_processor = processor._get_hf_processor(**mm_processor_kwargs) + hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 75d878217b6..d98bd9736b6 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -10,12 +10,17 @@ from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, - _PlaceholderInfo, find_mm_placeholders, +# yapf conflicts with isort for this block +# yapf: disable +from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache, + PromptReplacement, + find_mm_placeholders, find_text_matches, find_token_matches, iter_token_matches, replace_text_matches, replace_token_matches) +# yapf: enable +from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby @@ -431,7 +436,7 @@ def test_find_replace_tokens( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], { "pattern_1": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=0, start_idx=6, @@ -445,13 +450,13 @@ def test_find_replace_tokens( [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], { "pattern_1": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=0, start_idx=1, replacement=[32000, 32000], ), - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=1, start_idx=5, @@ -459,7 +464,7 @@ def test_find_replace_tokens( ), ], "pattern_3": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_3", item_idx=0, start_idx=7, @@ -472,13 +477,13 @@ def test_find_replace_tokens( [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], { "pattern_1": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=0, start_idx=1, replacement=[32000, 32000], ), - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=1, start_idx=3, @@ -486,7 +491,7 @@ def test_find_replace_tokens( ), ], "pattern_3": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_3", item_idx=0, start_idx=6, @@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): revision=None, limit_mm_per_prompt=limit_mm_per_prompt, ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] - ctx = InputProcessingContext( + processor = MULTIMODAL_REGISTRY.create_processor( model_config, tokenizer=cached_get_tokenizer(model_config.tokenizer), ) - - processor = processor_factory(ctx, cache=None) - profiler = processor.profiling_info + profiler = MultiModalProfiler(processor) mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) - profiler.get_supported_mm_limits = mock_supported_mm_limits + processor.info.get_supported_mm_limits = mock_supported_mm_limits if is_valid: exc_ctx = nullcontext() @@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): exc_ctx = pytest.raises(ValueError, match="this model only supports") with exc_ctx: - profiler.get_mm_limits() + profiler.get_dummy_data(model_config.max_model_len) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): revision=None, limit_mm_per_prompt=limit_mm_per_prompt, ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] - ctx = InputProcessingContext( + processor = MULTIMODAL_REGISTRY.create_processor( model_config, tokenizer=cached_get_tokenizer(model_config.tokenizer), ) - processor = processor_factory(ctx, cache=None) - rng = np.random.RandomState(0) image = _rand_img(rng, min_wh=128, max_wh=256) if num_images == 0: @@ -681,9 +678,9 @@ def _test_processing_cache_correctness( hf_overrides=hf_overrides, limit_mm_per_prompt=limit_mm_per_prompt, ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] ctx = InputProcessingContext( model_config, tokenizer=cached_get_tokenizer(model_config.tokenizer), @@ -691,8 +688,9 @@ def _test_processing_cache_correctness( # Ensure that it can fit all of the data cache = ProcessingCache(capacity=1 << 30) - baseline_processor = processor_factory(ctx, cache=None) - cached_processor = processor_factory(ctx, cache=cache) + baseline_processor = factories.build_processor(ctx, cache=None) + cached_processor = factories.build_processor(ctx, cache=cache) + dummy_inputs = baseline_processor.dummy_inputs rng = np.random.RandomState(0) @@ -724,7 +722,7 @@ def _test_processing_cache_correctness( } mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = baseline_processor.profiling_info.get_dummy_processor_inputs( + prompt = dummy_inputs.get_dummy_processor_inputs( model_config.max_model_len, mm_counts, ).prompt_text diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 06dfebbb955..ac64edfd4ec 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -2,13 +2,17 @@ import torch -from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - LlavaMultiModalProcessor) +from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder, + LlavaForConditionalGeneration, + LlavaMultiModalProcessor, + LlavaProcessingInfo) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor, + info=LlavaProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index b362ee0cac3..6ddc1eb76f1 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2 +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_info_once, print_warning_once diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2d9d024e03e..b22b3f1594f 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -323,6 +323,7 @@ def dummy_data_for_profiling( # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer if mm_registry.has_processor(model_config): @@ -331,7 +332,8 @@ def dummy_data_for_profiling( trust_remote_code=model_config.trust_remote_code, ) processor = mm_registry.create_processor(model_config, tokenizer) - dummy_data = processor.get_dummy_data(seq_len) + profiler = MultiModalProfiler(processor) + dummy_data = profiler.get_dummy_data(seq_len) else: model_cls, _ = get_model_architecture(model_config) if is_encoder_data: diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 2e649f10c07..089062ab53f 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -23,10 +23,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) @@ -445,33 +445,33 @@ def build_mm_projector(config: PretrainedConfig): ) -class AriaProcessingMixin(ProcessingMixin): +class AriaProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config() - def _get_vision_config(self) -> AriaVisionConfig: - return self._get_hf_config().vision_config - - def _get_num_image_tokens(self) -> int: - hf_config = self._get_hf_config() - return max(hf_config.projector_patch_to_query_dict.values()) - - -class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo): + def get_vision_config(self) -> AriaVisionConfig: + return self.get_hf_config().vision_config def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_num_image_tokens()} + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return max(hf_config.projector_patch_to_query_dict.values()) + + +class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - vision_config = self._get_vision_config() + vision_config = self.info.get_vision_config() max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) @@ -483,7 +483,7 @@ def get_dummy_processor_inputs( num_images=num_images) } - hf_processor = self._get_hf_processor() + hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token # type: ignore return ProcessorInputs( @@ -492,10 +492,7 @@ def get_dummy_processor_inputs( ) -class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return AriaProfilingInfo(self.ctx) +class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): def _get_mm_fields_config( self, @@ -513,10 +510,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index - num_image_tokens = self._get_num_image_tokens() + num_image_tokens = self.info.get_num_image_tokens() return [ PromptReplacement( @@ -527,7 +524,9 @@ def _get_prompt_replacements( ] -@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, + info=AriaProcessingInfo, + dummy_inputs=AriaDummyInputsBuilder) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ Aria model for conditional generation tasks. diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index fd45783f167..7dfc0b687c6 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -17,10 +17,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) +from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .blip import BlipVisionModel @@ -397,30 +397,30 @@ def forward( return sequence_output -class Blip2ProcessingMixin(ProcessingMixin): +class Blip2ProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(Blip2Config) - def _get_num_image_tokens(self) -> int: - hf_config = self._get_hf_config() - return hf_config.num_query_tokens - - -class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_num_image_tokens()} + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return hf_config.num_query_tokens + + +class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config max_image_size = vision_config.image_size @@ -439,10 +439,7 @@ def get_dummy_processor_inputs( ) -class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Blip2ProfilingInfo(self.ctx) +class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): def _get_mm_fields_config( self, @@ -460,7 +457,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - num_image_tokens = self._get_num_image_tokens() + num_image_tokens = self.info.get_num_image_tokens() return [ PromptReplacement( @@ -491,7 +488,9 @@ def apply( return result -@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, + info=Blip2ProcessingInfo, + dummy_inputs=Blip2DummyInputsBuilder) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 73ed73b61eb..acff926891b 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -30,10 +30,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) +from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once @@ -49,33 +49,34 @@ class ChameleonImagePixelInputs(TypedDict): """Shape: `(batch_size * num_images, num_channels, height, width)`""" -class ChameleonProcessingMixin(ProcessingMixin): +class ChameleonProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(ChameleonConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(ChameleonProcessor) - def _get_num_image_tokens(self) -> int: - processor = self._get_hf_processor() - return processor.image_seq_length - - -class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_num_image_tokens()} + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + processor = self.get_hf_processor() + return processor.image_seq_length + + +class ChameleonDummyInputsBuilder( + BaseDummyInputsBuilder[ChameleonProcessingInfo]): def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - config = self._get_hf_config() + config = self.info.get_hf_config() width = height = config.vq_config.resolution num_images = mm_counts.get("image", 0) @@ -93,11 +94,8 @@ def get_dummy_processor_inputs( ) -class ChameleonMultiModalProcessor(ChameleonProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return ChameleonProfilingInfo(self.ctx) +class ChameleonMultiModalProcessor( + BaseMultiModalProcessor[ChameleonProcessingInfo]): def _get_mm_fields_config( self, @@ -112,7 +110,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - processor = self._get_hf_processor(**hf_processor_mm_kwargs) + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) return [ PromptReplacement( @@ -120,7 +118,7 @@ def _get_prompt_replacements( target="", replacement="".join([ processor.image_start_token, - processor.image_token * self._get_num_image_tokens(), + processor.image_token * self.info.get_num_image_tokens(), processor.image_end_token, ]), ) @@ -916,7 +914,10 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor( + ChameleonMultiModalProcessor, + info=ChameleonProcessingInfo, + dummy_inputs=ChameleonDummyInputsBuilder) class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index c937fcb0978..59af5f0b3ae 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -33,11 +33,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.parse import ImageProcessorItems, ImageSize +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -64,24 +64,38 @@ class FuyuImagePatchInputs(TypedDict): """ -class FuyuProcessingMixin(ProcessingMixin): +class FuyuProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(FuyuConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(FuyuProcessor) - def _get_image_processor(self) -> FuyuImageProcessor: - return self._get_hf_processor().image_processor + def get_image_processor(self) -> FuyuImageProcessor: + return self.get_hf_processor().image_processor - def _get_image_feature_grid_size( + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + target_width, target_height = self.get_image_size_with_most_features() + + max_ncols, max_nrows = self.get_image_feature_grid_size( + image_width=target_width, + image_height=target_height, + ) + max_image_tokens = (max_ncols + 1) * max_nrows + + return {"image": max_image_tokens} + + def get_image_feature_grid_size( self, *, image_width: int, image_height: int, ) -> tuple[int, int]: - image_processor = self._get_image_processor() + image_processor = self.get_image_processor() target_width = image_processor.size["width"] target_height = image_processor.size["height"] @@ -97,34 +111,21 @@ def _get_image_feature_grid_size( nrows = math.ceil(image_height / 30) return ncols, nrows - -class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - target_width, target_height = self._get_image_size_with_most_features() - - max_ncols, max_nrows = self._get_image_feature_grid_size( - image_width=target_width, - image_height=target_height, - ) - max_image_tokens = (max_ncols + 1) * max_nrows - - return {"image": max_image_tokens} - - def _get_image_size_with_most_features(self) -> ImageSize: - image_processor = self._get_image_processor() + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() return ImageSize(width=image_processor.size["width"], height=image_processor.size["height"]) + +class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) mm_data = { @@ -140,10 +141,7 @@ def get_dummy_processor_inputs( ) -class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return FuyuProfilingInfo(self.ctx) +class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): def _call_hf_processor( self, @@ -156,7 +154,7 @@ def _call_hf_processor( # Avoid warning from HF logger for text-only input # Input_ids format: bos_token_id + prompt_token_ids + boa_token_id # Tokenizer won't add boa_token_id by default, we add it manually. - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore prompt_ids = tokenizer.encode(prompt) + [boa_token_id] return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") @@ -196,10 +194,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() eot_token_id = tokenizer.bos_token_id assert isinstance(eot_token_id, int) @@ -207,7 +205,7 @@ def get_replacement_fuyu(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - ncols, nrows = self._get_image_feature_grid_size( + ncols, nrows = self.info.get_image_feature_grid_size( image_width=image_size.width, image_height=image_size.height, ) @@ -244,7 +242,9 @@ def apply( return result -@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, + info=FuyuProcessingInfo, + dummy_inputs=FuyuDummyInputsBuilder) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 305f1364dba..8d94acf3b21 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,7 +1,7 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, Union) + Protocol, Set, Tuple, TypedDict, TypeVar, Union) import torch import torch.nn as nn @@ -25,11 +25,11 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize) + ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingCache, - ProcessingMixin, PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, ProcessingCache, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel @@ -105,34 +105,23 @@ class LlavaLikeProcessor(Protocol): image_token: Final[str] -class BaseLlavaProcessingMixin(ProcessingMixin, ABC): +class BaseLlavaProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self) -> LlavaLikeConfig: + def get_hf_config(self) -> LlavaLikeConfig: return self.ctx.get_hf_config(LlavaConfig) - def _get_vision_encoder_info(self): - return get_vision_encoder_info(self._get_hf_config()) + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) @abstractmethod - def _get_hf_processor(self) -> LlavaLikeProcessor: + def get_hf_processor(self) -> LlavaLikeProcessor: raise NotImplementedError - def _get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - hf_config = self._get_hf_config() - vision_encoder_info = self._get_vision_encoder_info() + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} - return self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, - vision_encoder_info.get_num_image_tokens( - image_width=image_width, - image_height=image_height, - ), - ) + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} def _apply_feature_select_strategy( self, @@ -147,28 +136,42 @@ def _apply_feature_select_strategy( msg = f"Unexpected feature select strategy: {strategy!r}" raise NotImplementedError(msg) + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() -class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_max_image_tokens()} + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) - def _get_image_size_with_most_features(self) -> ImageSize: - vision_encoder_info = self._get_vision_encoder_info() + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() width = height = vision_encoder_info.get_image_size() return ImageSize(width=width, height=height) - def _get_max_image_tokens(self) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_image_tokens( + return self.get_num_image_tokens( image_width=target_width, image_height=target_height, ) + +_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) + + +class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -176,9 +179,10 @@ def get_dummy_processor_inputs( ) -> ProcessorInputs: num_images = mm_counts.get("image", 0) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() image_token = processor.image_token - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() mm_data = { "image": @@ -193,23 +197,13 @@ def get_dummy_processor_inputs( ) -class LlavaProcessingMixin(BaseLlavaProcessingMixin): +class LlavaProcessingInfo(BaseLlavaProcessingInfo): - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaProcessor) -class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo): - pass - - -class BaseLlavaMultiModalProcessor(LlavaProcessingMixin, - BaseMultiModalProcessor): - - # Copied from BaseMultiModalProcessor - @abstractmethod - def _get_profiling_info(self) -> BaseProfilingInfo: - raise NotImplementedError +class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): # Copied from BaseMultiModalProcessor @abstractmethod @@ -226,7 +220,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index def get_replacement(item_idx: int): @@ -237,7 +231,7 @@ def get_replacement(item_idx: int): num_image_tokens = images.get_feature_size(item_idx) else: image_size = images.get_image_size(item_idx) - num_image_tokens = self._get_num_image_tokens( + num_image_tokens = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, ) @@ -253,10 +247,8 @@ def get_replacement(item_idx: int): ] -class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaProfilingInfo(self.ctx) +class LlavaMultiModalProcessor( + BaseLlavaMultiModalProcessor[LlavaProcessingInfo]): def _get_mm_fields_config( self, @@ -269,21 +261,14 @@ def _get_mm_fields_config( ) -class PixtralHFProcessingMixin(BaseLlavaProcessingMixin): +class PixtralHFProcessingInfo(BaseLlavaProcessingInfo): - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(PixtralProcessor) -class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo): - pass - - -class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return PixtralHFProfilingInfo(self.ctx) +class PixtralHFMultiModalProcessor( + BaseMultiModalProcessor[PixtralHFProcessingInfo]): def _call_hf_processor( self, @@ -328,10 +313,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() image_token = processor.image_token image_break_token = processor.image_break_token image_end_token = processor.image_end_token @@ -363,26 +348,40 @@ def get_replacement(item_idx: int): ] +def _build_llava_or_pixtral_hf_info( + ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + hf_config = ctx.get_hf_config(LlavaConfig) + + if isinstance(hf_config.vision_config, PixtralVisionConfig): + return PixtralHFProcessingInfo(ctx) + + return LlavaProcessingInfo(ctx) + + def _build_llava_or_pixtral_hf_processor( - ctx: InputProcessingContext, + info: _I, + dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True, ) -> BaseMultiModalProcessor: - hf_config = ctx.get_hf_config(LlavaConfig) - - if isinstance(hf_config.vision_config, PixtralVisionConfig): + if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( - ctx, + info, + dummy_inputs, # type: ignore + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) + + if isinstance(info, LlavaProcessingInfo): + return LlavaMultiModalProcessor( + info, + dummy_inputs, # type: ignore cache=cache, enable_sanity_checks=enable_sanity_checks, ) - return LlavaMultiModalProcessor( - ctx, - cache=cache, - enable_sanity_checks=enable_sanity_checks, - ) + raise NotImplementedError(type(info)) def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: @@ -460,7 +459,9 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor) +@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor, + info=_build_llava_or_pixtral_hf_info, + dummy_inputs=LlavaDummyInputsBuilder) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -727,11 +728,11 @@ def apply( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index # Assume that it doesn't depend on the image size - num_image_tokens = self._get_num_image_tokens( + num_image_tokens = self.info.get_num_image_tokens( image_width=-1, image_height=-1, ) @@ -796,6 +797,8 @@ def get_replacement_mantis(item_idx: int): # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` -@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, + info=LlavaProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 815456dac2a..fda4f22d366 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,6 +1,7 @@ +from abc import abstractmethod from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, Union) + Protocol, Set, Tuple, TypedDict, TypeVar, Union) import torch import torch.nn as nn @@ -16,13 +17,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.parse import ImageSize -from vllm.multimodal.profiling import BaseProfilingInfo from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin, - BaseLlavaProfilingInfo, LlavaLikeConfig, +from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, + LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, @@ -65,23 +65,23 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): image_grid_pinpoints: Final[list[list[int]]] -class LlavaNextProcessingMixin(BaseLlavaProcessingMixin): +class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): - def _get_hf_config(self) -> LlavaNextLikeConfig: + def get_hf_config(self) -> LlavaNextLikeConfig: return self.ctx.get_hf_config(LlavaNextConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaNextProcessor) # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113 - def _get_num_image_tokens( + def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: - hf_config = self._get_hf_config() - vision_encoder_info = self._get_vision_encoder_info() + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() base_feature_size = self._apply_feature_select_strategy( hf_config.vision_feature_select_strategy, @@ -140,16 +140,13 @@ def _get_num_unpadded_features( return (unpadded_features, newline_features) - -class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo): - - def _get_image_size_with_most_features(self) -> ImageSize: - hf_config = self._get_hf_config() + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() largest_feature_size, largest_feature_pinpoint = 0, None for (height, width) in hf_config.image_grid_pinpoints: - feat_size = self._get_num_image_tokens(image_width=width, - image_height=height) + feat_size = self.get_num_image_tokens(image_width=width, + image_height=height) if feat_size > largest_feature_size: largest_feature_size = feat_size largest_feature_pinpoint = ImageSize(width=width, @@ -161,11 +158,23 @@ def _get_image_size_with_most_features(self) -> ImageSize: return largest_feature_pinpoint -class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin, - BaseLlavaMultiModalProcessor): +_I = TypeVar("_I", bound=LlavaNextProcessingInfo) + + +class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): + + # Copied from BaseMultiModalProcessor + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError + - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaNextProfilingInfo(self.ctx) +class LlavaNextMultiModalProcessor( + BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]): def _get_mm_fields_config( self, @@ -179,7 +188,9 @@ def _get_mm_fields_config( ) -@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, + info=LlavaNextProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 6e82cee1c95..5be85d7c0f0 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -17,12 +17,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems, - VideoProcessorItems) +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -47,33 +46,52 @@ class LlavaNextVideoPixelInputs(TypedDict): """ -class LlavaNextVideoProcessingMixin(ProcessingMixin): +class LlavaNextVideoProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(LlavaNextVideoConfig) - def _get_vision_encoder_info(self): - return get_vision_encoder_info(self._get_hf_config()) + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaNextVideoProcessor) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"video": 1} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + target_width, target_height = self.get_image_size_with_most_features() + + max_video_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features(seq_len), + ) + + return {"video": max_video_tokens} + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + def _get_num_frame_tokens( self, *, image_width: int, image_height: int, ) -> int: - hf_config = self._get_hf_config() + hf_config = self.get_hf_config() spatial_pool_stride = hf_config.spatial_pool_stride - vision_encoder_info = self._get_vision_encoder_info() + vision_encoder_info = self.get_vision_encoder_info() patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length - def _get_num_video_tokens( + def get_num_video_tokens( self, *, image_width: int, @@ -87,37 +105,14 @@ def _get_num_video_tokens( return num_frame_tokens * num_frames - -class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin, - BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"video": 1} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - target_width, target_height = self._get_image_size_with_most_features() - - max_video_tokens = self._get_num_video_tokens( - image_width=target_width, - image_height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), - ) - - return {"video": max_video_tokens} - - def _get_image_size_with_most_features(self) -> ImageSize: - vision_encoder_info = self._get_vision_encoder_info() - width = height = vision_encoder_info.get_image_size() - return ImageSize(width=width, height=height) - def _get_max_video_frames(self, max_tokens: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 - next_max_tokens = self._get_num_video_tokens( + next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, @@ -130,7 +125,7 @@ def _get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def _get_dummy_num_frames(self, seq_len: int) -> int: + def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_videos = mm_config.limit_per_prompt.get("video", 1) @@ -138,6 +133,10 @@ def _get_dummy_num_frames(self, seq_len: int) -> int: return max(max_total_frames // max(max_videos, 1), 1) + +class LlavaNextVideoDummyInputsBuilder( + BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -145,16 +144,20 @@ def get_dummy_processor_inputs( ) -> ProcessorInputs: num_videos = mm_counts.get("video", 0) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() video_token = processor.video_token - target_width, target_height = self._get_image_size_with_most_features() + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len) mm_data = { "video": self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=target_num_frames, num_videos=num_videos, ) } @@ -165,11 +168,8 @@ def get_dummy_processor_inputs( ) -class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaNextVideoProfilingInfo(self.ctx) +class LlavaNextVideoMultiModalProcessor( + BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): def _get_mm_fields_config( self, @@ -184,7 +184,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index def get_replacement(item_idx: int): @@ -195,7 +195,7 @@ def get_replacement(item_idx: int): num_video_tokens = videos.get_feature_size(item_idx) else: image_size = videos.get_frame_size(item_idx) - num_video_tokens = self._get_num_video_tokens( + num_video_tokens = self.info.get_num_video_tokens( image_width=image_size.width, image_height=image_size.height, num_frames=videos.get_num_frames(item_idx), @@ -269,7 +269,11 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextVideoMultiModalProcessor, + info=LlavaNextVideoProcessingInfo, + dummy_inputs=LlavaNextVideoDummyInputsBuilder, +) class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index b5e3edba1f0..78a47e64d9a 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -17,19 +17,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems, + VideoProcessorItems) +from vllm.multimodal.processing import PromptReplacement +from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava -from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor, - LlavaNextProcessingMixin) +from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava +from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, + LlavaNextProcessingInfo) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -89,14 +90,23 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): video_token_index: Final[int] -class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin): +class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): - def _get_hf_config(self) -> LlavaOnevisionLikeConfig: + def get_hf_config(self) -> LlavaOnevisionLikeConfig: return self.ctx.get_hf_config(LlavaOnevisionConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaOnevisionProcessor) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "video": self.get_max_video_tokens(seq_len), + } + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # with additional logic afterwards taken from LlavaOnevisionProcessor def _get_num_unpadded_features( @@ -141,16 +151,16 @@ def _get_num_frame_tokens( image_width: int, image_height: int, ) -> int: - hf_config = self._get_hf_config() + hf_config = self.get_hf_config() spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) - vision_encoder_info = self._get_vision_encoder_info() + vision_encoder_info = self.get_vision_encoder_info() patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length - def _get_num_video_tokens( + def get_num_video_tokens( self, *, image_width: int, @@ -164,43 +174,14 @@ def _get_num_video_tokens( return num_frame_tokens * num_frames + 1 # Newline token - -class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin, - BaseLlavaProfilingInfo): - - def _get_image_size_with_most_features(self) -> ImageSize: - hf_config = self._get_hf_config() - largest_feature_size, largest_feature_pinpoint = 0, None - for (height, width) in hf_config.image_grid_pinpoints: - feat_size = self._get_num_image_tokens(image_width=width, - image_height=height) - if feat_size > largest_feature_size: - largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) - - if largest_feature_size == 0 or largest_feature_pinpoint is None: - raise ValueError("Cannot have a largest feature size of 0!") - - return largest_feature_pinpoint - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return { - "image": self._get_max_image_tokens(), - "video": self._get_max_video_tokens(seq_len), - } - def _get_max_video_frames(self, max_tokens: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 - next_max_tokens = self._get_num_video_tokens( + next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, @@ -213,12 +194,12 @@ def _get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def _get_dummy_num_frames(self, seq_len: int) -> int: + def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) - max_image_tokens = self._get_max_image_tokens() * max_images + max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = min(max_total_frames // max(max_videos, 1), @@ -226,15 +207,19 @@ def _get_dummy_num_frames(self, seq_len: int) -> int: return max(max_frames_per_video, 1) - def _get_max_video_tokens(self, seq_len: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_video_tokens(self, seq_len: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_video_tokens( + return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=self.get_num_frames_with_most_features(seq_len), ) + +class LlavaOnevisionDummyInputsBuilder( + LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -243,10 +228,14 @@ def get_dummy_processor_inputs( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() image_token = processor.image_token video_token = processor.video_token - target_width, target_height = self._get_image_size_with_most_features() + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len) mm_data = { "image": @@ -257,7 +246,7 @@ def get_dummy_processor_inputs( self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=target_num_frames, num_videos=num_videos, ) } @@ -268,11 +257,8 @@ def get_dummy_processor_inputs( ) -class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin, - LlavaNextMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaOnevisionProfilingInfo(self.ctx) +class LlavaOnevisionMultiModalProcessor( + BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): def _get_mm_fields_config( self, @@ -303,7 +289,7 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() video_token = processor.video_token # LLaVA-OneVision processor doesn't support multiple videos @@ -345,7 +331,7 @@ def _get_prompt_replacements( out_mm_kwargs=out_mm_kwargs, ) - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index def get_video_replacement(item_idx: int): @@ -356,7 +342,7 @@ def get_video_replacement(item_idx: int): num_video_tokens = videos.get_feature_size(item_idx) else: image_size = videos.get_frame_size(item_idx) - num_video_tokens = self._get_num_video_tokens( + num_video_tokens = self.info.get_num_video_tokens( image_width=image_size.width, image_height=image_size.height, num_frames=videos.get_num_frames(item_idx), @@ -393,7 +379,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor( + LlavaOnevisionMultiModalProcessor, + info=LlavaOnevisionProcessingInfo, + dummy_inputs=LlavaOnevisionDummyInputsBuilder) class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index c8418c14e5f..a1b1af35604 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -34,13 +34,12 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize) + ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement, - _BoundPromptReplacement, - _PlaceholderInfo) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, + BoundPromptReplacement, + PlaceholderInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -302,9 +301,9 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -class Phi3VProcessingMixin(ProcessingMixin): +class Phi3VProcessingInfo(BaseProcessingInfo): - def _get_hf_processor( + def get_hf_processor( self, *, num_crops: Optional[int] = None, @@ -314,39 +313,42 @@ def _get_hf_processor( return self.ctx.get_hf_processor() - def _get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - processor = self._get_hf_processor() - - return processor.calc_num_image_tokens_from_image_size( # type: ignore - width=image_width, - height=image_height, - ) - - -class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = self.get_image_size_with_most_features() - max_image_tokens = self._get_num_image_tokens( + max_image_tokens = self.get_num_image_tokens( image_width=target_width, image_height=target_height, + processor=None, ) return {"image": max_image_tokens} - def _get_image_size_with_most_features(self) -> ImageSize: + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.calc_num_image_tokens_from_image_size( # type: ignore + width=image_width, + height=image_height, + ) + + def get_image_size_with_most_features(self) -> ImageSize: # Result in the max possible feature size (h:w = 16:1) return ImageSize(height=8000, width=50) + +class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -354,7 +356,8 @@ def get_dummy_processor_inputs( ) -> ProcessorInputs: num_images = mm_counts.get("image", 0) - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() mm_data = { "image": @@ -363,7 +366,7 @@ def get_dummy_processor_inputs( num_images=num_images) } - hf_processor = self._get_hf_processor() + hf_processor = self.info.get_hf_processor() image_tokens: list[str] = hf_processor.img_tokens # type: ignore return ProcessorInputs( @@ -372,10 +375,7 @@ def get_dummy_processor_inputs( ) -class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Phi3VProfilingInfo(self.ctx) +class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): def _call_hf_processor( self, @@ -416,10 +416,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() bos_token_id = tokenizer.bos_token_id assert isinstance(bos_token_id, int) @@ -431,9 +431,10 @@ def get_replacement_phi3v(item_idx: int): num_image_tokens = images.get_feature_size(item_idx) else: image_size = images.get_image_size(item_idx) - num_image_tokens = self._get_num_image_tokens( + num_image_tokens = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, + processor=hf_processor, ) return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id] @@ -451,9 +452,9 @@ def get_replacement_phi3v(item_idx: int): def _apply_prompt_replacements( self, token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids=token_ids, mm_prompt_repls=mm_prompt_repls, @@ -466,7 +467,7 @@ def _apply_prompt_replacements( token_ids = [token_ids[0], *token_ids[2:]] placeholders = { modality: [ - _PlaceholderInfo( + PlaceholderInfo( modality=p.modality, item_idx=p.item_idx, start_idx=p.start_idx - 1, @@ -499,7 +500,9 @@ def apply( return result -@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, + info=Phi3VProcessingInfo, + dummy_inputs=Phi3VDummyInputsBuilder) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 7012ddc66cd..0dff9595c6c 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -38,11 +38,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser +from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, + MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): return feat_lengths, output_lengths -class Qwen2AudioProcessingMixin(ProcessingMixin): +class Qwen2AudioProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(Qwen2AudioConfig) - def _get_hf_processor( + def get_hf_processor( self, *, # Ignored in initialization @@ -93,36 +93,37 @@ def _get_hf_processor( ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) - def _get_feature_extractor( + def get_feature_extractor( self, *, # Ignored in initialization sampling_rate: Optional[int] = None, ) -> WhisperFeatureExtractor: - hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) feature_extractor = hf_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - -class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - hf_config = self._get_hf_config() + hf_config = self.get_hf_config() max_source_positions = hf_config.audio_config.max_source_positions max_output_lengths = (max_source_positions - 2) // 2 + 1 return {"audio": max_output_lengths} + +class Qwen2AudioDummyInputsBuilder( + BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate @@ -139,14 +140,11 @@ def get_dummy_processor_inputs( ) -class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Qwen2AudioProfilingInfo(self.ctx) +class Qwen2AudioMultiModalProcessor( + BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -161,7 +159,7 @@ def _call_hf_processor( if audios: mm_data["audios"] = audios - feature_extractor = self._get_feature_extractor(**mm_kwargs) + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, @@ -194,7 +192,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() placeholder = hf_config.audio_token_index feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") @@ -234,10 +232,13 @@ def _always_apply_prompt_replacements(self) -> bool: # has already performed processing for multi-audio input when the input # audios are short (the corresponding placeholders may take up fewer # tokens than the number of audio items) - return not hasattr(self._get_hf_processor(), "audio_token") + return not hasattr(self.info.get_hf_processor(), "audio_token") -@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor( + Qwen2AudioMultiModalProcessor, + info=Qwen2AudioProcessingInfo, + dummy_inputs=Qwen2AudioDummyInputsBuilder) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a5c2fb9e84d..8537fec854b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -57,11 +57,10 @@ MultiModalFieldConfig, MultiModalKwargs, NestedTensors, VideoItem) from vllm.multimodal.parse import (ImageSize, ModalityDataItems, - MultiModalDataParser) + MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -709,12 +708,12 @@ def _parse_video_data( return super()._parse_video_data(data) -class Qwen2VLProcessingMixin(ProcessingMixin): +class Qwen2VLProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(Qwen2VLConfig) - def _get_hf_processor( + def get_hf_processor( self, *, min_pixels: Optional[int] = None, @@ -736,18 +735,27 @@ def _get_hf_processor( return hf_processor - def _get_image_processor( + def get_image_processor( self, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, ): - hf_processor = self._get_hf_processor(min_pixels=min_pixels, - max_pixels=max_pixels) + hf_processor = self.get_hf_processor(min_pixels=min_pixels, + max_pixels=max_pixels) image_processor = hf_processor.image_processor # type: ignore assert isinstance(image_processor, Qwen2VLImageProcessor) return image_processor + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "video": self.get_max_video_tokens(seq_len), + } + def _get_vision_info( self, *, @@ -755,15 +763,17 @@ def _get_vision_info( image_height: int, num_frames: int = 1, do_resize: bool = True, + image_processor: Optional[Qwen2VLImageProcessor], ) -> tuple[ImageSize, int]: - hf_config = self._get_hf_config() + if image_processor is None: + image_processor = self.get_image_processor() + + hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = vision_config.temporal_patch_size - image_processor = self._get_image_processor() - if do_resize: resized_height, resized_width = smart_resize( height=image_height, @@ -787,70 +797,65 @@ def _get_vision_info( return preprocessed_size, num_vision_tokens - def _get_num_image_tokens( + def get_num_image_tokens( self, *, image_width: int, image_height: int, + image_processor: Optional[Qwen2VLImageProcessor], ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, + image_processor=image_processor, ) return num_image_tokens - def _get_num_video_tokens( + def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, + image_processor: Optional[Qwen2VLImageProcessor], ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, num_frames=num_frames, + image_processor=image_processor, ) return num_video_tokens - -class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return { - "image": self._get_max_image_tokens(), - "video": self._get_max_video_tokens(seq_len), - } - - def _get_image_size_with_most_features(self) -> ImageSize: + def get_image_size_with_most_features(self) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, + image_processor=None, ) return max_image_size - def _get_max_image_tokens(self) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_image_tokens( + return self.get_num_image_tokens( image_width=target_width, image_height=target_height, + image_processor=None, ) def _get_max_video_frames(self, max_tokens: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 - next_max_tokens = self._get_num_video_tokens( + next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, + image_processor=None, ) if next_max_tokens > max_tokens: @@ -860,12 +865,12 @@ def _get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def _get_dummy_num_frames(self, seq_len: int) -> int: + def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) - max_image_tokens = self._get_max_image_tokens() * max_images + max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) @@ -877,15 +882,19 @@ def _get_dummy_num_frames(self, seq_len: int) -> int: return num_frames - def _get_max_video_tokens(self, seq_len: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_video_tokens(self, seq_len: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_video_tokens( + return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=self.get_num_frames_with_most_features(seq_len), + image_processor=None, ) + +class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -894,10 +903,14 @@ def get_dummy_processor_inputs( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - hf_processor = self._get_hf_processor() + hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token - target_width, target_height = self._get_image_size_with_most_features() + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len) mm_data = { "image": @@ -908,7 +921,7 @@ def get_dummy_processor_inputs( self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=target_num_frames, num_videos=num_videos, ) } @@ -919,11 +932,8 @@ def get_dummy_processor_inputs( ) -class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Qwen2VLProfilingInfo(self.ctx) +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] + ): def _get_data_parser(self) -> MultiModalDataParser: return Qwen2MultiModalDataParser() @@ -934,8 +944,9 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self._get_image_processor(**hf_processor_mm_kwargs) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # image_token and video_token registered @@ -991,7 +1002,9 @@ def _get_mm_fields_config( ) -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): packed_modules_mapping = { diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ecafd157b1d..fada22d685d 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -24,11 +24,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import MultiModalDataParser +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -59,9 +58,9 @@ class UltravoxAudioEmbeddingInputs(TypedDict): UltravoxAudioEmbeddingInputs] -class UltravoxProcessingMixin(ProcessingMixin): +class UltravoxProcessingInfo(BaseProcessingInfo): - def _get_hf_processor( + def get_hf_processor( self, *, # Ignored in initialization @@ -76,37 +75,38 @@ def _get_hf_processor( hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE return hf_processor - def _get_feature_extractor( + def get_feature_extractor( self, *, # Ignored in initialization sampling_rate: Optional[int] = None, ) -> WhisperFeatureExtractor: - hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) audio_processor = hf_processor.audio_processor # type: ignore feature_extractor = audio_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - -class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.get_feature_extractor() max_audio_tokens = math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) return {"audio": max_audio_tokens} + +class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] + ): + def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate @@ -123,14 +123,11 @@ def get_dummy_processor_inputs( ) -class UltravoxMultiModalProcessor(UltravoxProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return UltravoxProfilingInfo(self.ctx) +class UltravoxMultiModalProcessor( + BaseMultiModalProcessor[UltravoxProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -141,7 +138,7 @@ def _call_hf_processor( ) -> BatchFeature: # Text-only input not supported in composite processor if not mm_data: - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() prompt_ids = tokenizer.encode( prompt, @@ -160,7 +157,7 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, @@ -208,7 +205,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) placeholder = hf_processor.audio_token_replacement # type: ignore def get_replacement_ultravox(item_idx: int): @@ -342,7 +339,10 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor, + info=UltravoxProcessingInfo, + dummy_inputs=UltravoxDummyInputsBuilder + ) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 41113cd85bd..c6a30cacebd 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,12 +4,13 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union +from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, + TypeVar, Union) from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from vllm import envs -from vllm.inputs import DummyData, InputProcessingContext +import vllm.envs as envs +from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) @@ -20,7 +21,9 @@ MultiModalInputsV2, MultiModalKwargs, MultiModalKwargsItem, PlaceholderRange) from .parse import MultiModalDataItems, MultiModalDataParser -from .profiling import BaseProfilingInfo + +if TYPE_CHECKING: + from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -46,8 +49,8 @@ class PromptReplacement: if it does not depend on the input. """ - def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement": - return _BoundPromptReplacement( + def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement": + return BoundPromptReplacement( tokenizer=tokenizer, modality=self.modality, _target=self.target, @@ -128,7 +131,7 @@ def token_ids(self) -> list[int]: @dataclass -class _BoundPromptReplacement: +class BoundPromptReplacement: tokenizer: AnyTokenizer = field(repr=False) modality: str @@ -207,7 +210,7 @@ def iter_token_matches( @dataclass(repr=False) class _PromptReplacementMatch(ABC): - prompt_repl: _BoundPromptReplacement + prompt_repl: BoundPromptReplacement @property def modality(self) -> str: @@ -255,7 +258,7 @@ def end_idx(self) -> int: @dataclass -class _PlaceholderInfo: +class PlaceholderInfo: modality: str item_idx: int start_idx: int @@ -274,7 +277,7 @@ def to_range(self) -> PlaceholderRange: def find_token_matches( prompt: list[int], - prompt_repls: Sequence[_BoundPromptReplacement], + prompt_repls: Sequence[BoundPromptReplacement], ) -> list[_PromptReplacementTokenMatch]: """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ @@ -286,7 +289,7 @@ def find_token_matches( def find_text_matches( prompt: str, - prompt_repls: Sequence[_BoundPromptReplacement], + prompt_repls: Sequence[BoundPromptReplacement], ) -> list[_PromptReplacementTextMatch]: """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ @@ -390,9 +393,9 @@ def replace_text_matches( def _iter_modality_placeholders( prompt: list[int], modality: str, - modality_repls: Sequence[_BoundPromptReplacement], + modality_repls: Sequence[BoundPromptReplacement], modal_item_count: int, -) -> Iterable[_PlaceholderInfo]: +) -> Iterable[PlaceholderInfo]: if modal_item_count == 0: return @@ -413,7 +416,7 @@ def _iter_modality_placeholders( continue if prompt[start_idx:end_idx] == repl_tokens: - yield _PlaceholderInfo( + yield PlaceholderInfo( modality=modality, item_idx=item_idx, start_idx=start_idx, @@ -434,10 +437,10 @@ def _iter_modality_placeholders( def _iter_placeholders( - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], prompt: list[int], mm_item_counts: Mapping[str, int], -) -> Iterable[_PlaceholderInfo]: +) -> Iterable[PlaceholderInfo]: """ For each modality, yield each set of placeholder tokens found in :code:`prompt`. @@ -455,10 +458,10 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], prompt: list[int], mm_item_counts: Mapping[str, int], -) -> Mapping[str, list[_PlaceholderInfo]]: +) -> Mapping[str, list[PlaceholderInfo]]: it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) return dict(full_groupby_modality(it)) @@ -524,29 +527,59 @@ def put( self._cache.put(cache_key, output_kwargs) -class ProcessingMixin: - """ - Contains helper functions to perform processing. +class BaseProcessingInfo: + """Base class containing information to perform processing.""" - Not to be confused with :class:`transformers.ProcessorMixin`. - """ - ctx: InputProcessingContext + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__() - def _get_tokenizer(self) -> AnyTokenizer: + self.ctx = ctx + + @property + def model_id(self) -> str: + return self.ctx.model_config.model + + def get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer - def _get_hf_config(self) -> PretrainedConfig: + def get_hf_config(self) -> PretrainedConfig: return self.ctx.get_hf_config() - def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin: + def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: """ Subclasses can override this method to handle specific kwargs from model config or user inputs. """ return self.ctx.get_hf_processor(**kwargs) + @abstractmethod + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + """ + Return the maximum supported number of items for each modality. + + A value of `None` means unlimited number of items. + + Omitting a modality from the returned dictionary means that + it is not supported at all. + """ + raise NotImplementedError + + @abstractmethod + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + """ + Get the maximum possible number of tokens per data item + for each modality. + + The dictionary returned by this method should have the same + keys as that returned by :meth:`get_supported_mm_limits`. + """ + raise NotImplementedError + + +_I = TypeVar("_I", bound=BaseProcessingInfo) -class BaseMultiModalProcessor(ProcessingMixin, ABC): + +class BaseMultiModalProcessor(ABC, Generic[_I]): """ Abstract base class to process multi-modal inputs to be used in vLLM. @@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): """ def __init__(self, - ctx: InputProcessingContext, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True) -> None: super().__init__() - self.ctx = ctx + self.info = info + self.dummy_inputs = dummy_inputs self.cache = cache self.enable_sanity_checks = enable_sanity_checks self.data_parser = self._get_data_parser() - self.profiling_info = self._get_profiling_info() def __call__( self, @@ -585,13 +619,6 @@ def _get_data_parser(self) -> MultiModalDataParser: """ return MultiModalDataParser() - def _get_profiling_info(self) -> BaseProfilingInfo: - """ - Get the profiling information to find the worst-case memory usage of - the model. - """ - raise NotImplementedError - def _to_mm_items( self, mm_data: MultiModalDataDict, @@ -602,7 +629,7 @@ def _to_mm_items( """ mm_items = self.data_parser.parse_mm_data(mm_data) - mm_limits = self.ctx.get_mm_config().limit_per_prompt + mm_limits = self.info.ctx.get_mm_config().limit_per_prompt for modality, items in mm_items.items(): limit = mm_limits.get(modality, 1) if len(items) > limit: @@ -646,19 +673,19 @@ def _get_prompt_replacements( def _find_mm_placeholders( self, - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], new_token_ids: list[int], mm_item_counts: Mapping[str, int], - ) -> Mapping[str, list[_PlaceholderInfo]]: + ) -> Mapping[str, list[PlaceholderInfo]]: return find_mm_placeholders(mm_prompt_repls, new_token_ids, mm_item_counts) def _get_hf_mm_data( self, mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - processor_data = dict[str, Any]() - passthrough_data = dict[str, Any]() + ) -> tuple[Mapping[str, object], Mapping[str, object]]: + processor_data = dict[str, object]() + passthrough_data = dict[str, object]() for items in mm_items.values(): processor_data.update(items.get_processor_data()) @@ -678,8 +705,8 @@ def _call_hf_processor( Call the HF processor on the prompt text and associated multi-modal data. """ - return self.ctx.call_hf_processor( - self._get_hf_processor(**mm_kwargs), + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), mm_kwargs, ) @@ -738,8 +765,8 @@ def _apply_hf_processor_missing( # Some HF processors (e.g. Qwen2-VL) expect corresponding # multi-modal tokens to be in the prompt text - dummy_inputs = self.profiling_info.get_dummy_processor_inputs( - self.ctx.model_config.max_model_len, + dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( + self.info.ctx.model_config.max_model_len, mm_missing_counts, ) @@ -762,7 +789,7 @@ def _cached_apply_hf_processor( caching the results and reusing cached results. """ cache = self.cache - model_id = self.ctx.model_config.model + model_id = self.info.model_id _, passthrough_data = self._get_hf_mm_data(mm_data_items) if cache is None or passthrough_data: @@ -838,8 +865,8 @@ def _cached_apply_hf_processor( def _bind_and_group_repls( self, prompt_repls: list[PromptReplacement], - ) -> dict[str, list[_BoundPromptReplacement]]: - tokenizer = self._get_tokenizer() + ) -> dict[str, list[BoundPromptReplacement]]: + tokenizer = self.info.get_tokenizer() it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) return dict(full_groupby_modality(it)) @@ -859,10 +886,10 @@ def _always_apply_prompt_replacements(self) -> bool: def _apply_prompt_replacements( self, token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: - tokenizer = self._get_tokenizer() + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: + tokenizer = self.info.get_tokenizer() mm_token_matches = { modality: find_token_matches(token_ids, prompt_repls) @@ -950,7 +977,7 @@ def _validate_mm_kwargs( def _validate_mm_placeholders( self, - mm_placeholders: Mapping[str, list[_PlaceholderInfo]], + mm_placeholders: Mapping[str, list[PlaceholderInfo]], mm_item_counts: Mapping[str, int], *, allow_missing: bool = False, @@ -1001,7 +1028,7 @@ def apply( # instead of rehashing. if envs.VLLM_USE_V1: - model_id = self.ctx.model_config.model + model_id = self.info.model_id mm_hashes = { modality: [ MultiModalHasher.hash_kwargs(model_id=model_id, @@ -1046,7 +1073,7 @@ def apply( allow_missing=True, ) - mm_missing_repls = dict[str, list[_BoundPromptReplacement]]() + mm_missing_repls = dict[str, list[BoundPromptReplacement]]() for modality, missing_repl_count in mm_missing_repl_counts.items(): if missing_repl_count == 0: mm_missing_repls[modality] = [] @@ -1059,7 +1086,7 @@ def apply( # If HF processor already inserts placeholder tokens, # there is no need for us to insert them if all(len(repls) == 0 for repls in mm_missing_repls.items()): - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() prompt_text = decode_tokens(tokenizer, prompt_ids) mm_placeholders = hf_mm_placeholders else: @@ -1090,79 +1117,3 @@ def apply( mm_hashes=mm_hashes, mm_placeholders=mm_placeholder_ranges, ) - - def _get_dummy_mm_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalInputsV2: - profiling = self.profiling_info - processor_inputs = profiling.get_dummy_processor_inputs( - seq_len, mm_counts) - - return self.apply( - prompt_text=processor_inputs.prompt_text, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - ) - - def get_dummy_data(self, seq_len: int) -> DummyData: - # Avoid circular import - from vllm.sequence import SequenceData - - profiling = self.profiling_info - mm_counts = profiling.get_mm_limits() - mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len) - if mm_counts.keys() != mm_max_tokens_per_item.keys(): - raise AssertionError( - "The keys returned by `get_supported_mm_limits`" - f"({set(mm_counts.keys())}) should be the same as those " - "returned by `get_mm_max_tokens_per_item` " - f"({set(mm_max_tokens_per_item.keys())})") - - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - prompt_token_ids = mm_inputs["prompt_token_ids"] - placeholders_by_modality = mm_inputs["mm_placeholders"] - - total_placeholders_by_modality = { - modality: sum(item["length"] for item in placeholders) - for modality, placeholders in placeholders_by_modality.items() - } - expected_placeholders_by_modality = { - modality: mm_max_tokens_per_item[modality] * mm_counts[modality] - for modality in placeholders_by_modality - } - if total_placeholders_by_modality != expected_placeholders_by_modality: - raise AssertionError( - f"The processed dummy data has a total of " - f"{total_placeholders_by_modality} placeholder tokens, which " - f"is not the expected {expected_placeholders_by_modality} " - "tokens.") - - total_len = len(prompt_token_ids) - - # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - logger.warning( - "The context length (%d) of the model is too short " - "to hold the multi-modal embeddings in the worst case " - "(%d tokens in total, out of which %s are reserved for " - "multi-modal embeddings). This may cause certain multi-modal " - "inputs to fail during inference, even when the input text is " - "short. To avoid this, you should increase `max_model_len`, " - "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, - total_len, total_placeholders_by_modality) - - return DummyData( - seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), - multi_modal_data=None, - multi_modal_placeholders=None, - ) - - prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) - - return DummyData( - seq_data=SequenceData.from_seqs(prompt_token_ids), - multi_modal_data=mm_inputs["mm_kwargs"], - multi_modal_placeholders=placeholders_by_modality, - ) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 2ecf0db1a48..2ac3a6bcf3d 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -1,16 +1,18 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Optional +from typing import Generic, TypeVar import numpy as np import numpy.typing as npt from PIL import Image -from vllm.inputs import InputProcessingContext +import vllm.envs as envs +from vllm.inputs import DummyData from vllm.logger import init_logger -from .inputs import MultiModalDataDict +from .inputs import MultiModalDataDict, MultiModalInputsV2 +from .processing import BaseMultiModalProcessor, BaseProcessingInfo logger = init_logger(__name__) @@ -23,39 +25,19 @@ class ProcessorInputs: hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) -class BaseProfilingInfo(ABC): +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class BaseDummyInputsBuilder(ABC, Generic[_I]): """ - Abstract base class that provides the information necessary to profile + Abstract base class that constructs the dummy data to profile multi-modal models. """ - def __init__(self, ctx: InputProcessingContext) -> None: + def __init__(self, info: _I) -> None: super().__init__() - self.ctx = ctx - - @abstractmethod - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - """ - Return the maximum supported number of items for each modality. - - A value of `None` means unlimited number of items. - - Omitting a modality from the returned dictionary means that - it is not supported at all. - """ - raise NotImplementedError - - @abstractmethod - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - """ - Get the maximum possible number of tokens per data item - for each modality. - - The dictionary returned by this method should have the same - keys as that returned by :meth:`get_supported_mm_limits`. - """ - raise NotImplementedError + self.info = info @abstractmethod def get_dummy_processor_inputs( @@ -64,8 +46,8 @@ def get_dummy_processor_inputs( mm_counts: Mapping[str, int], ) -> ProcessorInputs: """ - Build the multi-modal portion of the input which, after processing, - results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`. + Build the input which, after processing, results in + `self.info.get_mm_max_tokens_per_item()` placeholder tokens. """ raise NotImplementedError @@ -99,11 +81,33 @@ def _get_dummy_videos( video = np.zeros((num_frames, width, height, 3)) return [video] * num_videos - def get_mm_limits(self) -> Mapping[str, int]: - mm_config = self.ctx.get_mm_config() + +class MultiModalProfiler(Generic[_I]): + """ + Contains code for running memory profiling for multi-modal models. + """ + + def __init__( + self, + processor: BaseMultiModalProcessor[_I], + ) -> None: + super().__init__() + + self.processor = processor + + @property + def processing_info(self) -> BaseProcessingInfo: + return self.processor.info + + @property + def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: + return self.processor.dummy_inputs + + def _get_mm_limits(self) -> Mapping[str, int]: + mm_config = self.processing_info.ctx.get_mm_config() mm_limit_per_prompt = mm_config.limit_per_prompt - supported_mm_limits = self.get_supported_mm_limits() + supported_mm_limits = self.processing_info.get_supported_mm_limits() mm_limits = { modality: mm_limit_per_prompt.get(modality, 1) @@ -119,3 +123,81 @@ def get_mm_limits(self) -> Mapping[str, int]: f"at most {supported_limit} {modality} items.") return mm_limits + + def _get_dummy_mm_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalInputsV2: + factory = self.dummy_inputs + processor_inputs = factory.get_dummy_processor_inputs( + seq_len, mm_counts) + + return self.processor.apply( + prompt_text=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) + + def get_dummy_data(self, seq_len: int) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + mm_counts = self._get_mm_limits() + + info = self.processing_info + mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len) + + if mm_counts.keys() != mm_max_tokens_per_item.keys(): + raise AssertionError( + "The keys returned by `get_supported_mm_limits`" + f"({set(mm_counts.keys())}) should be the same as those " + "returned by `get_mm_max_tokens_per_item` " + f"({set(mm_max_tokens_per_item.keys())})") + + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + prompt_token_ids = mm_inputs["prompt_token_ids"] + placeholders_by_modality = mm_inputs["mm_placeholders"] + + total_placeholders_by_modality = { + modality: sum(item["length"] for item in placeholders) + for modality, placeholders in placeholders_by_modality.items() + } + expected_placeholders_by_modality = { + modality: mm_max_tokens_per_item[modality] * mm_counts[modality] + for modality in placeholders_by_modality + } + if total_placeholders_by_modality != expected_placeholders_by_modality: + raise AssertionError( + f"The processed dummy data has a total of " + f"{total_placeholders_by_modality} placeholder tokens, which " + f"is not the expected {expected_placeholders_by_modality} " + "tokens.") + + total_len = len(prompt_token_ids) + + # V0 does not support chunked prefill. + if total_len > seq_len and not envs.VLLM_USE_V1: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain multi-modal " + "inputs to fail during inference, even when the input text is " + "short. To avoid this, you should increase `max_model_len`, " + "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, + total_len, total_placeholders_by_modality) + + return DummyData( + seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + multi_modal_data=None, + multi_modal_placeholders=None, + ) + + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_placeholders=placeholders_by_modality, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index f75a594a4c4..5f01eac4eda 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,7 +1,8 @@ import functools from collections import UserDict -from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, - Sequence, Type, TypeVar) +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional, + Protocol, Sequence, Type, TypeVar) import torch.nn as nn @@ -14,7 +15,9 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import BaseMultiModalProcessor, ProcessingCache +from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, + ProcessingCache) +from .profiling import BaseDummyInputsBuilder from .utils import cached_get_tokenizer from .video import VideoPlugin @@ -27,20 +30,59 @@ MM_CACHE_SIZE = 256 N = TypeVar("N", bound=Type[nn.Module]) +_I = TypeVar("_I", bound=BaseProcessingInfo) +_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True) -class MultiModalProcessorFactory(Protocol): +class ProcessingInfoFactory(Protocol[_I_co]): """Constructs a :class:`MultiModalProcessor` instance from the context.""" def __call__( self, ctx: InputProcessingContext, + ) -> _I_co: + ... + + +class DummyInputsBuilderFactory(Protocol[_I]): + """ + Constructs a :class:`BaseDummyInputsBuilder` instance from the context. + """ + + def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: + ... + + +class MultiModalProcessorFactory(Protocol[_I]): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + info: _I, + dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, - ) -> BaseMultiModalProcessor: + ) -> BaseMultiModalProcessor[_I]: ... +@dataclass(frozen=True) +class _ProcessorFactories(Generic[_I]): + info: ProcessingInfoFactory[_I] + processor: MultiModalProcessorFactory[_I] + dummy_inputs: DummyInputsBuilderFactory[_I] + + def build_processor( + self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + ): + info = self.info(ctx) + dummy_inputs_builder = self.dummy_inputs(info) + return self.processor(info, dummy_inputs_builder, cache=cache) + + class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): """ Wraps `_limits_by_model` for a more informative error message @@ -71,7 +113,7 @@ def __init__( self._plugins = {p.get_data_key(): p for p in plugins} self._processor_factories = ClassRegistry[nn.Module, - MultiModalProcessorFactory]() + _ProcessorFactories]() # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} @@ -224,7 +266,7 @@ def get_max_tokens_per_item_by_modality( tokenizer = cached_get_tokenizer(model_config.tokenizer) processor = self.create_processor(model_config, tokenizer) seq_len = model_config.max_model_len - return processor.profiling_info.get_mm_max_tokens_per_item(seq_len) + return processor.info.get_mm_max_tokens_per_item(seq_len) return { key: plugin.get_max_multimodal_tokens(model_config) @@ -315,7 +357,10 @@ def get_mm_limits_per_prompt( def register_processor( self, - factory: MultiModalProcessorFactory, + processor: MultiModalProcessorFactory[_I], + *, + info: ProcessingInfoFactory[_I], + dummy_inputs: DummyInputsBuilderFactory[_I], ): """ Register a multi-modal processor to a model class. The processor @@ -336,7 +381,11 @@ def wrapper(model_cls: N) -> N: "registered to %s. It is overwritten by the new one.", model_cls, self) - self._processor_factories[model_cls] = factory + self._processor_factories[model_cls] = _ProcessorFactories( + info=info, + dummy_inputs=dummy_inputs, + processor=processor, + ) return model_cls @@ -359,15 +408,15 @@ def create_processor( self, model_config: "ModelConfig", tokenizer: AnyTokenizer, - ) -> BaseMultiModalProcessor: + ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. """ model_cls = self._get_model_cls(model_config) - processor_factory = self._processor_factories[model_cls] + factories = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) cache = (None if model_config.disable_mm_preprocessor_cache else self._processing_cache) - return processor_factory(ctx, cache=cache) + return factories.build_processor(ctx, cache=cache)