Skip to content

Re-submit: Fix: Proper RGBA -> RGB conversion for PIL images. #18569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -257,7 +258,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
Expand Down
4 changes: 3 additions & 1 deletion examples/offline_inference/qwen2_5_omni/only_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser


Expand Down Expand Up @@ -45,7 +46,8 @@ def get_mixed_modalities_query() -> QueryResult:
"audio":
AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image":
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB"),
"video":
VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
Expand Down
5 changes: 3 additions & 2 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser


Expand Down Expand Up @@ -1096,8 +1097,8 @@ def get_multi_modal_input(args):
"""
if args.modality == "image":
# Input image and question
image = ImageAsset("cherry_blossom") \
.pil_image.convert("RGB")
image = convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB")
img_questions = [
"What is the content of this image?",
"Describe the content of this image in detail.",
Expand Down
6 changes: 4 additions & 2 deletions tests/models/multimodal/generation/test_interleaved.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode

models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]

Expand All @@ -26,8 +27,9 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None:
give the same result.
"""

image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image_stop = ImageAsset("stop_sign").pil_image.convert("RGB")
image_cherry = convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB")
image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB")
images = [image_cherry, image_stop]
video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays

Expand Down
4 changes: 2 additions & 2 deletions tests/models/multimodal/generation/test_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.image import convert_image_mode, rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs

Expand Down Expand Up @@ -267,7 +267,7 @@ def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,

# use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=None)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")

inputs_vision_speech = [
(
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.multimodal.image import convert_image_mode

from ..utils import create_new_process_for_each_test

Expand Down Expand Up @@ -58,7 +59,7 @@ def test_oot_registration_embedding(
assert all(v == 0 for v in output.outputs.embedding)


image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")


@create_new_process_for_each_test()
Expand Down
Binary file added tests/multimodal/assets/rgba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions tests/multimodal/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path

import numpy as np
from PIL import Image, ImageChops

from vllm.multimodal.image import convert_image_mode

ASSETS_DIR = Path(__file__).parent / "assets"
assert ASSETS_DIR.exists()


def test_rgb_to_rgb():
# Start with an RGB image.
original_image = Image.open(ASSETS_DIR / "image1.png").convert("RGB")
converted_image = convert_image_mode(original_image, "RGB")

# RGB to RGB should be a no-op.
diff = ImageChops.difference(original_image, converted_image)
assert diff.getbbox() is None


def test_rgba_to_rgb():
original_image = Image.open(ASSETS_DIR / "rgba.png")
original_image_numpy = np.array(original_image)

converted_image = convert_image_mode(original_image, "RGB")
converted_image_numpy = np.array(converted_image)

for i in range(original_image_numpy.shape[0]):
for j in range(original_image_numpy.shape[1]):
# Verify that all transparent pixels are converted to white.
if original_image_numpy[i][j][3] == 0:
assert converted_image_numpy[i][j][0] == 255
assert converted_image_numpy[i][j][1] == 255
assert converted_image_numpy[i][j][2] == 255
3 changes: 2 additions & 1 deletion tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from PIL import Image, ImageChops

from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata)
Expand Down Expand Up @@ -53,7 +54,7 @@ def get_supported_suffixes() -> tuple[str, ...]:


def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()


@pytest.mark.asyncio
Expand Down
4 changes: 2 additions & 2 deletions vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT.
"""

import base64
import io
import json
Expand All @@ -33,6 +32,7 @@
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -259,7 +259,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes']))
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
Expand Down Expand Up @@ -77,7 +78,7 @@ class InternVLImageEmbeddingInputs(TypedDict):
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/skyworkr1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
Expand Down Expand Up @@ -78,7 +79,7 @@ class SkyworkR1VImageEmbeddingInputs(TypedDict):
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
Expand Down
4 changes: 3 additions & 1 deletion vllm/multimodal/hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PIL import Image

from vllm.logger import init_logger
from vllm.multimodal.image import convert_image_mode

if TYPE_CHECKING:
from vllm.inputs import TokensPrompt
Expand All @@ -35,7 +36,8 @@ def serialize_item(cls, obj: object) -> bytes:
return np.array(obj).tobytes()

if isinstance(obj, Image.Image):
return cls.item_to_bytes("image", np.array(obj.convert("RGBA")))
return cls.item_to_bytes("image",
np.array(convert_image_mode(obj, "RGBA")))
if isinstance(obj, torch.Tensor):
return cls.item_to_bytes("tensor", obj.numpy())
if isinstance(obj, np.ndarray):
Expand Down
25 changes: 22 additions & 3 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ def rescale_image_size(image: Image.Image,
return image


# TODO: Support customizable background color to fill in.
def rgba_to_rgb(
image: Image.Image, background_color=(255, 255, 255)) -> Image.Image:
"""Convert an RGBA image to RGB with filled background color."""
assert image.mode == "RGBA"
converted = Image.new("RGB", image.size, background_color)
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return converted


def convert_image_mode(image: Image.Image, to_mode: str):
if image.mode == to_mode:
return image
elif image.mode == "RGBA" and to_mode == "RGB":
return rgba_to_rgb(image)
else:
return image.convert(to_mode)


class ImageMediaIO(MediaIO[Image.Image]):

def __init__(self, *, image_mode: str = "RGB") -> None:
Expand All @@ -32,15 +51,15 @@ def __init__(self, *, image_mode: str = "RGB") -> None:
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return image.convert(self.image_mode)
return convert_image_mode(image, self.image_mode)

def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data))

def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return image.convert(self.image_mode)
return convert_image_mode(image, self.image_mode)

def encode_base64(
self,
Expand All @@ -51,7 +70,7 @@ def encode_base64(
image = media

with BytesIO() as buffer:
image = image.convert(self.image_mode)
image = convert_image_mode(image, self.image_mode)
image.save(buffer, image_format)
data = buffer.getvalue()

Expand Down
6 changes: 4 additions & 2 deletions vllm/transformers_utils/processors/ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput

from vllm.multimodal.image import convert_image_mode

__all__ = ['OvisProcessor']
IGNORE_ID = -100

Expand Down Expand Up @@ -361,8 +363,8 @@ def _get_best_grid(img, side):
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]

if convert_to_rgb and image.mode != 'RGB':
image = image.convert('RGB')
if convert_to_rgb:
image = convert_image_mode(image, 'RGB')


sides = self.get_image_size()
Expand Down