Skip to content

[Bugfix] Added embed_is_patch mask for fuyu model #15731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict
from typing import Literal, Optional, Set, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand All @@ -39,10 +39,12 @@
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
Expand All @@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""


class FuyuProcessingInfo(BaseProcessingInfo):
Expand Down Expand Up @@ -183,6 +190,19 @@ def _call_hf_processor(

processed_outputs["image_patches"] = image_patches[0]

# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)

mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)

processed_outputs["embed_is_patch"] = embed_is_patch

return processed_outputs

def _apply_hf_processor_tokens_only(
Expand All @@ -202,7 +222,8 @@ def _get_mm_fields_config(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
return dict(image_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"))

def _get_prompt_updates(
self,
Expand Down Expand Up @@ -301,18 +322,23 @@ def _validate_shape(d: torch.Tensor):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")

if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches)

return FuyuImagePatchInputs(
type="image_patches",
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
)

return None
Expand All @@ -333,7 +359,12 @@ def get_multimodal_embeddings(
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
#return vision_embeddings
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["embed_is_patch"],
))

def get_input_embeddings(
self,
Expand All @@ -343,8 +374,8 @@ def get_input_embeddings(
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
input_ids, inputs_embeds,
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
return inputs_embeds

def forward(
Expand Down