Skip to content

Commit a115ac4

Browse files
[VLM] Move supported limits and max tokens to merged multi-modal processor (#11669)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent 7300144 commit a115ac4

File tree

16 files changed

+340
-350
lines changed

16 files changed

+340
-350
lines changed

tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
from transformers import AutoTokenizer
66

7-
from vllm.inputs import InputContext, InputProcessingContext
7+
from vllm.inputs import InputProcessingContext
88
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
99

1010
from .....conftest import _ImageAssets
@@ -20,42 +20,6 @@ def processor_for_phi3v():
2020
return Phi3VMultiModalProcessor
2121

2222

23-
@pytest.fixture()
24-
def get_max_phi3v_image_tokens():
25-
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
26-
return get_max_phi3v_image_tokens
27-
28-
29-
@pytest.mark.parametrize("model", models)
30-
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
31-
(4, 781),
32-
(16, 2653),
33-
])
34-
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
35-
num_crops: int, expected_max_tokens: int):
36-
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
37-
# NOTE: mm_processor_kwargs on the context in this test is unused, since
38-
# this is testing the mapper directly. In practice, the processor kwargs
39-
# are wrapped in a closure when calling the max tokens func. We explicitly
40-
# do NOT use the mm_processor_kwargs in the model context here to ensure
41-
# that the max image tokens implementation is referencing a mix of the
42-
# kwargs to the function and the original mm_processor_kwargs in case
43-
# values are somehow updated and end up in a bad state.
44-
ctx = build_model_context(
45-
model_name=model,
46-
tokenizer_name=model,
47-
trust_remote_code=True,
48-
mm_processor_kwargs=None,
49-
)
50-
51-
actual_max_tokens = get_max_phi3v_image_tokens(
52-
InputContext(ctx.model_config),
53-
num_crops=num_crops,
54-
)
55-
56-
assert expected_max_tokens == actual_max_tokens
57-
58-
5923
@pytest.mark.parametrize("model", models)
6024
@pytest.mark.parametrize(
6125
"num_crops,expected_toks_per_img",
@@ -77,6 +41,7 @@ def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
7741
model_name=model,
7842
tokenizer_name=model,
7943
trust_remote_code=True,
44+
limit_mm_per_prompt={"image": num_imgs},
8045
)
8146
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
8247
ctx = InputProcessingContext(ctx.model_config, tokenizer)

tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from transformers import AutoTokenizer
55

6-
from vllm.inputs import InputContext, InputProcessingContext
6+
from vllm.inputs import InputProcessingContext
77

88
from .....conftest import _ImageAssets
99
from ....utils import build_model_context
@@ -22,39 +22,6 @@ def processor_for_qwen2_vl():
2222
return Qwen2VLMultiModalProcessor
2323

2424

