Skip to content

[Bugfix] Fix profiling dummy data for Pixtral #18677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 105 additions & 156 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens)

from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
Expand All @@ -28,7 +28,6 @@ def _test_processing_correctness(
hit_rate: float,
num_batches: int,
simplify_rate: float,
ignore_mm_keys: Optional[set[str]] = None,
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
Expand Down Expand Up @@ -99,10 +98,23 @@ def _test_processing_correctness(
}

mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text

# Mistral chat outputs tokens directly, rather than text prompts
if isinstance(tokenizer, MistralTokenizer):
images = mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=""),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
prompt = res.tokens
else:
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt

# Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate:
Expand All @@ -112,124 +124,66 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]

if isinstance(tokenizer, MistralTokenizer):
_test_processing_correctness_mistral(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)


def _test_processing_correctness_hf(
_test_processing_correctness_one(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
)


# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"mllama": False,
"ovis": False,
"ultravox": False,
"whisper": False,
}

_IGNORE_MM_KEYS = {
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
"ultravox": {"audio_features"},
}


def _test_processing_correctness_one(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt: str,
tokenizer: AnyTokenizer,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
"whisper"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
model_type = model_config.hf_config.model_type
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())

if isinstance(prompt, str):
text_prompt = prompt
token_prompt = encode_tokens(
tokenizer,
prompt,
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type),
)
else:
token_prompt = tokenizer.encode(prompt)

baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

_assert_inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
# Mistral does not support decode_tokens with skip_special_tokens=False
text_prompt = None
token_prompt = prompt

baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

_assert_inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)

cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

_assert_inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)


def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]

request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens

# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
Expand All @@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
)

if text_prompt is not None:
baseline_text_result = baseline_processor.apply(
text_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_text_result = cached_processor.apply(
text_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

_assert_inputs_equal(
baseline_text_result,
cached_text_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
)

_assert_inputs_equal(
baseline_text_result,
baseline_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
f"{token_prompt=}, {mm_data=})",
)

_assert_inputs_equal(
cached_text_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
f"{token_prompt=}, {mm_data=})",
)


# yapf: disable
@pytest.mark.parametrize("model_id", [
Expand Down Expand Up @@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
"AIDC-AI/Ovis2-1B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-3.5-vision-instruct",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
Expand All @@ -303,41 +293,6 @@ def test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
ignore_mm_keys = None
if 'ultravox' in model_id:
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
ignore_mm_keys = {"audio_features"}

_test_processing_correctness(
model_id,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
ignore_mm_keys=ignore_mm_keys,
)


# yapf: disable
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness_phi3v(
model_id: str,
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
# HACK - this is an attempted workaround for the following bug
# https://github.com/huggingface/transformers/issues/34307
from transformers import AutoImageProcessor # noqa: F401
from transformers import AutoProcessor # noqa: F401

AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

_test_processing_correctness(
model_id,
hit_rate=hit_rate,
Expand All @@ -356,16 +311,10 @@ def _assert_inputs_equal(
if ignore_mm_keys is None:
ignore_mm_keys = set()

if msg is None:
assert "mm_kwargs" in a and "mm_kwargs" in b
else:
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
assert "mm_kwargs" in a and "mm_kwargs" in b, msg

for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None)

if msg is None:
assert a == b
else:
assert a == b, msg
assert a == b, msg
2 changes: 1 addition & 1 deletion tests/models/multimodal/processing/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_profiling(
] * max_num_seqs

mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt_text,
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
Expand Down
9 changes: 5 additions & 4 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION

from vllm.config import TokenizerMode


@dataclass(frozen=True)
class _HfExamplesInfo:
Expand All @@ -20,7 +22,7 @@ class _HfExamplesInfo:
tokenizer: Optional[str] = None
"""Set the tokenizer to load for this architecture."""

tokenizer_mode: str = "auto"
tokenizer_mode: TokenizerMode = "auto"
"""Set the tokenizer type for this architecture."""

speculative_model: Optional[str] = None
Expand Down Expand Up @@ -388,8 +390,7 @@ def check_available_online(
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral",
v0_only=True),
tokenizer_mode="mistral"),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
trust_remote_code=True,
Expand All @@ -400,7 +401,7 @@ def check_available_online(
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
min_transformers_version="4.52"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501
min_transformers_version="4.52"),
min_transformers_version="4.52"),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
Expand Down
Loading