9
9
from abc import ABC , abstractmethod
10
10
from collections .abc import Iterable , Mapping , Sequence
11
11
from functools import cached_property
12
- from typing import (List , Literal , Optional , Set , Tuple , TypedDict , TypeVar ,
13
- Union )
12
+ from typing import Literal , Optional , Set , Tuple , TypedDict , TypeVar , Union
14
13
15
14
import torch
16
15
import torch .nn as nn
17
16
import torchvision .transforms as T
18
17
from PIL import Image
19
- from transformers import BatchFeature , PretrainedConfig , TensorType
18
+ from transformers import BatchEncoding , PretrainedConfig , TensorType
20
19
21
20
from vllm .config import VllmConfig
22
21
from vllm .model_executor .layers .quantization import QuantizationConfig
36
35
from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
37
36
from vllm .sequence import IntermediateTensors
38
37
from vllm .transformers_utils .tokenizer import AnyTokenizer
38
+ from vllm .utils import flatten_2d_lists
39
39
40
40
from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
41
41
from .utils import (AutoWeightsLoader , flatten_bn , init_vllm_registered_model ,
42
42
maybe_prefix , merge_multimodal_embeddings )
43
+ from .vision import scatter_patch_features , select_patch_features
43
44
44
45
IMG_START = '<img>'
45
46
IMG_END = '</img>'
51
52
52
53
class InternVLImagePixelInputs (TypedDict ):
53
54
type : Literal ["pixel_values" ]
54
- data : torch .Tensor
55
+ pixel_values_flat : torch .Tensor
55
56
"""
56
57
Shape:
57
58
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
58
59
"""
59
- patches_per_image : List [int ]
60
+
61
+ num_patches : torch .Tensor
62
+ """Shape: `(batch_size * num_images)`"""
63
+
64
+ embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
60
65
"""
61
- List of number of total patches for each image in the batch.
66
+ A boolean mask indicating which image embeddings correspond
67
+ to patch tokens.
68
+
69
+ Shape: `(batch_size, num_images, num_embeds)`
62
70
"""
63
71
72
+ num_embeds : Union [torch .Tensor , list [torch .Tensor ]]
73
+ """Shape: `(batch_size, num_images)`"""
74
+
64
75
65
76
class InternVLImageEmbeddingInputs (TypedDict ):
66
77
type : Literal ["image_embeds" ]
@@ -286,19 +297,11 @@ def image_token_id(self) -> int:
286
297
raise NotImplementedError
287
298
288
299
@abstractmethod
289
- def get_image_repl_features (
300
+ def get_image_repl (
290
301
self ,
291
302
feature_size : int ,
292
303
num_patches : Optional [int ],
293
- ) -> str :
294
- raise NotImplementedError
295
-
296
- @abstractmethod
297
- def get_image_repl_full (
298
- self ,
299
- feature_size : int ,
300
- num_patches : Optional [int ],
301
- ) -> str :
304
+ ) -> PromptUpdateDetails [str ]:
302
305
raise NotImplementedError
303
306
304
307
def resolve_min_max_num (
@@ -394,7 +397,7 @@ def __call__(
394
397
max_dynamic_patch : Optional [int ] = None ,
395
398
dynamic_image_size : Optional [bool ] = None ,
396
399
return_tensors : Optional [Union [str , TensorType ]] = None ,
397
- ) -> BatchFeature :
400
+ ) -> Mapping [ str , NestedTensors ] :
398
401
if text is None :
399
402
text = []
400
403
if not isinstance (text , list ):
@@ -413,28 +416,41 @@ def __call__(
413
416
max_dynamic_patch = max_dynamic_patch ,
414
417
dynamic_image_size = dynamic_image_size ,
415
418
)
416
- image_inputs = {
417
- "pixel_values_flat" : torch .cat (pixel_values_lst ),
418
- "image_num_patches" : list (map (len , pixel_values_lst )),
419
+ image_inputs : dict [str , NestedTensors ] = {
420
+ "pixel_values_flat" :
421
+ torch .cat (pixel_values_lst ),
422
+ "image_num_patches" :
423
+ torch .tensor ([len (item ) for item in pixel_values_lst ]),
419
424
}
420
425
426
+ tokenizer = self .tokenizer
427
+ image_token_id = self .image_token_id
428
+
429
+ num_embeds = list [int ]()
430
+ embed_is_patch = list [torch .Tensor ]()
431
+
421
432
for pixel_values in pixel_values_lst :
422
433
num_patches = pixel_values .shape [0 ]
423
434
feature_size = num_patches * self .num_image_token
424
435
425
- image_repl = self .get_image_repl_full (feature_size ,
426
- num_patches )
427
- text = [t .replace ('<image>' , image_repl , 1 ) for t in text ]
436
+ image_repl = self .get_image_repl (feature_size , num_patches )
437
+ feature_tokens = tokenizer .encode (image_repl .features ,
438
+ add_special_tokens = False )
439
+
440
+ text = [t .replace ('<image>' , image_repl .full , 1 ) for t in text ]
441
+ num_embeds .append (len (feature_tokens ))
442
+ embed_is_patch .append (
443
+ torch .tensor (feature_tokens ) == image_token_id )
444
+
445
+ image_inputs ["num_embeds" ] = torch .tensor (num_embeds )
446
+ image_inputs ["embed_is_patch" ] = embed_is_patch
428
447
429
448
text_inputs = self .tokenizer (text )
430
449
431
- return BatchFeature (
432
- {
433
- ** text_inputs ,
434
- ** image_inputs ,
435
- },
436
- tensor_type = return_tensors ,
437
- )
450
+ return {
451
+ ** BatchEncoding (text_inputs , tensor_type = return_tensors ),
452
+ ** image_inputs ,
453
+ }
438
454
439
455
440
456
class InternVLProcessor (BaseInternVLProcessor ):
@@ -443,20 +459,15 @@ class InternVLProcessor(BaseInternVLProcessor):
443
459
def image_token_id (self ) -> int :
444
460
return self .tokenizer .get_vocab ()[IMG_CONTEXT ]
445
461
446
- def get_image_repl_features (
462
+ def get_image_repl (
447
463
self ,
448
464
feature_size : int ,
449
465
num_patches : Optional [int ],
450
- ) -> str :
451
- return IMG_CONTEXT * feature_size
466
+ ) -> PromptUpdateDetails [str ]:
467
+ repl_features = IMG_CONTEXT * feature_size
468
+ repl_full = IMG_START + repl_features + IMG_END
452
469
453
- def get_image_repl_full (
454
- self ,
455
- feature_size : int ,
456
- num_patches : Optional [int ],
457
- ) -> str :
458
- features = self .get_image_repl_features (feature_size , num_patches )
459
- return IMG_START + features + IMG_END
470
+ return PromptUpdateDetails (full = repl_full , features = repl_features )
460
471
461
472
462
473
class BaseInternVLProcessingInfo (BaseProcessingInfo ):
@@ -566,16 +577,15 @@ def _call_hf_processor(
566
577
prompt : str ,
567
578
mm_data : Mapping [str , object ],
568
579
mm_kwargs : Mapping [str , object ],
569
- ) -> BatchFeature :
580
+ ) -> Mapping [ str , NestedTensors ] :
570
581
processed_outputs = super ()._call_hf_processor (
571
582
prompt = prompt ,
572
583
mm_data = mm_data ,
573
584
mm_kwargs = mm_kwargs ,
574
585
)
575
586
576
- image_token_id = self .info .get_hf_processor (** mm_kwargs ).image_token_id
577
- image_data = mm_data .get ("images" , [])
578
- assert isinstance (image_data , list )
587
+ hf_processor = self .info .get_hf_processor (** mm_kwargs )
588
+ image_token_id = hf_processor .image_token_id
579
589
580
590
# Since there may be extra tokens in the feature placeholders,
581
591
# we need to pass the image token ID to the model to select the
@@ -586,7 +596,7 @@ def _call_hf_processor(
586
596
587
597
def _get_mm_fields_config (
588
598
self ,
589
- hf_inputs : BatchFeature ,
599
+ hf_inputs : Mapping [ str , NestedTensors ] ,
590
600
hf_processor_mm_kwargs : Mapping [str , object ],
591
601
) -> Mapping [str , MultiModalFieldConfig ]:
592
602
image_num_patches = hf_inputs .get ("image_num_patches" , torch .empty (0 ))
@@ -596,6 +606,8 @@ def _get_mm_fields_config(
596
606
pixel_values_flat = MultiModalFieldConfig .flat_from_sizes (
597
607
"image" , image_num_patches ),
598
608
image_num_patches = MultiModalFieldConfig .batched ("image" ),
609
+ embed_is_patch = MultiModalFieldConfig .batched ("image" ),
610
+ num_embeds = MultiModalFieldConfig .batched ("image" ),
599
611
image_embeds = MultiModalFieldConfig .batched ("image" ),
600
612
image_token_id = MultiModalFieldConfig .shared ("image" , num_images ),
601
613
)
@@ -637,12 +649,7 @@ def get_replacement_internvl(item_idx: int):
637
649
if num_patches is not None :
638
650
assert isinstance (num_patches , int )
639
651
640
- return PromptUpdateDetails (
641
- full = hf_processor .get_image_repl_full (feature_size ,
642
- num_patches ),
643
- features = hf_processor .get_image_repl_features (
644
- feature_size , num_patches ),
645
- )
652
+ return hf_processor .get_image_repl (feature_size , num_patches )
646
653
647
654
return [
648
655
PromptReplacement (
@@ -832,6 +839,8 @@ def _parse_and_validate_image_input(
832
839
self , ** kwargs : object ) -> Optional [InternVLImageInputs ]:
833
840
pixel_values_flat = kwargs .pop ("pixel_values_flat" , None )
834
841
image_num_patches = kwargs .pop ("image_num_patches" , None )
842
+ embed_is_patch = kwargs .pop ("embed_is_patch" , None )
843
+ num_embeds = kwargs .pop ("num_embeds" , None )
835
844
image_embeds = kwargs .pop ("image_embeds" , None )
836
845
837
846
if pixel_values_flat is None and image_embeds is None :
@@ -858,46 +867,57 @@ def _parse_and_validate_image_input(
858
867
859
868
if not isinstance (image_num_patches , (torch .Tensor , list )):
860
869
raise ValueError ("Incorrect type of image_num_patches. "
861
- f"Got type: { type (pixel_values_flat )} " )
870
+ f"Got type: { type (image_num_patches )} " )
871
+
872
+ if not isinstance (embed_is_patch , (torch .Tensor , list )):
873
+ raise ValueError ("Incorrect type of embed_is_patch. "
874
+ f"Got type: { type (embed_is_patch )} " )
875
+
876
+ if not isinstance (num_embeds , (torch .Tensor , list )):
877
+ raise ValueError ("Incorrect type of num_embeds. "
878
+ f"Got type: { type (num_embeds )} " )
879
+
880
+ pixel_values_flat = flatten_bn (pixel_values_flat , concat = True )
881
+ image_num_patches = flatten_bn (image_num_patches , concat = True )
862
882
863
883
return InternVLImagePixelInputs (
864
884
type = "pixel_values" ,
865
- data = self ._validate_pixel_values (
866
- flatten_bn (pixel_values_flat , concat = True )),
867
- patches_per_image = flatten_bn (image_num_patches ,
868
- concat = True ).tolist ())
885
+ pixel_values_flat = self ._validate_pixel_values (
886
+ pixel_values_flat ),
887
+ num_patches = image_num_patches ,
888
+ embed_is_patch = embed_is_patch ,
889
+ num_embeds = num_embeds ,
890
+ )
869
891
870
892
raise AssertionError ("This line should be unreachable." )
871
893
872
894
def _process_image_input (
873
895
self ,
874
896
image_input : InternVLImageInputs ,
875
- ) -> tuple [torch .Tensor , ...]:
897
+ ) -> Union [ torch . Tensor , tuple [torch .Tensor , ...] ]:
876
898
if image_input ["type" ] == "image_embeds" :
877
899
return image_input ["data" ]
878
900
879
901
assert self .vision_model is not None
880
902
881
- image_embeds = self .extract_feature (image_input ["data " ])
903
+ image_embeds = self .extract_feature (image_input ["pixel_values_flat " ])
882
904
883
- patches_per_image = image_input ["patches_per_image " ]
905
+ num_patches = image_input ["num_patches " ]
884
906
885
907
# Only one image in the current batch
886
- if len (patches_per_image ) == 1 :
887
- image_embeds = image_embeds .view (
908
+ if len (num_patches ) == 1 :
909
+ return image_embeds .view (
888
910
- 1 , self .config .text_config .hidden_size ).unsqueeze (0 )
889
- return image_embeds
890
911
891
912
# NOTE: Image embeddings are split into separate tensors for each image
892
913
# by the size of each embedding.
893
914
feature_size = image_embeds .shape [1 ]
894
915
image_embeds = image_embeds .view (- 1 ,
895
916
self .config .text_config .hidden_size )
896
917
image_feature_sizes = [
897
- num_patches * feature_size for num_patches in patches_per_image
918
+ num_patches * feature_size for num_patches in num_patches
898
919
]
899
- image_embeds = image_embeds .split (image_feature_sizes )
900
- return image_embeds
920
+ return image_embeds .split (image_feature_sizes )
901
921
902
922
def _set_visual_token_mask (self , input_ids : torch .Tensor ) -> None :
903
923
if self .is_mono :
@@ -911,8 +931,19 @@ def get_multimodal_embeddings(
911
931
image_input = self ._parse_and_validate_image_input (** kwargs )
912
932
if image_input is None :
913
933
return None
914
- vision_embeddings = self ._process_image_input (image_input )
915
- return vision_embeddings
934
+
935
+ image_features = self ._process_image_input (image_input )
936
+
937
+ if (kwargs .get ("v0_path" , False )
938
+ or image_input ["type" ] != "pixel_values" ):
939
+ return image_features
940
+
941
+ return flatten_2d_lists (
942
+ scatter_patch_features (* args ) for args in zip (
943
+ image_features ,
944
+ image_input ["num_embeds" ],
945
+ image_input ["embed_is_patch" ],
946
+ ))
916
947
917
948
def get_input_embeddings (
918
949
self ,
@@ -924,8 +955,11 @@ def get_input_embeddings(
924
955
assert self .img_context_token_id is not None
925
956
self ._set_visual_token_mask (input_ids )
926
957
inputs_embeds = merge_multimodal_embeddings (
927
- input_ids , inputs_embeds , multimodal_embeddings ,
928
- self .img_context_token_id )
958
+ input_ids ,
959
+ inputs_embeds ,
960
+ select_patch_features (multimodal_embeddings ),
961
+ self .img_context_token_id ,
962
+ )
929
963
return inputs_embeds
930
964
931
965
def forward (
@@ -944,6 +978,7 @@ def forward(
944
978
# NOTE: In v1, inputs_embeds is always generated at model runner, this
945
979
# condition is for v0 compatibility.
946
980
elif inputs_embeds is None :
981
+ kwargs .update ({"v0_path" : True })
947
982
vision_embeddings = self .get_multimodal_embeddings (** kwargs )
948
983
inputs_embeds = self .get_input_embeddings (input_ids ,
949
984
vision_embeddings )
0 commit comments