Skip to content

Commit 8ca4598

Browse files
kylehhkylesayrs
authored andcommitted
[Bugfix] Added embed_is_patch mask for fuyu model (vllm-project#15731)
Signed-off-by: Kyle Huang <[email protected]> Signed-off-by: Kyle Sayers <[email protected]>
1 parent c34e9ae commit 8ca4598

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

vllm/model_executor/models/fuyu.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
""" PyTorch Fuyu model."""
1919
import math
2020
from collections.abc import Iterable, Mapping, Sequence
21-
from typing import Literal, Optional, Set, Tuple, TypedDict
21+
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
2222

2323
import torch
2424
import torch.nn as nn
@@ -39,10 +39,12 @@
3939
PromptUpdate, PromptUpdateDetails)
4040
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
4141
from vllm.sequence import IntermediateTensors
42+
from vllm.utils import flatten_2d_lists
4243

4344
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4445
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
4546
merge_multimodal_embeddings)
47+
from .vision import scatter_patch_features, select_patch_features
4648

4749
# Cannot find the following 2 numbers from hf config.
4850
_IMAGE_TOKEN_ID = 71011
@@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict):
6466
This is used to split the embeddings which has the first two dimensions
6567
flattened just like `flat_data`.
6668
"""
69+
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
70+
"""
71+
A boolean mask indicating which image embeddings correspond
72+
to patch tokens.
73+
"""
6774

6875

6976
class FuyuProcessingInfo(BaseProcessingInfo):
@@ -183,6 +190,19 @@ def _call_hf_processor(
183190

184191
processed_outputs["image_patches"] = image_patches[0]
185192

193+
# get patch grid size for each image
194+
embed_is_patch = []
195+
for image in images:
196+
ncols, nrows = self.info.get_image_feature_grid_size(
197+
image_width=image.width,
198+
image_height=image.height,
199+
)
200+
201+
mask = torch.tensor(([True] * ncols + [False]) * nrows)
202+
embed_is_patch.append(mask)
203+
204+
processed_outputs["embed_is_patch"] = embed_is_patch
205+
186206
return processed_outputs
187207

188208
def _apply_hf_processor_tokens_only(
@@ -202,7 +222,8 @@ def _get_mm_fields_config(
202222
hf_inputs: BatchFeature,
203223
hf_processor_mm_kwargs: Mapping[str, object],
204224
) -> Mapping[str, MultiModalFieldConfig]:
205-
return dict(image_patches=MultiModalFieldConfig.batched("image"))
225+
return dict(image_patches=MultiModalFieldConfig.batched("image"),
226+
embed_is_patch=MultiModalFieldConfig.batched("image"))
206227

207228
def _get_prompt_updates(
208229
self,
@@ -301,18 +322,23 @@ def _validate_shape(d: torch.Tensor):
301322
def _parse_and_validate_image_input(
302323
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
303324
image_patches = kwargs.pop("image_patches", None)
325+
embed_is_patch = kwargs.pop("embed_is_patch", None)
304326
if image_patches is not None:
305327
if not isinstance(image_patches, (torch.Tensor, list)):
306328
raise ValueError("Incorrect type of image patches. "
307329
f"Got type: {type(image_patches)}")
308330

331+
if not isinstance(embed_is_patch, (torch.Tensor, list)):
332+
raise ValueError("Incorrect type of embed_is_patch. "
333+
f"Got type: {type(embed_is_patch)}")
309334
image_patches_flat = flatten_bn(image_patches)
310335

311336
return FuyuImagePatchInputs(
312337
type="image_patches",
313338
flat_data=self._validate_pixel_values(
314339
flatten_bn(image_patches_flat, concat=True)),
315340
patches_per_image=[x.size(0) for x in image_patches_flat],
341+
embed_is_patch=embed_is_patch,
316342
)
317343

318344
return None
@@ -333,7 +359,12 @@ def get_multimodal_embeddings(
333359
if image_input is None:
334360
return None
335361
vision_embeddings = self._process_image_input(image_input)
336-
return vision_embeddings
362+
#return vision_embeddings
363+
return flatten_2d_lists(
364+
scatter_patch_features(*args) for args in zip(
365+
vision_embeddings,
366+
image_input["embed_is_patch"],
367+
))
337368

338369
def get_input_embeddings(
339370
self,
@@ -343,8 +374,8 @@ def get_input_embeddings(
343374
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
344375
if multimodal_embeddings is not None:
345376
inputs_embeds = merge_multimodal_embeddings(
346-
input_ids, inputs_embeds, multimodal_embeddings,
347-
_IMAGE_TOKEN_ID)
377+
input_ids, inputs_embeds,
378+
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
348379
return inputs_embeds
349380

350381
def forward(

0 commit comments

Comments
 (0)