18
18
""" PyTorch Fuyu model."""
19
19
import math
20
20
from collections .abc import Iterable , Mapping , Sequence
21
- from typing import Literal , Optional , Set , Tuple , TypedDict
21
+ from typing import Literal , Optional , Set , Tuple , TypedDict , Union
22
22
23
23
import torch
24
24
import torch .nn as nn
39
39
PromptUpdate , PromptUpdateDetails )
40
40
from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
41
41
from vllm .sequence import IntermediateTensors
42
+ from vllm .utils import flatten_2d_lists
42
43
43
44
from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
44
45
from .utils import (AutoWeightsLoader , flatten_bn , maybe_prefix ,
45
46
merge_multimodal_embeddings )
47
+ from .vision import scatter_patch_features , select_patch_features
46
48
47
49
# Cannot find the following 2 numbers from hf config.
48
50
_IMAGE_TOKEN_ID = 71011
@@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict):
64
66
This is used to split the embeddings which has the first two dimensions
65
67
flattened just like `flat_data`.
66
68
"""
69
+ embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
70
+ """
71
+ A boolean mask indicating which image embeddings correspond
72
+ to patch tokens.
73
+ """
67
74
68
75
69
76
class FuyuProcessingInfo (BaseProcessingInfo ):
@@ -183,6 +190,19 @@ def _call_hf_processor(
183
190
184
191
processed_outputs ["image_patches" ] = image_patches [0 ]
185
192
193
+ # get patch grid size for each image
194
+ embed_is_patch = []
195
+ for image in images :
196
+ ncols , nrows = self .info .get_image_feature_grid_size (
197
+ image_width = image .width ,
198
+ image_height = image .height ,
199
+ )
200
+
201
+ mask = torch .tensor (([True ] * ncols + [False ]) * nrows )
202
+ embed_is_patch .append (mask )
203
+
204
+ processed_outputs ["embed_is_patch" ] = embed_is_patch
205
+
186
206
return processed_outputs
187
207
188
208
def _apply_hf_processor_tokens_only (
@@ -202,7 +222,8 @@ def _get_mm_fields_config(
202
222
hf_inputs : BatchFeature ,
203
223
hf_processor_mm_kwargs : Mapping [str , object ],
204
224
) -> Mapping [str , MultiModalFieldConfig ]:
205
- return dict (image_patches = MultiModalFieldConfig .batched ("image" ))
225
+ return dict (image_patches = MultiModalFieldConfig .batched ("image" ),
226
+ embed_is_patch = MultiModalFieldConfig .batched ("image" ))
206
227
207
228
def _get_prompt_updates (
208
229
self ,
@@ -301,18 +322,23 @@ def _validate_shape(d: torch.Tensor):
301
322
def _parse_and_validate_image_input (
302
323
self , ** kwargs : object ) -> Optional [FuyuImagePatchInputs ]:
303
324
image_patches = kwargs .pop ("image_patches" , None )
325
+ embed_is_patch = kwargs .pop ("embed_is_patch" , None )
304
326
if image_patches is not None :
305
327
if not isinstance (image_patches , (torch .Tensor , list )):
306
328
raise ValueError ("Incorrect type of image patches. "
307
329
f"Got type: { type (image_patches )} " )
308
330
331
+ if not isinstance (embed_is_patch , (torch .Tensor , list )):
332
+ raise ValueError ("Incorrect type of embed_is_patch. "
333
+ f"Got type: { type (embed_is_patch )} " )
309
334
image_patches_flat = flatten_bn (image_patches )
310
335
311
336
return FuyuImagePatchInputs (
312
337
type = "image_patches" ,
313
338
flat_data = self ._validate_pixel_values (
314
339
flatten_bn (image_patches_flat , concat = True )),
315
340
patches_per_image = [x .size (0 ) for x in image_patches_flat ],
341
+ embed_is_patch = embed_is_patch ,
316
342
)
317
343
318
344
return None
@@ -333,7 +359,12 @@ def get_multimodal_embeddings(
333
359
if image_input is None :
334
360
return None
335
361
vision_embeddings = self ._process_image_input (image_input )
336
- return vision_embeddings
362
+ #return vision_embeddings
363
+ return flatten_2d_lists (
364
+ scatter_patch_features (* args ) for args in zip (
365
+ vision_embeddings ,
366
+ image_input ["embed_is_patch" ],
367
+ ))
337
368
338
369
def get_input_embeddings (
339
370
self ,
@@ -343,8 +374,8 @@ def get_input_embeddings(
343
374
inputs_embeds = self .language_model .get_input_embeddings (input_ids )
344
375
if multimodal_embeddings is not None :
345
376
inputs_embeds = merge_multimodal_embeddings (
346
- input_ids , inputs_embeds , multimodal_embeddings ,
347
- _IMAGE_TOKEN_ID )
377
+ input_ids , inputs_embeds ,
378
+ select_patch_features ( multimodal_embeddings ), _IMAGE_TOKEN_ID )
348
379
return inputs_embeds
349
380
350
381
def forward (
0 commit comments