25-
@pytest.fixture()
26-
def get_max_qwen2_vl_image_tokens():
27-
from vllm.model_executor.models.qwen2_vl import (
28-
get_max_qwen2_vl_image_tokens)
29-
return get_max_qwen2_vl_image_tokens
30-
31-
32-
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
33-
({}, 16384),
34-
({
35-
MIN_PIXELS: 64**2,
36-
MAX_PIXELS: 512**2
37-
}, 324),
38-
])
39-
@pytest.mark.parametrize("model", [MODEL])
40-
def test_qwen2_vl_max_image_tokens(
41-
get_max_qwen2_vl_image_tokens,
42-
model: str,
43-
mm_processor_kwargs: Dict[str, Any],
44-
expected_max_tokens: int,
45-
):
46-
"""Ensure that the max token calc handles min/max pixels properly."""
47-
ctx = build_model_context(
48-
model_name=model,
49-
tokenizer_name=model,
50-
mm_processor_kwargs=None,
51-
)
52-
53-
actual_max_tokens = get_max_qwen2_vl_image_tokens(
54-
InputContext(ctx.model_config), **mm_processor_kwargs)
55-
assert actual_max_tokens == expected_max_tokens
56-
57-
5825
@pytest.mark.parametrize(
5926
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
6027
({}, 1426, (5704, 1176)),
@@ -82,6 +49,7 @@ def test_processor_override(
8249
model_name=model,
8350
tokenizer_name=model,
8451
mm_processor_kwargs=None,
52+
limit_mm_per_prompt={"image": num_imgs},
8553
)
8654
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
8755
ctx = InputProcessingContext(ctx.model_config, tokenizer)

tests/multimodal/test_processing.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ def _test_processing_cache_correctness(
538538
else:
539539
hf_overrides = {}
540540

541+
limit_mm_per_prompt = {
542+
modality: 3 if supports_multi else 1
543+
for modality, supports_multi in modalities.items()
544+
}
545+
541546
model_config = ModelConfig(
542547
model_id,
543548
task="auto",
@@ -548,6 +553,7 @@ def _test_processing_cache_correctness(
548553
dtype="float16",
549554
revision=None,
550555
hf_overrides=hf_overrides,
556+
limit_mm_per_prompt=limit_mm_per_prompt,
551557
)
552558
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
553559

@@ -580,18 +586,14 @@ def _test_processing_cache_correctness(
580586
min_wh=128,
581587
max_wh=256),
582588
"audio":
583-
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
584-
}
585-
input_max_count = {
586-
modality: 3 if supports_multi else 1
587-
for modality, supports_multi in modalities.items()
589+
partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000),
588590
}
589591

590592
for batch_idx in range(num_batches):
591593
mm_data = {
592594
k:
593595
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
594-
for _ in range(rng.randint(input_max_count[k]))]
596+
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
595597
for k in modalities
596598
}
597599

vllm/inputs/registry.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,7 @@ def dummy_data_for_profiling(
331331
trust_remote_code=model_config.trust_remote_code,
332332
)
333333
processor = mm_registry.create_processor(model_config, tokenizer)
334-
335-
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
336-
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
337-
model_config)
338-
339-
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
340-
mm_max_tokens)
334+
dummy_data = processor.get_dummy_data(seq_len)
341335
else:
342336
model_cls, _ = get_model_architecture(model_config)
343337
if is_encoder_data:

vllm/model_executor/models/aria.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
2-
Union)
1+
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
2+
TypedDict, Union)
33

