Skip to content

Commit 8c38ee7

Browse files
[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent b6087a6 commit 8c38ee7

File tree

14 files changed

+605
-551
lines changed

14 files changed

+605
-551
lines changed

tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

tests/multimodal/test_mapper.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

tests/multimodal/test_processing.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from contextlib import nullcontext
12
from functools import partial
23
from typing import cast
4+
from unittest.mock import MagicMock
35

46
import numpy as np
57
import pytest
@@ -526,6 +528,100 @@ def _rand_audio(
526528
return rng.rand(audio_len), sr
527529

528530

531+
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
532+
@pytest.mark.parametrize(
533+
("limit", "num_supported", "is_valid"),
534+
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
535+
(2, 1, False), (2, 2, True)],
536+
)
537+
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
538+
limit_mm_per_prompt = {"image": limit}
539+
540+
model_config = ModelConfig(
541+
model=model_id,
542+
task="auto",
543+
tokenizer=model_id,
544+
tokenizer_mode="auto",
545+
trust_remote_code=False,
546+
seed=0,
547+
dtype="half",
548+
revision=None,
549+
limit_mm_per_prompt=limit_mm_per_prompt,
550+
)
551+
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
552+
553+
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
554+
ctx = InputProcessingContext(
555+
model_config,
556+
tokenizer=cached_get_tokenizer(model_config.tokenizer),
557+
)
558+
559+
processor = processor_factory(ctx, cache=None)
560+
561+
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
562+
processor.get_supported_mm_limits = mock_supported_mm_limits
563+
564+
if is_valid:
565+
exc_ctx = nullcontext()
566+
else:
567+
exc_ctx = pytest.raises(ValueError, match="this model only supports")
568+
569+
with exc_ctx:
570+
processor._get_and_validate_dummy_mm_counts()
571+
572+
573+
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
574+
@pytest.mark.parametrize(
575+
("num_images", "limit", "is_valid"),
576+
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
577+
(2, 1, False), (2, 2, True)],
578+
)
579+
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
580+
limit_mm_per_prompt = {"image": limit}
581+
582+
model_config = ModelConfig(
583+
model=model_id,
584+
task="auto",
585+
tokenizer=model_id,
586+
tokenizer_mode="auto",
587+
trust_remote_code=False,
588+
seed=0,
589+
dtype="half",
590+
revision=None,
591+
limit_mm_per_prompt=limit_mm_per_prompt,
592+
)
593+
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
594+
595+
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
596+
ctx = InputProcessingContext(
597+
model_config,
598+
tokenizer=cached_get_tokenizer(model_config.tokenizer),
599+
)
600+
601+
processor = processor_factory(ctx, cache=None)
602+
603+
rng = np.random.RandomState(0)
604+
image = _rand_img(rng, min_wh=128, max_wh=256)
605+
if num_images == 0:
606+
mm_data = {}
607+
elif num_images == 1:
608+
mm_data = {"image": image}
609+
else:
610+
mm_data = {"image": [image] * num_images}
611+
612+
if is_valid:
613+
exc_ctx = nullcontext()
614+
else:
615+
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")
616+
617+
with exc_ctx:
618+
processor.apply(
619+
"<image>" * num_images,
620+
mm_data=mm_data,
621+
hf_processor_mm_kwargs={},
622+
)
623+
624+
529625
def _test_processing_cache_correctness(
530626
model_id: str,
531627
modalities: dict[str, bool],
@@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
631727
("facebook/chameleon-7b", {"image": False}),
632728
("adept/fuyu-8b", {"image": False}),
633729
("llava-hf/llava-1.5-7b-hf", {"image": True}),
730+
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
634731
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
635732
("mistral-community/pixtral-12b", {"image": True}),
636733
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),

tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import torch
44

55
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
6-
LlavaMultiModalProcessor,
7-
get_max_llava_image_tokens)
6+
LlavaMultiModalProcessor)
87
from vllm.model_executor.sampling_metadata import SamplingMetadata
98
from vllm.multimodal import MULTIMODAL_REGISTRY
109

1110

12-
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
1311
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
1412
class MyLlava(LlavaForConditionalGeneration):
1513

vllm/model_executor/models/clip.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
resolve_visual_encoder_outputs)
2525
from vllm.sequence import SequenceData
2626

27+
from .vision import VisionEncoderInfo
28+
2729

2830
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
2931
assert image_size % patch_size == 0
@@ -149,6 +151,29 @@ def input_processor_for_clip(
149151
multi_modal_placeholders={"image": ranges})
150152

151153

154+
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
155+
156+
def get_num_image_tokens(
157+
self,
158+
*,
159+
image_width: int,
160+
image_height: int,
161+
) -> int:
162+
return get_clip_image_feature_size(self.vision_config)
163+
164+
def get_max_image_tokens(self) -> int:
165+
return get_max_clip_image_tokens(self.vision_config)
166+
167+
def get_num_patches(self) -> int:
168+
return get_clip_patch_grid_length(
169+
image_size=self.vision_config.image_size,
170+
patch_size=self.vision_config.patch_size,
171+
)
172+
173+
def get_image_size(self) -> int:
174+
return self.vision_config.image_size
175+
176+
152177
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
153178
class CLIPVisionEmbeddings(nn.Module):
154179

vllm/model_executor/models/fuyu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _get_image_target_size(self) -> ImageSize:
7676
return ImageSize(width=target_size["width"],
7777
height=target_size["height"])
7878

79-
def _get_image_grid_size(
79+
def _get_image_feature_grid_size(
8080
self,
8181
*,
8282
image_width: int,
@@ -99,7 +99,7 @@ def _get_image_grid_size(
9999
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
100100
target_width, target_height = self._get_image_target_size()
101101

102-
max_ncols, max_nrows = self._get_image_grid_size(
102+
max_ncols, max_nrows = self._get_image_feature_grid_size(
103103
image_width=target_width,
104104
image_height=target_height,
105105
)
@@ -172,7 +172,7 @@ def get_replacement_fuyu(item_idx: int):
172172
images = mm_items.get_items("image", ImageProcessorItems)
173173
image_size = images.get_image_size(item_idx)
174174

175-
ncols, nrows = self._get_image_grid_size(
175+
ncols, nrows = self._get_image_feature_grid_size(
176176
image_width=image_size.width,
177177
image_height=image_size.height,
178178
)

0 commit comments

Comments
 (0)