Skip to content

Commit edf309e

Browse files
authored
[VLM] Support multimodal inputs for Florence-2 models (#13320)
1 parent 788f284 commit edf309e

File tree

13 files changed

+1078
-117
lines changed

13 files changed

+1078
-117
lines changed

docs/source/models/supported_models.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ
715715
*
716716
* ✅︎
717717
* ✅︎
718+
- * `Florence2ForConditionalGeneration`
719+
* Florence-2
720+
* T + I
721+
* `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc.
722+
*
723+
*
724+
*
718725
- * `FuyuForCausalLM`
719726
* Fuyu
720727
* T + I
Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,45 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
'''
2+
"""
33
Demonstrate prompting of text-to-text
44
encoder/decoder models, specifically Florence-2
5-
'''
5+
"""
66
# TODO(Isotr0py):
77
# Move to offline_inference/vision_language.py
88
# after porting vision backbone
99
from vllm import LLM, SamplingParams
10-
11-
dtype = "float"
10+
from vllm.assets.image import ImageAsset
1211

1312
# Create a Florence-2 encoder/decoder model instance
1413
llm = LLM(
15-
model="microsoft/Florence-2-base",
16-
tokenizer="facebook/bart-base",
17-
dtype=dtype,
14+
model="microsoft/Florence-2-large",
15+
tokenizer="facebook/bart-large",
16+
max_num_seqs=8,
1817
trust_remote_code=True,
1918
)
2019

2120
prompts = [
22-
"<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
23-
"<CAPTION_TO_PHRASE_GROUNDING>", "<OD>", "<DENSE_REGION_CAPTION>",
24-
"<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
21+
{ # implicit prompt with task token
22+
"prompt": "<DETAILED_CAPTION>",
23+
"multi_modal_data": {
24+
"image": ImageAsset("stop_sign").pil_image
25+
},
26+
},
27+
{ # explicit encoder/decoder prompt
28+
"encoder_prompt": {
29+
"prompt": "Describe in detail what is shown in the image.",
30+
"multi_modal_data": {
31+
"image": ImageAsset("cherry_blossom").pil_image
32+
},
33+
},
34+
"decoder_prompt": "",
35+
},
2536
]
2637
# Create a sampling params object.
2738
sampling_params = SamplingParams(
2839
temperature=0,
2940
top_p=1.0,
3041
min_tokens=0,
31-
max_tokens=20,
42+
max_tokens=128,
3243
)
3344

3445
# Generate output tokens from the prompts. The output is a list of
@@ -38,9 +49,5 @@
3849

3950
# Print the outputs.
4051
for output in outputs:
41-
prompt = output.prompt
42-
encoder_prompt = output.encoder_prompt
4352
generated_text = output.outputs[0].text
44-
print(f"Encoder prompt: {encoder_prompt!r}, "
45-
f"Decoder prompt: {prompt!r}, "
46-
f"Generated text: {generated_text!r}")
53+
print(f"Generated text: {generated_text!r}")

examples/offline_inference/vision_language.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str):
8282
return llm, prompt, stop_token_ids
8383

8484

85+
# Florence2
86+
def run_florence2(question: str, modality: str):
87+
assert modality == "image"
88+
89+
llm = LLM(model="microsoft/Florence-2-large",
90+
tokenizer="facebook/bart-large",
91+
max_num_seqs=8,
92+
trust_remote_code=True,
93+
dtype="bfloat16",
94+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
95+
96+
prompt = "<MORE_DETAILED_CAPTION>"
97+
stop_token_ids = None
98+
return llm, prompt, stop_token_ids
99+
100+
85101
# Fuyu
86102
def run_fuyu(question: str, modality: str):
87103
assert modality == "image"
@@ -571,6 +587,7 @@ def run_qwen2_5_vl(question: str, modality: str):
571587
"blip-2": run_blip2,
572588
"chameleon": run_chameleon,
573589
"deepseek_vl_v2": run_deepseek_vl2,
590+
"florence2": run_florence2,
574591
"fuyu": run_fuyu,
575592
"glm4v": run_glm4v,
576593
"h2ovl_chat": run_h2ovl,

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,8 @@ def generate_encoder_decoder_greedy_logprobs_limit(
600600
if images is not None and images[i] is not None:
601601
processor_kwargs["images"] = images[i]
602602

603-
encoder_input_ids = self.wrap_device(
604-
self.processor(**processor_kwargs).input_ids,
603+
encoder_inputs = self.wrap_device(
604+
self.processor(**processor_kwargs),
605605
device=self.model.device.type,
606606
)
607607

@@ -615,13 +615,13 @@ def generate_encoder_decoder_greedy_logprobs_limit(
615615
)
616616

617617
output = self.model.generate(
618-
encoder_input_ids,
619618
decoder_input_ids=decoder_input_ids,
620619
use_cache=True,
621620
do_sample=False,
622621
max_new_tokens=max_tokens,
623622
output_hidden_states=True,
624623
return_dict_in_generate=True,
624+
**encoder_inputs,
625625
**kwargs,
626626
)
627627

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ....utils import RemoteOpenAIServer
1616
from ...utils import check_logprobs_close
1717

18-
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
18+
MODEL_NAME = "fixie-ai/ultravox-v0_4"
1919

2020
AudioTuple = Tuple[np.ndarray, int]
2121

@@ -187,7 +187,7 @@ def run_multi_audio_test(
187187

188188

189189
@pytest.mark.core_model
190-
@pytest.mark.parametrize("dtype", ["half"])
190+
@pytest.mark.parametrize("dtype", ["bfloat16"])
191191
@pytest.mark.parametrize("max_tokens", [128])
192192
@pytest.mark.parametrize("num_logprobs", [5])
193193
@pytest.mark.parametrize("vllm_kwargs", [

tests/models/encoder_decoder/vision_language/test_florence2.py

Lines changed: 88 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,59 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from functools import partial
4-
from typing import List, Optional, Tuple, Type
3+
from typing import Optional, Type
54

65
import pytest
76
from PIL import Image
87

9-
from vllm.inputs.data import ExplicitEncoderDecoderPrompt
8+
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
9+
from vllm.multimodal.image import rescale_image_size
1010
from vllm.sequence import SampleLogprobs
1111

12-
from ....conftest import HfRunner, VllmRunner
12+
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
1313
from ...utils import check_logprobs_close
1414

15-
Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
16-
decoder_prompt=None,
17-
mm_processor_kwargs=None)
18-
1915
MODELS = ["microsoft/Florence-2-base"]
2016
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
2117
# Therefore, we borrow the BartTokenizer from the original Bart model
2218
TOKENIZER = "facebook/bart-base"
23-
PROMPTS = [
24-
Florence2Prompt(encoder_prompt="<CAPTION>"),
25-
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"),
26-
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"),
27-
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"),
28-
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"),
29-
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"),
30-
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"),
31-
Florence2Prompt(encoder_prompt="<OCR>"),
32-
Florence2Prompt(encoder_prompt="<OD>"),
33-
]
19+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
20+
"stop_sign":
21+
"<CAPTION>", # special task token
22+
"cherry_blossom":
23+
"Describe in detail what is shown in the image.",
24+
})
25+
3426

27+
def get_hf_images_prompts(
28+
prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]],
29+
) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]:
30+
prompts, images = [], []
31+
for prompt in prompts_:
32+
encoder_prompt = prompt["encoder_prompt"]
33+
prompts.append(
34+
ExplicitEncoderDecoderPrompt(
35+
encoder_prompt=encoder_prompt["prompt"],
36+
decoder_prompt=None,
37+
))
38+
images.append(encoder_prompt["multi_modal_data"]["image"])
39+
return prompts, images
3540

36-
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
37-
Optional[SampleLogprobs]], ):
38-
"""Sanitize vllm output to be comparable with hf output."""
39-
output_ids, output_str, out_logprobs = vllm_output
4041

41-
hf_output_str = "</s><s>" + output_str + "</s>"
42+
def hf_to_vllm_output(hf_output: tuple[list[int], str,
43+
Optional[SampleLogprobs]]):
44+
"""Sanitize hf output to be comparable with vllm output."""
45+
output_ids, output_str, out_logprobs = hf_output
4246

43-
return output_ids, hf_output_str, out_logprobs
47+
output_str = output_str.replace("</s>", "").replace("<s>", "")
48+
output_ids = [ids for ids in output_ids if ids not in [0, 2]]
49+
50+
return output_ids, output_str, out_logprobs
4451

4552

4653
def run_test(
4754
hf_runner: Type[HfRunner],
4855
vllm_runner: Type[VllmRunner],
49-
prompts: List[ExplicitEncoderDecoderPrompt],
56+
inputs: list[list[ExplicitEncoderDecoderPrompt]],
5057
model: str,
5158
*,
5259
dtype: str,
@@ -56,46 +63,76 @@ def run_test(
5663
distributed_executor_backend: Optional[str] = None,
5764
) -> None:
5865
with vllm_runner(model,
66+
max_num_seqs=8,
5967
tokenizer_name=TOKENIZER,
6068
dtype=dtype,
6169
tensor_parallel_size=tensor_parallel_size,
6270
distributed_executor_backend=distributed_executor_backend,
6371
enforce_eager=True) as vllm_model:
64-
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
65-
prompts, max_tokens, num_logprobs)
72+
vllm_outputs_per_case = [
73+
vllm_model.generate_encoder_decoder_greedy_logprobs(
74+
prompts, max_tokens, num_logprobs=num_logprobs)
75+
for prompts in inputs
76+
]
77+
78+
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
6679

67-
# Florence-2 processors require image inputs
68-
dummy_image = Image.new(mode="RGB", size=(2, 2))
6980
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
7081
hf_model.model.get_output_embeddings = lambda: \
7182
hf_model.model.language_model.lm_head
72-
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
73-
prompts,
74-
max_tokens,
75-
num_logprobs,
76-
images=[dummy_image] * len(prompts),
77-
))
78-
79-
check_logprobs_close(
80-
outputs_0_lst=hf_outputs,
81-
outputs_1_lst=[
82-
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
83-
],
84-
name_0="hf",
85-
name_1="vllm",
86-
)
87-
88-
83+
hf_outputs_per_case = [
84+
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
85+
prompts, max_tokens, num_logprobs=num_logprobs, images=images)
86+
for prompts, images in hf_inputs
87+
]
88+
89+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
90+
vllm_outputs_per_case):
91+
check_logprobs_close(
92+
outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs],
93+
outputs_1_lst=vllm_outputs,
94+
name_0="hf",
95+
name_1="vllm",
96+
)
97+
98+
99+
@pytest.mark.core_model
89100
@pytest.mark.parametrize("model", MODELS)
90-
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
101+
@pytest.mark.parametrize(
102+
"size_factors",
103+
[
104+
# No image
105+
[],
106+
# Single-scale
107+
[1.0],
108+
# Single-scale, batched
109+
[1.0, 1.0, 1.0],
110+
# Multi-scale
111+
[0.25, 0.5, 1.0],
112+
],
113+
)
114+
@pytest.mark.parametrize("dtype", ["float"])
91115
@pytest.mark.parametrize("max_tokens", [64])
92116
@pytest.mark.parametrize("num_logprobs", [5])
93-
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
94-
num_logprobs) -> None:
117+
def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner],
118+
image_assets: _ImageAssets, model: str,
119+
size_factors: list[int], dtype: str, max_tokens: int,
120+
num_logprobs: int) -> None:
121+
images = [asset.pil_image for asset in image_assets]
122+
123+
inputs_per_image = [[
124+
ExplicitEncoderDecoderPrompt(
125+
encoder_prompt=TextPrompt(
126+
prompt=prompt,
127+
multi_modal_data={"image": rescale_image_size(image, factor)}),
128+
decoder_prompt=None,
129+
) for factor in size_factors
130+
] for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
131+
95132
run_test(
96133
hf_runner,
97134
vllm_runner,
98-
PROMPTS,
135+
inputs_per_image,
99136
model,
100137
dtype=dtype,
101138
max_tokens=max_tokens,

tests/models/multimodal/processing/test_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def _test_processing_correctness(
2929
model_config = ModelConfig(
3030
model_id,
3131
task="auto",
32-
tokenizer=model_id,
33-
tokenizer_mode="auto",
32+
tokenizer=model_info.tokenizer or model_id,
33+
tokenizer_mode=model_info.tokenizer_mode,
3434
trust_remote_code=model_info.trust_remote_code,
3535
seed=0,
3636
dtype="float16",
@@ -151,6 +151,7 @@ def _test_processing_correctness(
151151
"Salesforce/blip2-opt-2.7b",
152152
"facebook/chameleon-7b",
153153
"deepseek-ai/deepseek-vl2-tiny",
154+
"microsoft/Florence-2-base",
154155
"adept/fuyu-8b",
155156
"THUDM/glm-4v-9b",
156157
"h2oai/h2ovl-mississippi-800m",

tests/models/registry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,6 @@ def check_available_online(
193193
# [Encoder-decoder]
194194
"BartModel": _HfExamplesInfo("facebook/bart-base"),
195195
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
196-
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
197-
# Therefore, we borrow the BartTokenizer from the original Bart model
198-
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
199-
tokenizer="facebook/bart-base",
200-
trust_remote_code=True), # noqa: E501
201196
}
202197

203198
_EMBEDDING_EXAMPLE_MODELS = {
@@ -288,6 +283,11 @@ def check_available_online(
288283
extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501
289284
trust_remote_code=True),
290285
# [Encoder-decoder]
286+
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
287+
# Therefore, we borrow the BartTokenizer from the original Bart model
288+
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
289+
tokenizer="facebook/bart-base",
290+
trust_remote_code=True), # noqa: E501
291291
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
292292
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
293293
}

0 commit comments

Comments
 (0)