44
import torch
55
import torch.nn as nn
@@ -9,7 +9,6 @@
99
from vllm.attention import AttentionMetadata
1010
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
1111
from vllm.distributed import get_tensor_model_parallel_rank
12-
from vllm.inputs import InputContext
1312
from vllm.model_executor.layers.activation import get_act_fn
1413
from vllm.model_executor.layers.fused_moe import FusedMoE
1514
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -87,8 +86,8 @@ def __init__(
8786
def forward(
8887
self,
8988
pixel_values: torch.Tensor,
90-
pixel_mask: Optional[torch.BoolTensor] = None,
91-
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]:
89+
pixel_mask: Optional[torch.Tensor] = None,
90+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
9291
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
9392

9493
vit_oup = self.vision_model(
@@ -100,7 +99,8 @@ def forward(
10099

101100
return vit_oup, image_atts
102101

103-
def _create_patch_attention_mask(self, pixel_mask):
102+
def _create_patch_attention_mask(
103+
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
104104
if pixel_mask is None:
105105
return None
106106

@@ -115,7 +115,8 @@ def _create_patch_attention_mask(self, pixel_mask):
115115
)
116116
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
117117

118-
def _create_image_attention_mask(self, patch_attention_mask):
118+
def _create_image_attention_mask(
119+
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
119120
if patch_attention_mask is None:
120121
return None
121122

@@ -125,13 +126,13 @@ def _create_image_attention_mask(self, patch_attention_mask):
125126

126127
class FFN(nn.Module):
127128

128-
def __init__(self, embed_dim, ff_dim, output_dim):
129+
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
129130
super().__init__()
130131
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
131132
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
132133
self.act = get_act_fn("gelu_new")
133134

134-
def forward(self, hidden_states):
135+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
135136
hidden_states, _ = self.linear_in(hidden_states)
136137
hidden_states = self.act(hidden_states)
137138
hidden_states, _ = self.linear_out(hidden_states)
@@ -140,7 +141,7 @@ def forward(self, hidden_states):
140141

141142
class CrossAttention(nn.Module):
142143

143-
def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
144+
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
144145
super().__init__()
145146
self.num_heads = num_heads
146147
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
@@ -149,12 +150,16 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
149150

150151
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
151152
self.linear = nn.Linear(embed_dim, embed_dim)
152-
self.dropout = nn.Dropout(drop_out_rate)
153153

154154
self.layer_norm = nn.LayerNorm(embed_dim)
155155
self.ln_kv = nn.LayerNorm(kv_dim)
156156

157-
def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
157+
def forward(
158+
self,
159+
x: torch.Tensor,
160+
hidden_states: torch.Tensor,
161+
attn_mask: Optional[torch.Tensor] = None,
162+
) -> torch.Tensor:
158163
normed_hidden_states = self.layer_norm(hidden_states)
159164
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
160165

@@ -169,11 +174,7 @@ def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
169174

170175
attn_output = attn_output.permute(1, 0, 2)
171176

172-
if add_residual:
173-
attn_output = hidden_states + self.dropout(
174-
self.linear(attn_output))
175-
else:
176-
attn_output = self.dropout(self.linear(attn_output))
177+
attn_output = self.linear(attn_output)
177178

178179
return attn_output
179180

@@ -201,14 +202,14 @@ class AriaProjector(nn.Module):
201202

202203
def __init__(
203204
self,
204-
patch_to_query_dict,
205-
embed_dim,
206-
num_heads,
207-
kv_dim,
208-
ff_dim,
209-
output_dim,
210-
norm_layer=nn.LayerNorm,
211-
):
205+
patch_to_query_dict: dict[int, int],
206+
embed_dim: int,
207+
num_heads: int,
208+
kv_dim: int,
209+
ff_dim: int,
210+
output_dim: int,
211+
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
212+
) -> None:
212213
super().__init__()
213214
self.patch_to_query_dict = patch_to_query_dict
214215
self.embed_dim = embed_dim
@@ -224,7 +225,11 @@ def __init__(
224225
self.ln_ffn = norm_layer(embed_dim)
225226
self.ffn = FFN(embed_dim, ff_dim, output_dim)
226227

227-
def forward(self, x, attn_mask=None):
228+
def forward(
229+
self,
230+
x: torch.Tensor,
231+
attn_mask: Optional[torch.Tensor] = None,
232+
) -> torch.Tensor:
228233
bs = x.shape[0]
229234
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
230235

@@ -442,12 +447,17 @@ def build_mm_projector(config: PretrainedConfig):
442447
)
443448

444449

445-
def get_max_aria_image_tokens(ctx: InputContext):
446-
hf_config = ctx.get_hf_config()
447-
return max(hf_config.projector_patch_to_query_dict.values())
450+
class AriaMultiModalProcessor(BaseMultiModalProcessor):
451+
452+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
453+
return {"image": None}
448454

455+
def _get_num_image_tokens(self) -> int:
456+
hf_config = self.ctx.get_hf_config()
457+
return max(hf_config.projector_patch_to_query_dict.values())
449458

450-
class AriaMultiModalProcessor(BaseMultiModalProcessor):
459+
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
460+
return {"image": self._get_num_image_tokens()}
451461

452462
def _get_mm_fields_config(
453463
self,
@@ -468,13 +478,13 @@ def _get_prompt_replacements(
468478
hf_config = self.ctx.get_hf_config()
469479
image_token_id = hf_config.image_token_index
470480

471-
max_image_tokens = get_max_aria_image_tokens(self.ctx)
481+
num_image_tokens = self._get_num_image_tokens()
472482

473483
return [
474484
PromptReplacement(
475485
modality="image",
476486
target=[image_token_id],
477-
replacement=[image_token_id] * max_image_tokens,
487+
replacement=[image_token_id] * num_image_tokens,
478488
)
479489
]
480490

@@ -504,7 +514,6 @@ def _get_dummy_mm_inputs(
504514
)
505515

506516

507-
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
508517
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
509518
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
510519
"""

0 commit comments

Comments
 (0)