Skip to content

Commit 9e4e312

Browse files
committed
[V1] Scatter and gather placeholders in the model runner
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 239b7be commit 9e4e312

23 files changed

+297
-621
lines changed

Diff for: docs/source/contributing/model/multimodal.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -860,8 +860,8 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
860860
)
861861
```
862862

863-
To accommodate this, instead of a string you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`
864-
with different `full` and `feature` attributes:
863+
To assign the vision embeddings to only the image tokens, instead of a string
864+
you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`:
865865

866866
```python
867867
hf_config = self.info.get_hf_config()
@@ -879,9 +879,9 @@ def get_replacement_fuyu(item_idx: int):
879879
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
880880
[_NEWLINE_TOKEN_ID]) * nrows
881881

882-
return PromptUpdateDetails(
883-
full=image_tokens + [bos_token_id],
884-
features=image_tokens,
882+
return PromptUpdateDetails.select_token_id(
883+
image_tokens + [bos_token_id],
884+
embed_token_id=_IMAGE_TOKEN_ID,
885885
)
886886
```
887887

@@ -914,9 +914,9 @@ def _get_prompt_updates(
914914
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
915915
[_NEWLINE_TOKEN_ID]) * nrows
916916

917-
return PromptUpdateDetails(
918-
full=image_tokens + [bos_token_id],
919-
features=image_tokens,
917+
return PromptUpdateDetails.select_token_id(
918+
image_tokens + [bos_token_id],
919+
embed_token_id=_IMAGE_TOKEN_ID,
920920
)
921921

922922
return [

Diff for: vllm/model_executor/models/chameleon.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ def _get_prompt_updates(
161161
PromptReplacement(
162162
modality="image",
163163
target=[image_token_id],
164-
replacement=PromptUpdateDetails(
165-
full=([image_start_id] + image_tokens + [image_end_id]),
166-
features=image_tokens,
164+
replacement=PromptUpdateDetails.select_token_id(
165+
[image_start_id] + image_tokens + [image_end_id],
166+
embed_token_id=image_token_id,
167167
),
168168
)
169169
]

Diff for: vllm/model_executor/models/fuyu.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ def get_replacement_fuyu(item_idx: int):
252252
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
253253
[_NEWLINE_TOKEN_ID]) * nrows
254254

255-
return PromptUpdateDetails(
256-
full=image_tokens + [bos_token_id],
257-
features=image_tokens,
255+
return PromptUpdateDetails.select_token_id(
256+
image_tokens + [bos_token_id],
257+
embed_token_id=_IMAGE_TOKEN_ID,
258258
)
259259

260260
return [

Diff for: vllm/model_executor/models/gemma3_mm.py

+14-55
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from .siglip import SiglipVisionModel
3737
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
3838
maybe_prefix, merge_multimodal_embeddings)
39-
from .vision import scatter_patch_features, select_patch_features
4039

4140
logger = init_logger(__name__)
4241

@@ -54,14 +53,6 @@ class Gemma3ImagePixelInputs(TypedDict):
5453
num_patches: torch.Tensor
5554
"""Shape: `(batch_size * num_images)`"""
5655

57-
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
58-
"""
59-
A boolean mask indicating which image embeddings correspond
60-
to patch tokens.
61-
62-
Shape: `(batch_size * num_images, num_embeds)`
63-
"""
64-
6556

6657
Gemma3ImageInputs = Gemma3ImagePixelInputs
6758

@@ -183,7 +174,7 @@ def get_image_repl(
183174
if processor is None:
184175
processor = self.get_hf_processor()
185176

186-
image_token = processor.boi_token
177+
boi_token = processor.boi_token
187178

188179
num_crops = self.get_num_crops(
189180
image_width=image_width,
@@ -192,19 +183,21 @@ def get_image_repl(
192183
)
193184

194185
if num_crops == 0:
195-
image_text = image_token
186+
image_text = boi_token
196187
else:
197-
crops_image_tokens = " ".join(image_token
198-
for _ in range(num_crops))
188+
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
199189
image_text = (
200-
f"Here is the original image {image_token} and here are some "
190+
f"Here is the original image {boi_token} and here are some "
201191
f"crops to help you see better {crops_image_tokens}")
202192

203-
repl_full = image_text.replace(image_token,
193+
repl_full = image_text.replace(boi_token,
204194
processor.full_image_sequence)
205-
repl_features = repl_full.strip("\n")
206195

207-
return PromptUpdateDetails(full=repl_full, features=repl_features)
196+
tokenizer = processor.tokenizer
197+
vocab = tokenizer.get_vocab()
198+
image_token_id = vocab[tokenizer.image_token]
199+
200+
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
208201

209202
def get_num_image_tokens(
210203
self,
@@ -222,7 +215,7 @@ def get_num_image_tokens(
222215

223216
image_repl_tokens = encode_tokens(
224217
tokenizer,
225-
image_repl.features,
218+
image_repl.full,
226219
add_special_tokens=False,
227220
)
228221
return len(image_repl_tokens)
@@ -301,28 +294,6 @@ def _call_hf_processor(
301294
]
302295
hf_processor = self.info.get_hf_processor(**mm_kwargs)
303296

304-
image_repl_features = [
305-
self.info.get_image_repl(image_width=size.width,
306-
image_height=size.height,
307-
processor=hf_processor).features
308-
for size in image_sizes
309-
]
310-
311-
tokenizer = self.info.get_tokenizer()
312-
image_repls_feature_tokens = [
313-
tokenizer.encode(image_repl, add_special_tokens=False)
314-
for image_repl in image_repl_features
315-
]
316-
317-
vocab = tokenizer.get_vocab()
318-
image_token_id = vocab[tokenizer.image_token]
319-
320-
embed_is_patch = [
321-
torch.tensor(image_repl_tokens) == image_token_id
322-
for image_repl_tokens in image_repls_feature_tokens
323-
]
324-
processed_outputs["embed_is_patch"] = embed_is_patch
325-
326297
num_crops = [
327298
self.info.get_num_crops(image_width=size.width,
328299
image_height=size.height,
@@ -344,7 +315,6 @@ def _get_mm_fields_config(
344315
pixel_values=MultiModalFieldConfig.flat_from_sizes(
345316
"image", num_crops + 1),
346317
num_crops=MultiModalFieldConfig.batched("image"),
347-
embed_is_patch=MultiModalFieldConfig.batched("image"),
348318
)
349319

350320
def _get_prompt_updates(
@@ -454,6 +424,7 @@ def get_repl_toks(tok: int) -> list[int]:
454424
item_idx=p.item_idx,
455425
start_idx=repl_orig_idxs[p.start_idx],
456426
tokens=p.tokens,
427+
is_embed=p.is_embed,
457428
) for p in placeholders
458429
]
459430
for modality, placeholders in repls.items()
@@ -572,7 +543,6 @@ def _parse_and_validate_image_input(
572543
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
573544
pixel_values = kwargs.pop("pixel_values", None)
574545
num_crops = kwargs.pop("num_crops", None)
575-
embed_is_patch = kwargs.pop("embed_is_patch", None)
576546
image_embeds = kwargs.pop("image_embeds", None)
577547
assert image_embeds is None, "Gemma3 does not support image_embeds."
578548
if pixel_values is None:
@@ -586,19 +556,13 @@ def _parse_and_validate_image_input(
586556
raise ValueError("Incorrect type of num_crops. "
587557
f"Got type: {type(num_crops)}")
588558

589-
if not isinstance(embed_is_patch, (torch.Tensor, list)):
590-
raise ValueError("Incorrect type of embed_is_patch. "
591-
f"Got type: {type(embed_is_patch)}")
592-
593559
pixel_values = flatten_bn(pixel_values, concat=True)
594560
num_crops = flatten_bn(num_crops, concat=True)
595-
embed_is_patch = flatten_bn(embed_is_patch)
596561

597562
return Gemma3ImagePixelInputs(
598563
type="pixel_values",
599564
pixel_values=self._validate_pixel_values(pixel_values),
600565
num_patches=num_crops + 1,
601-
embed_is_patch=embed_is_patch,
602566
)
603567

604568
def _image_pixels_to_features(
@@ -635,12 +599,7 @@ def get_multimodal_embeddings(
635599
if image_input is None:
636600
return None
637601

638-
image_features = self._process_image_input(image_input)
639-
640-
return scatter_patch_features(
641-
image_features,
642-
image_input["embed_is_patch"],
643-
)
602+
return self._process_image_input(image_input)
644603

645604
def get_input_embeddings(
646605
self,
@@ -652,7 +611,7 @@ def get_input_embeddings(
652611
inputs_embeds = merge_multimodal_embeddings(
653612
input_ids,
654613
inputs_embeds,
655-
select_patch_features(multimodal_embeddings),
614+
multimodal_embeddings,
656615
self.config.image_token_index,
657616
)
658617
return inputs_embeds

Diff for: vllm/model_executor/models/h2ovl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def get_image_repl(
257257
repl_features = IMG_CONTEXT * feature_size
258258
repl_full = IMG_START + repl_features + IMG_END
259259

260-
return PromptUpdateDetails(full=repl_full, features=repl_features)
260+
return PromptUpdateDetails.select_token_text(repl_full, IMG_CONTEXT)
261261

262262
def resolve_min_max_num(
263263
self,

Diff for: vllm/model_executor/models/idefics3.py

+10-59
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
MultiModalDataItems,
4242
MultiModalFieldConfig,
4343
PromptReplacement, PromptUpdate,
44-
encode_tokens)
44+
PromptUpdateDetails, encode_tokens)
4545
# yapf: enable
4646
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
4747
from vllm.sequence import IntermediateTensors
@@ -54,7 +54,6 @@
5454
from .llama import LlamaModel
5555
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
5656
merge_multimodal_embeddings)
57-
from .vision import scatter_patch_features, select_patch_features
5857

5958

6059
class Idefics3ImagePixelInputs(TypedDict):
@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
6968
num_patches: torch.Tensor
7069
"""Shape: `(batch_size * num_images)`"""
7170

72-
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
73-
"""
74-
A boolean mask indicating which image embeddings correspond
75-
to patch tokens.
76-
77-
Shape: `(batch_size * num_images, num_embeds)`
78-
"""
79-
8071

8172
class Idefics3ImageEmbeddingInputs(TypedDict):
8273
type: Literal["image_embeds"]
@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
8677
`hidden_size` must match the hidden size of language model backbone.
8778
"""
8879

89-
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
90-
"""
91-
A boolean mask indicating which image embeddings correspond
92-
to patch tokens.
93-
94-
Shape: `(batch_size * num_images, num_embeds)`
95-
"""
96-
9780

9881
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
9982

@@ -364,28 +347,6 @@ def _call_hf_processor(
364347
]
365348
hf_processor = self.info.get_hf_processor(**mm_kwargs)
366349

367-
image_repl_features = [
368-
self.info.get_image_repl(image_width=size.width,
369-
image_height=size.height,
370-
processor=hf_processor)
371-
for size in image_sizes
372-
]
373-
374-
tokenizer = self.info.get_tokenizer()
375-
image_repls_feature_tokens = [
376-
tokenizer.encode(image_repl, add_special_tokens=False)
377-
for image_repl in image_repl_features
378-
]
379-
380-
vocab = tokenizer.get_vocab()
381-
image_token_id = vocab[hf_processor.image_token.content]
382-
383-
embed_is_patch = [
384-
torch.tensor(image_repl_tokens) == image_token_id
385-
for image_repl_tokens in image_repls_feature_tokens
386-
]
387-
processed_outputs["embed_is_patch"] = embed_is_patch
388-
389350
num_patches = [
390351
self.info.get_num_patches(
391352
image_width=size.width,
@@ -415,7 +376,6 @@ def _get_mm_fields_config(
415376
"image", num_patches),
416377
image_embeds=MultiModalFieldConfig.batched("image"),
417378
num_patches=MultiModalFieldConfig.batched("image"),
418-
embed_is_patch=MultiModalFieldConfig.batched("image"),
419379
)
420380

421381
def _get_prompt_updates(
@@ -427,17 +387,22 @@ def _get_prompt_updates(
427387
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
428388
image_token = hf_processor.image_token.content
429389

430-
def get_replacement_idefics3(item_idx: int) -> str:
390+
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
431391
images = mm_items.get_items("image", ImageProcessorItems)
432392

433393
image_size = images.get_image_size(item_idx)
434394

435-
return self.info.get_image_repl(
395+
image_repl = self.info.get_image_repl(
436396
image_width=image_size.width,
437397
image_height=image_size.height,
438398
processor=hf_processor,
439399
)
440400

401+
return PromptUpdateDetails.select_token_text(
402+
image_repl,
403+
embed_token_text=image_token,
404+
)
405+
441406
return [
442407
PromptReplacement(
443408
modality="image",
@@ -675,13 +640,6 @@ def _parse_and_validate_image_input(
675640
if pixel_values is None and image_embeds is None:
676641
return None
677642

678-
embed_is_patch = kwargs.pop("embed_is_patch")
679-
if not isinstance(embed_is_patch, (torch.Tensor, list)):
680-
raise ValueError("Incorrect type of embed_is_patch. "
681-
f"Got type: {type(embed_is_patch)}")
682-
683-
embed_is_patch = flatten_bn(embed_is_patch)
684-
685643
if image_embeds is not None:
686644
if not isinstance(image_embeds, (torch.Tensor, list)):
687645
raise ValueError("Incorrect type of image embeddings. "
@@ -690,7 +648,6 @@ def _parse_and_validate_image_input(
690648
return Idefics3ImageEmbeddingInputs(
691649
type="image_embeds",
692650
data=flatten_bn(image_embeds, concat=True),
693-
embed_is_patch=embed_is_patch,
694651
)
695652

696653
if pixel_values is not None:
@@ -718,7 +675,6 @@ def _parse_and_validate_image_input(
718675
pixel_values=self._validate_pixel_values(pixel_values),
719676
pixel_attention_mask=pixel_attention_mask,
720677
num_patches=num_patches,
721-
embed_is_patch=embed_is_patch,
722678
)
723679

724680
raise AssertionError("This line should be unreachable.")
@@ -754,12 +710,7 @@ def get_multimodal_embeddings(
754710
if image_input is None:
755711
return None
756712

757-
image_features = self._process_image_input(image_input)
758-
759-
return scatter_patch_features(
760-
image_features,
761-
image_input["embed_is_patch"],
762-
)
713+
return self._process_image_input(image_input)
763714

764715
def get_input_embeddings(
765716
self,
@@ -771,7 +722,7 @@ def get_input_embeddings(
771722
inputs_embeds = merge_multimodal_embeddings(
772723
input_ids,
773724
inputs_embeds,
774-
select_patch_features(multimodal_embeddings),
725+
multimodal_embeddings,
775726
self.config.image_token_id,
776727
)
777728
return inputs_embeds

0 commit comments

Comments
 (0)