Skip to content

Commit 893a2cf

Browse files
committed
[Bugfix] Fix embedding assignment for InternVL-related models on V1
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 61f4121 commit 893a2cf

File tree

5 files changed

+123
-103
lines changed

5 files changed

+123
-103
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def get_image_repl(
183183
image_width: int,
184184
image_height: int,
185185
processor: Optional[Gemma3Processor],
186-
) -> PromptUpdateDetails:
186+
) -> PromptUpdateDetails[str]:
187187
if processor is None:
188188
processor = self.get_hf_processor()
189189

vllm/model_executor/models/h2ovl.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -249,20 +249,15 @@ def __init__(
249249
def image_token_id(self) -> int:
250250
return self.tokenizer.get_vocab()[IMG_CONTEXT]
251251

252-
def get_image_repl_features(
252+
def get_image_repl(
253253
self,
254254
feature_size: int,
255255
num_patches: Optional[int],
256-
) -> str:
257-
return IMG_CONTEXT * feature_size
256+
) -> PromptUpdateDetails[str]:
257+
repl_features = IMG_CONTEXT * feature_size
258+
repl_full = IMG_START + repl_features + IMG_END
258259

259-
def get_image_repl_full(
260-
self,
261-
feature_size: int,
262-
num_patches: Optional[int],
263-
) -> str:
264-
features = self.get_image_repl_features(feature_size, num_patches)
265-
return IMG_START + features + IMG_END
260+
return PromptUpdateDetails(full=repl_full, features=repl_features)
266261

267262
def resolve_min_max_num(
268263
self,
@@ -501,12 +496,7 @@ def get_replacement_internvl(item_idx: int):
501496
if num_patches is not None:
502497
assert isinstance(num_patches, int)
503498

504-
return PromptUpdateDetails(
505-
full=hf_processor.get_image_repl_full(feature_size,
506-
num_patches),
507-
features=hf_processor.get_image_repl_features(
508-
feature_size, num_patches),
509-
)
499+
return hf_processor.get_image_repl(feature_size, num_patches)
510500

511501
return [
512502
PromptReplacement(

vllm/model_executor/models/internvl.py

Lines changed: 104 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
from abc import ABC, abstractmethod
1010
from collections.abc import Iterable, Mapping, Sequence
1111
from functools import cached_property
12-
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
13-
Union)
12+
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
1413

1514
import torch
1615
import torch.nn as nn
1716
import torchvision.transforms as T
1817
from PIL import Image
19-
from transformers import BatchFeature, PretrainedConfig, TensorType
18+
from transformers import BatchEncoding, PretrainedConfig, TensorType
2019

2120
from vllm.config import VllmConfig
2221
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -36,10 +35,12 @@
3635
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3736
from vllm.sequence import IntermediateTensors
3837
from vllm.transformers_utils.tokenizer import AnyTokenizer
38+
from vllm.utils import flatten_2d_lists
3939

4040
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4141
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
4242
maybe_prefix, merge_multimodal_embeddings)
43+
from .vision import scatter_patch_features, select_patch_features
4344

4445
IMG_START = '<img>'
4546
IMG_END = '</img>'
@@ -51,16 +52,26 @@
5152

5253
class InternVLImagePixelInputs(TypedDict):
5354
type: Literal["pixel_values"]
54-
data: torch.Tensor
55+
pixel_values_flat: torch.Tensor
5556
"""
5657
Shape:
5758
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
5859
"""
59-
patches_per_image: List[int]
60+
61+
num_patches: torch.Tensor
62+
"""Shape: `(batch_size * num_images)`"""
63+
64+
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
6065
"""
61-
List of number of total patches for each image in the batch.
66+
A boolean mask indicating which image embeddings correspond
67+
to patch tokens.
68+
69+
Shape: `(batch_size, num_images, num_embeds)`
6270
"""
6371

72+
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
73+
"""Shape: `(batch_size, num_images)`"""
74+
6475

6576
class InternVLImageEmbeddingInputs(TypedDict):
6677
type: Literal["image_embeds"]
@@ -286,19 +297,11 @@ def image_token_id(self) -> int:
286297
raise NotImplementedError
287298

288299
@abstractmethod
289-
def get_image_repl_features(
300+
def get_image_repl(
290301
self,
291302
feature_size: int,
292303
num_patches: Optional[int],
293-
) -> str:
294-
raise NotImplementedError
295-
296-
@abstractmethod
297-
def get_image_repl_full(
298-
self,
299-
feature_size: int,
300-
num_patches: Optional[int],
301-
) -> str:
304+
) -> PromptUpdateDetails[str]:
302305
raise NotImplementedError
303306

304307
def resolve_min_max_num(
@@ -394,7 +397,7 @@ def __call__(
394397
max_dynamic_patch: Optional[int] = None,
395398
dynamic_image_size: Optional[bool] = None,
396399
return_tensors: Optional[Union[str, TensorType]] = None,
397-
) -> BatchFeature:
400+
) -> Mapping[str, NestedTensors]:
398401
if text is None:
399402
text = []
400403
if not isinstance(text, list):
@@ -413,28 +416,41 @@ def __call__(
413416
max_dynamic_patch=max_dynamic_patch,
414417
dynamic_image_size=dynamic_image_size,
415418
)
416-
image_inputs = {
417-
"pixel_values_flat": torch.cat(pixel_values_lst),
418-
"image_num_patches": list(map(len, pixel_values_lst)),
419+
image_inputs: dict[str, NestedTensors] = {
420+
"pixel_values_flat":
421+
torch.cat(pixel_values_lst),
422+
"image_num_patches":
423+
torch.tensor([len(item) for item in pixel_values_lst]),
419424
}
420425

426+
tokenizer = self.tokenizer
427+
image_token_id = self.image_token_id
428+
429+
num_embeds = list[int]()
430+
embed_is_patch = list[torch.Tensor]()
431+
421432
for pixel_values in pixel_values_lst:
422433
num_patches = pixel_values.shape[0]
423434
feature_size = num_patches * self.num_image_token
424435

425-
image_repl = self.get_image_repl_full(feature_size,
426-
num_patches)
427-
text = [t.replace('<image>', image_repl, 1) for t in text]
436+
image_repl = self.get_image_repl(feature_size, num_patches)
437+
feature_tokens = tokenizer.encode(image_repl.features,
438+
add_special_tokens=False)
439+
440+
text = [t.replace('<image>', image_repl.full, 1) for t in text]
441+
num_embeds.append(len(feature_tokens))
442+
embed_is_patch.append(
443+
torch.tensor(feature_tokens) == image_token_id)
444+
445+
image_inputs["num_embeds"] = torch.tensor(num_embeds)
446+
image_inputs["embed_is_patch"] = embed_is_patch
428447

429448
text_inputs = self.tokenizer(text)
430449

431-
return BatchFeature(
432-
{
433-
**text_inputs,
434-
**image_inputs,
435-
},
436-
tensor_type=return_tensors,
437-
)
450+
return {
451+
**BatchEncoding(text_inputs, tensor_type=return_tensors),
452+
**image_inputs,
453+
}
438454

439455

440456
class InternVLProcessor(BaseInternVLProcessor):
@@ -443,20 +459,15 @@ class InternVLProcessor(BaseInternVLProcessor):
443459
def image_token_id(self) -> int:
444460
return self.tokenizer.get_vocab()[IMG_CONTEXT]
445461

446-
def get_image_repl_features(
462+
def get_image_repl(
447463
self,
448464
feature_size: int,
449465
num_patches: Optional[int],
450-
) -> str:
451-
return IMG_CONTEXT * feature_size
466+
) -> PromptUpdateDetails[str]:
467+
repl_features = IMG_CONTEXT * feature_size
468+
repl_full = IMG_START + repl_features + IMG_END
452469

453-
def get_image_repl_full(
454-
self,
455-
feature_size: int,
456-
num_patches: Optional[int],
457-
) -> str:
458-
features = self.get_image_repl_features(feature_size, num_patches)
459-
return IMG_START + features + IMG_END
470+
return PromptUpdateDetails(full=repl_full, features=repl_features)
460471

461472

462473
class BaseInternVLProcessingInfo(BaseProcessingInfo):
@@ -566,16 +577,15 @@ def _call_hf_processor(
566577
prompt: str,
567578
mm_data: Mapping[str, object],
568579
mm_kwargs: Mapping[str, object],
569-
) -> BatchFeature:
580+
) -> Mapping[str, NestedTensors]:
570581
processed_outputs = super()._call_hf_processor(
571582
prompt=prompt,
572583
mm_data=mm_data,
573584
mm_kwargs=mm_kwargs,
574585
)
575586

576-
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
577-
image_data = mm_data.get("images", [])
578-
assert isinstance(image_data, list)
587+
hf_processor = self.info.get_hf_processor(**mm_kwargs)
588+
image_token_id = hf_processor.image_token_id
579589

580590
# Since there may be extra tokens in the feature placeholders,
581591
# we need to pass the image token ID to the model to select the
@@ -586,7 +596,7 @@ def _call_hf_processor(
586596

587597
def _get_mm_fields_config(
588598
self,
589-
hf_inputs: BatchFeature,
599+
hf_inputs: Mapping[str, NestedTensors],
590600
hf_processor_mm_kwargs: Mapping[str, object],
591601
) -> Mapping[str, MultiModalFieldConfig]:
592602
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@@ -596,6 +606,8 @@ def _get_mm_fields_config(
596606
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
597607
"image", image_num_patches),
598608
image_num_patches=MultiModalFieldConfig.batched("image"),
609+
embed_is_patch=MultiModalFieldConfig.batched("image"),
610+
num_embeds=MultiModalFieldConfig.batched("image"),
599611
image_embeds=MultiModalFieldConfig.batched("image"),
600612
image_token_id=MultiModalFieldConfig.shared("image", num_images),
601613
)
@@ -637,12 +649,7 @@ def get_replacement_internvl(item_idx: int):
637649
if num_patches is not None:
638650
assert isinstance(num_patches, int)
639651

640-
return PromptUpdateDetails(
641-
full=hf_processor.get_image_repl_full(feature_size,
642-
num_patches),
643-
features=hf_processor.get_image_repl_features(
644-
feature_size, num_patches),
645-
)
652+
return hf_processor.get_image_repl(feature_size, num_patches)
646653

647654
return [
648655
PromptReplacement(
@@ -832,6 +839,8 @@ def _parse_and_validate_image_input(
832839
self, **kwargs: object) -> Optional[InternVLImageInputs]:
833840
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
834841
image_num_patches = kwargs.pop("image_num_patches", None)
842+
embed_is_patch = kwargs.pop("embed_is_patch", None)
843+
num_embeds = kwargs.pop("num_embeds", None)
835844
image_embeds = kwargs.pop("image_embeds", None)
836845

837846
if pixel_values_flat is None and image_embeds is None:
@@ -858,46 +867,57 @@ def _parse_and_validate_image_input(
858867

859868
if not isinstance(image_num_patches, (torch.Tensor, list)):
860869
raise ValueError("Incorrect type of image_num_patches. "
861-
f"Got type: {type(pixel_values_flat)}")
870+
f"Got type: {type(image_num_patches)}")
871+
872+
if not isinstance(embed_is_patch, (torch.Tensor, list)):
873+
raise ValueError("Incorrect type of embed_is_patch. "
874+
f"Got type: {type(embed_is_patch)}")
875+
876+
if not isinstance(num_embeds, (torch.Tensor, list)):
877+
raise ValueError("Incorrect type of num_embeds. "
878+
f"Got type: {type(num_embeds)}")
879+
880+
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
881+
image_num_patches = flatten_bn(image_num_patches, concat=True)
862882

863883
return InternVLImagePixelInputs(
864884
type="pixel_values",
865-
data=self._validate_pixel_values(
866-
flatten_bn(pixel_values_flat, concat=True)),
867-
patches_per_image=flatten_bn(image_num_patches,
868-
concat=True).tolist())
885+
pixel_values_flat=self._validate_pixel_values(
886+
pixel_values_flat),
887+
num_patches=image_num_patches,
888+
embed_is_patch=embed_is_patch,
889+
num_embeds=num_embeds,
890+
)
869891

870892
raise AssertionError("This line should be unreachable.")
871893

872894
def _process_image_input(
873895
self,
874896
image_input: InternVLImageInputs,
875-
) -> tuple[torch.Tensor, ...]:
897+
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
876898
if image_input["type"] == "image_embeds":
877899
return image_input["data"]
878900

879901
assert self.vision_model is not None
880902

881-
image_embeds = self.extract_feature(image_input["data"])
903+
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
882904

883-
patches_per_image = image_input["patches_per_image"]
905+
num_patches = image_input["num_patches"]
884906

885907
# Only one image in the current batch
886-
if len(patches_per_image) == 1:
887-
image_embeds = image_embeds.view(
908+
if len(num_patches) == 1:
909+
return image_embeds.view(
888910
-1, self.config.text_config.hidden_size).unsqueeze(0)
889-
return image_embeds
890911

891912
# NOTE: Image embeddings are split into separate tensors for each image
892913
# by the size of each embedding.
893914
feature_size = image_embeds.shape[1]
894915
image_embeds = image_embeds.view(-1,
895916
self.config.text_config.hidden_size)
896917
image_feature_sizes = [
897-
num_patches * feature_size for num_patches in patches_per_image
918+
num_patches * feature_size for num_patches in num_patches
898919
]
899-
image_embeds = image_embeds.split(image_feature_sizes)
900-
return image_embeds
920+
return image_embeds.split(image_feature_sizes)
901921

902922
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
903923
if self.is_mono:
@@ -911,8 +931,19 @@ def get_multimodal_embeddings(
911931
image_input = self._parse_and_validate_image_input(**kwargs)
912932
if image_input is None:
913933
return None
914-
vision_embeddings = self._process_image_input(image_input)
915-
return vision_embeddings
934+
935+
image_features = self._process_image_input(image_input)
936+
937+
if (kwargs.get("v0_path", False)
938+
or image_input["type"] != "pixel_values"):
939+
return image_features
940+
941+
return flatten_2d_lists(
942+
scatter_patch_features(*args) for args in zip(
943+
image_features,
944+
image_input["num_embeds"],
945+
image_input["embed_is_patch"],
946+
))
916947

917948
def get_input_embeddings(
918949
self,
@@ -924,8 +955,11 @@ def get_input_embeddings(
924955
assert self.img_context_token_id is not None
925956
self._set_visual_token_mask(input_ids)
926957
inputs_embeds = merge_multimodal_embeddings(
927-
input_ids, inputs_embeds, multimodal_embeddings,
928-
self.img_context_token_id)
958+
input_ids,
959+
inputs_embeds,
960+
select_patch_features(multimodal_embeddings),
961+
self.img_context_token_id,
962+
)
929963
return inputs_embeds
930964

931965
def forward(
@@ -944,6 +978,7 @@ def forward(
944978
# NOTE: In v1, inputs_embeds is always generated at model runner, this
945979
# condition is for v0 compatibility.
946980
elif inputs_embeds is None:
981+
kwargs.update({"v0_path": True})
947982
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
948983
inputs_embeds = self.get_input_embeddings(input_ids,
949984
vision_embeddings)

0 commit comments

Comments
 (0)