Skip to content

Commit 6c7ae8f

Browse files
lk-chenDamonFool
authored andcommitted
[V1][VLM][Pixtral-HF] Support Pixtral-HF on V1 (vllm-project#14275)
Signed-off-by: Linkun Chen <[email protected]>
1 parent 3c65fb9 commit 6c7ae8f

File tree

4 files changed

+175
-16
lines changed

4 files changed

+175
-16
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ See [this page](#generative-models) for more information on how to use generativ
866866
- * `PixtralForConditionalGeneration`
867867
* Pixtral
868868
* T + I<sup>+</sup>
869-
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b` (see note), etc.
869+
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc.
870870
*
871871
* ✅︎
872872
* ✅︎
@@ -930,10 +930,6 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
930930
Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release.
931931
:::
932932

933-
:::{note}
934-
`mistral-community/pixtral-12b` does not support V1 yet.
935-
:::
936-
937933
:::{note}
938934
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
939935
:::

vllm/model_executor/models/llava.py

Lines changed: 164 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterable, Mapping, Sequence
55
from functools import cached_property
66
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
7-
TypedDict, TypeVar, Union)
7+
TypedDict, TypeVar, Union, cast)
88

99
import torch
1010
import torch.nn as nn
@@ -35,6 +35,7 @@
3535
PromptReplacement, PromptUpdate)
3636
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3737
from vllm.sequence import IntermediateTensors
38+
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
3839

3940
from .clip import CLIPVisionModel
4041
from .interfaces import SupportsMultiModal, SupportsPP
@@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict):
5657
in which case the data is passed as a list instead of a batched tensor.
5758
"""
5859

60+
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
61+
"""
62+
A boolean mask indicating which image features correspond
63+
to patch tokens.
64+
65+
Shape: `(batch_size, num_crops, num_patch)`
66+
"""
67+
68+
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
69+
"""
70+
A boolean mask indicating which image embeddings correspond
71+
to patch tokens.
72+
73+
Shape: `(batch_size, num_embeds)`
74+
"""
75+
76+
num_crops: torch.Tensor
77+
"""Shape: `(batch_size, num_images)`"""
78+
5979

6080
class LlavaImageEmbeddingInputs(TypedDict):
6181
type: Literal["image_embeds"]
@@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict):
6585
`hidden_size` must match the hidden size of language model backbone.
6686
"""
6787

88+
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
89+
"""
90+
A boolean mask indicating which image features correspond
91+
to patch tokens.
92+
93+
Shape: `(batch_size, num_crops, num_patch)`
94+
"""
95+
96+
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
97+
"""
98+
A boolean mask indicating which image embeddings correspond
99+
to patch tokens.
100+
101+
Shape: `(batch_size, num_embeds)`
102+
"""
103+
104+
num_crops: torch.Tensor
105+
"""Shape: `(batch_size, num_images)`"""
106+
68107

69108
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
70109

@@ -317,14 +356,40 @@ def _call_hf_processor(
317356
for p, (h, w) in zip(pixel_values, image_sizes)
318357
]
319358

359+
hf_config = self.info.get_hf_config()
360+
361+
tile_sizes = [
362+
get_pixtral_hf_image_feature_grid_size(
363+
hf_config.vision_config,
364+
image_width=pixel_value.shape[-1],
365+
image_height=pixel_value.shape[-2])
366+
for pixel_value in processed_outputs["pixel_values"]
367+
]
368+
num_crops = torch.tensor([(ncols + 1) * nrows
369+
for ncols, nrows in tile_sizes])
370+
# Each image may result to masks of different sizes, so we need to
371+
# flatten the list and later use `num_crops` to get per-image masks.
372+
embed_is_patch = torch.tensor(
373+
flatten_2d_lists([([True] * ncols + [False]) * nrows
374+
for ncols, nrows in tile_sizes]))
375+
processed_outputs["num_crops"] = num_crops
376+
processed_outputs["embed_is_patch"] = embed_is_patch
377+
processed_outputs["feat_is_patch"] = embed_is_patch
378+
320379
return processed_outputs
321380

322381
def _get_mm_fields_config(
323382
self,
324383
hf_inputs: BatchFeature,
325384
hf_processor_mm_kwargs: Mapping[str, object],
326385
) -> Mapping[str, MultiModalFieldConfig]:
386+
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
327387
return dict(
388+
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
389+
"image", num_crops),
390+
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
391+
"image", num_crops),
392+
num_crops=MultiModalFieldConfig.batched("image"),
328393
pixel_values=MultiModalFieldConfig.batched("image"),
329394
image_embeds=MultiModalFieldConfig.batched("image"),
330395
)
@@ -562,6 +627,23 @@ def _parse_and_validate_image_input(
562627
if pixel_values is None and image_embeds is None:
563628
return None
564629

630+
feat_is_patch = kwargs.pop("feat_is_patch", None)
631+
if feat_is_patch is not None and not isinstance(
632+
feat_is_patch, (torch.Tensor, list)):
633+
raise ValueError("Incorrect type of feat_is_patch. "
634+
f"Got type: {type(feat_is_patch)}")
635+
636+
embed_is_patch = kwargs.pop("embed_is_patch", None)
637+
if embed_is_patch is not None and not isinstance(
638+
embed_is_patch, (torch.Tensor, list)):
639+
raise ValueError("Incorrect type of embed_is_patch. "
640+
f"Got type: {type(embed_is_patch)}")
641+
642+
num_crops = kwargs.pop("num_crops", None)
643+
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
644+
raise ValueError("Incorrect type of num_crops. "
645+
f"Got type: {type(num_crops)}")
646+
565647
if pixel_values is not None:
566648
if not isinstance(pixel_values, (torch.Tensor, list)):
567649
raise ValueError("Incorrect type of pixel values. "
@@ -571,12 +653,18 @@ def _parse_and_validate_image_input(
571653
return LlavaImagePixelInputs(
572654
type="pixel_values",
573655
data=flatten_bn(pixel_values),
656+
feat_is_patch=feat_is_patch,
657+
embed_is_patch=embed_is_patch,
658+
num_crops=num_crops,
574659
)
575660

576661
return LlavaImagePixelInputs(
577662
type="pixel_values",
578663
data=self._validate_pixel_values(
579664
flatten_bn(pixel_values, concat=True)),
665+
feat_is_patch=feat_is_patch,
666+
embed_is_patch=embed_is_patch,
667+
num_crops=num_crops,
580668
)
581669

582670
if image_embeds is not None:
@@ -587,6 +675,9 @@ def _parse_and_validate_image_input(
587675
return LlavaImageEmbeddingInputs(
588676
type="image_embeds",
589677
data=flatten_bn(image_embeds, concat=True),
678+
feat_is_patch=feat_is_patch,
679+
embed_is_patch=embed_is_patch,
680+
num_crops=num_crops,
590681
)
591682

592683
raise AssertionError("This line should be unreachable.")
@@ -633,16 +724,74 @@ def _process_image_input(self,
633724

634725
assert self.vision_tower is not None
635726
image_features = self._process_image_pixels(image_input)
636-
return self.multi_modal_projector(image_features)
637727

638-
def get_multimodal_embeddings(
639-
self, **kwargs
640-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
728+
if isinstance(image_features, torch.Tensor):
729+
return self.multi_modal_projector(image_features)
730+
731+
feature_sizes = [
732+
image_feature.shape[0] for image_feature in image_features
733+
]
734+
735+
image_embeds = self.multi_modal_projector(torch.cat(image_features))
736+
image_embeds = torch.split(image_embeds, feature_sizes)
737+
return image_embeds
738+
739+
def _get_mm_embeds(
740+
self,
741+
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
742+
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
743+
num_crops: torch.Tensor, # Shape: (num_images,)
744+
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
745+
) -> list[torch.Tensor]:
746+
"""Scatter the patch features into a contiguous tensor that corresponds
747+
to the embedding tokens defined by the multimodal processor.
748+
749+
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
750+
"""
751+
752+
# Insert columns of nan values according to `feat_is_patch`. This work
753+
# ideally should be done in `_process_image_input`, but
754+
# `_process_image_input` is used in both V0 and V1 path. It's safer to
755+
# put the logic here.
756+
# FIXME: Move this logic to `_process_image_input` when v0 is
757+
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
758+
feat_is_patch = feat_is_patch.view(-1)
759+
embed_is_patch = embed_is_patch.view(-1)
760+
expanded_embedding = torch.full(
761+
(sum(num_crops), *features.shape[1:]),
762+
torch.nan,
763+
dtype=features.dtype).to(features.device)
764+
expanded_embedding[feat_is_patch] = features
765+
766+
num_crops_per_image = num_crops.tolist()
767+
feats_per_image = expanded_embedding.split(num_crops_per_image)
768+
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
769+
770+
embed_dim = expanded_embedding.shape[-1]
771+
num_embeds = embed_is_patch.shape[0]
772+
773+
embeds_in_batch = list[torch.Tensor]()
774+
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
775+
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
776+
embeds[embed_is_patch] = feats[f_is_patch]
777+
embeds_in_batch.append(embeds)
778+
779+
return embeds_in_batch
780+
781+
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
641782
image_input = self._parse_and_validate_image_input(**kwargs)
642783
if image_input is None:
643784
return None
644785
vision_embeddings = self._process_image_input(image_input)
645-
return vision_embeddings
786+
if kwargs.get("v0_path", False):
787+
return vision_embeddings
788+
else:
789+
nested_emb = [
790+
self._get_mm_embeds(*args) for args in zip(
791+
vision_embeddings, image_input["feat_is_patch"],
792+
image_input["num_crops"], image_input["embed_is_patch"])
793+
]
794+
return flatten_2d_lists(nested_emb)
646795

647796
def get_input_embeddings(
648797
self,
@@ -651,8 +800,15 @@ def get_input_embeddings(
651800
) -> torch.Tensor:
652801
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
653802
if multimodal_embeddings is not None:
803+
# Extract the patch tokens
804+
patch_embeddings = json_map_leaves(
805+
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
806+
cast(JSONTree[torch.Tensor], multimodal_embeddings),
807+
)
808+
654809
inputs_embeds = merge_multimodal_embeddings(
655-
input_ids, inputs_embeds, multimodal_embeddings,
810+
input_ids, inputs_embeds, cast(NestedTensors,
811+
patch_embeddings),
656812
self.config.image_token_index)
657813
return inputs_embeds
658814

@@ -705,6 +861,7 @@ def forward(
705861
# NOTE: In v1, inputs_embeds is always generated at model runner, this
706862
# condition is for v0 compatibility.
707863
elif inputs_embeds is None:
864+
kwargs.update({"v0_path": True})
708865
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
709866
inputs_embeds = self.get_input_embeddings(input_ids,
710867
vision_embeddings)

vllm/model_executor/models/molmo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,8 +1484,8 @@ def _parse_and_validate_image_input(
14841484

14851485
img_patch_id = kwargs.pop("img_patch_id", None)
14861486
if not isinstance(img_patch_id, torch.Tensor):
1487-
raise ValueError("Incorrect type of num_crops. "
1488-
f"Got type: {type(num_crops)}")
1487+
raise ValueError("Incorrect type of img_patch_id. "
1488+
f"Got type: {type(img_patch_id)}")
14891489
self.img_patch_id = img_patch_id.flatten().unique().item()
14901490

14911491
return MolmoImageInputs(

vllm/model_executor/models/pixtral.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,9 +1042,13 @@ def forward(
10421042
for img in pixel_values
10431043
]
10441044

1045+
patch_embeds = [
1046+
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
1047+
]
1048+
embed_sizes = [p.shape[1] for p in patch_embeds]
1049+
10451050
# flatten to a single sequence
1046-
patch_embeds = torch.cat(
1047-
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
1051+
patch_embeds = torch.cat(patch_embeds, dim=1)
10481052
patch_embeds = self.ln_pre(patch_embeds)
10491053

10501054
# positional embeddings
@@ -1075,6 +1079,8 @@ def forward(
10751079
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
10761080
self.config.num_hidden_layers)
10771081

1082+
# squeeze dim 0 and split into separate tensors for each image
1083+
out = torch.split(torch.squeeze(out), embed_sizes)
10781084
return out
10791085

10801086
# (TODO) Add prefix argument for filtering out weights to be loaded

0 commit comments

Comments
 (0)