Skip to content

Commit d669fe7

Browse files
heheda12345simon-moCatherineSueywang96
authored andcommitted
[Model] Add support for the multi-modal Llama 3.2 model (vllm-project#8811)
Co-authored-by: simon-mo <[email protected]> Co-authored-by: Chang Su <[email protected]> Co-authored-by: Simon Mo <[email protected]> Co-authored-by: Roger Wang <[email protected]> Co-authored-by: Roger Wang <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
1 parent e52ad9c commit d669fe7

File tree

24 files changed

+1647
-45
lines changed

24 files changed

+1647
-45
lines changed

docs/source/models/supported_models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ Multimodal Language Models
254254
- Image\ :sup:`+`
255255
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
256256
-
257+
* - :code:`MllamaForConditionalGeneration`
258+
- Llama 3.2
259+
- Image
260+
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
261+
-
257262
* - :code:`PaliGemmaForConditionalGeneration`
258263
- PaliGemma
259264
- Image\ :sup:`E`

examples/offline_inference_vision_language.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,29 @@ def run_qwen2_vl(question, modality):
242242
return llm, prompt, stop_token_ids
243243

244244

245+
# LLama
246+
def run_mllama(question, modality):
247+
assert modality == "image"
248+
249+
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
250+
251+
# Note: The default setting of max_num_seqs (256) and
252+
# max_model_len (131072) for this model may cause OOM.
253+
# You may lower either to run this example on lower-end GPUs.
254+
255+
# The configuration below has been confirmed to launch on a
256+
# single H100 GPU.
257+
llm = LLM(
258+
model=model_name,
259+
max_num_seqs=16,
260+
enforce_eager=True,
261+
)
262+
263+
prompt = f"<|image|><|begin_of_text|>{question}"
264+
stop_token_ids = None
265+
return llm, prompt, stop_token_ids
266+
267+
245268
model_example_map = {
246269
"llava": run_llava,
247270
"llava-next": run_llava_next,
@@ -256,6 +279,7 @@ def run_qwen2_vl(question, modality):
256279
"internvl_chat": run_internvl,
257280
"qwen_vl": run_qwen_vl,
258281
"qwen2_vl": run_qwen2_vl,
282+
"mllama": run_mllama,
259283
}
260284

261285

examples/openai_vision_api_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"content": [
3939
{
4040
"type": "text",
41-
"text": "Whats in this image?"
41+
"text": "What's in this image?"
4242
},
4343
{
4444
"type": "image_url",
@@ -75,7 +75,7 @@ def encode_image_base64_from_url(image_url: str) -> str:
7575
"content": [
7676
{
7777
"type": "text",
78-
"text": "Whats in this image?"
78+
"text": "What's in this image?"
7979
},
8080
{
8181
"type": "image_url",

requirements-common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ numpy < 2.0.0
44
requests
55
tqdm
66
py-cpuinfo
7-
transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox.
7+
transformers >= 4.45.0 # Required for Llama 3.2.
88
tokenizers >= 0.19.1 # Required for Llama 3.
99
protobuf # Required by LlamaTokenizer.
1010
fastapi < 0.113.0; python_version < '3.9'

tests/models/encoder_decoder/vision_language/__init__.py

Whitespace-only changes.
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
from typing import List, Optional, Tuple, Type, overload
2+
3+
import pytest
4+
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
5+
BatchEncoding)
6+
7+
from vllm.multimodal.utils import rescale_image_size
8+
from vllm.sequence import SampleLogprobs
9+
10+
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
11+
_ImageAssets)
12+
from ....utils import multi_gpu_test
13+
from ...utils import check_logprobs_close
14+
15+
_LIMIT_IMAGE_PER_PROMPT = 1
16+
17+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
18+
"stop_sign":
19+
"<|image|><|begin_of_text|>The meaning of the image is",
20+
"cherry_blossom":
21+
"<|image|><|begin_of_text|>The city is",
22+
})
23+
24+
text_only_prompts = [
25+
"The color of the sky is blue but sometimes it can also be",
26+
]
27+
28+
models = [
29+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
30+
]
31+
32+
33+
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
34+
Optional[SampleLogprobs]],
35+
model: str):
36+
"""Sanitize vllm output to be comparable with hf output."""
37+
output_ids, output_str, out_logprobs = vllm_output
38+
39+
config = AutoConfig.from_pretrained(model)
40+
image_token_id = config.image_token_index
41+
42+
tokenizer = AutoTokenizer.from_pretrained(model)
43+
eos_token_id = tokenizer.eos_token_id
44+
45+
hf_output_ids = [
46+
token_id for idx, token_id in enumerate(output_ids)
47+
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
48+
]
49+
50+
assert output_str[0] == " "
51+
hf_output_str = output_str[1:]
52+
if hf_output_ids[-1] == eos_token_id:
53+
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
54+
55+
return hf_output_ids, hf_output_str, out_logprobs
56+
57+
58+
@overload
59+
def run_test(
60+
hf_runner: Type[HfRunner],
61+
vllm_runner: Type[VllmRunner],
62+
image_assets: _ImageAssets,
63+
model: str,
64+
*,
65+
size_factors: List[float],
66+
dtype: str,
67+
max_tokens: int,
68+
num_logprobs: int,
69+
tensor_parallel_size: int,
70+
distributed_executor_backend: Optional[str] = None,
71+
):
72+
...
73+
74+
75+
@overload
76+
def run_test(
77+
hf_runner: Type[HfRunner],
78+
vllm_runner: Type[VllmRunner],
79+
image_assets: _ImageAssets,
80+
model: str,
81+
*,
82+
sizes: List[Tuple[int, int]],
83+
dtype: str,
84+
max_tokens: int,
85+
num_logprobs: int,
86+
tensor_parallel_size: int,
87+
distributed_executor_backend: Optional[str] = None,
88+
):
89+
...
90+
91+
92+
def run_test(
93+
hf_runner: Type[HfRunner],
94+
vllm_runner: Type[VllmRunner],
95+
image_assets: _ImageAssets,
96+
model: str,
97+
*,
98+
size_factors: Optional[List[float]] = None,
99+
sizes: Optional[List[Tuple[int, int]]] = None,
100+
dtype: str,
101+
max_tokens: int,
102+
num_logprobs: int,
103+
tensor_parallel_size: int,
104+
distributed_executor_backend: Optional[str] = None,
105+
):
106+
images = [asset.pil_image for asset in image_assets]
107+
108+
if size_factors is not None:
109+
inputs_per_image = [(
110+
[prompt for _ in size_factors],
111+
[rescale_image_size(image, factor) for factor in size_factors],
112+
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
113+
elif sizes is not None:
114+
inputs_per_image = [(
115+
[
116+
prompt if size is not None else text_only_prompts[0]
117+
for size in sizes
118+
],
119+
[
120+
image.resize(size) if size is not None else None
121+
for size in sizes
122+
],
123+
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
124+
if len(sizes) == 0:
125+
inputs_per_image.append(
126+
(text_only_prompts, [None] * len(text_only_prompts)))
127+
else:
128+
raise ValueError("You must provide either `size_factors` or `sizes`")
129+
130+
_run_test(hf_runner,
131+
vllm_runner,
132+
inputs_per_image,
133+
model,
134+
dtype=dtype,
135+
max_tokens=max_tokens,
136+
num_logprobs=num_logprobs,
137+
tensor_parallel_size=tensor_parallel_size,
138+
distributed_executor_backend=distributed_executor_backend)
139+
140+
141+
def _run_test(
142+
hf_runner: Type[HfRunner],
143+
vllm_runner: Type[VllmRunner],
144+
inputs: List[Tuple[List[str], PromptImageInput]],
145+
model: str,
146+
*,
147+
dtype: str,
148+
max_tokens: int,
149+
num_logprobs: int,
150+
tensor_parallel_size: int,
151+
distributed_executor_backend: Optional[str] = None,
152+
):
153+
"""Inference result should be the same between hf and vllm.
154+
155+
All the image fixtures for the test are from IMAGE_ASSETS.
156+
For huggingface runner, we provide the PIL images as input.
157+
For vllm runner, we provide MultiModalDataDict objects
158+
and corresponding MultiModalConfig as input.
159+
Note, the text input is also adjusted to abide by vllm contract.
160+
The text output is sanitized to be able to compare with hf.
161+
"""
162+
# NOTE: take care of the order. run vLLM first, and then run HF.
163+
# vLLM needs a fresh new process without cuda initialization.
164+
# if we run HF first, the cuda initialization will be done and it
165+
# will hurt multiprocessing backend with fork method (the default method).
166+
167+
# max_model_len should be greater than image_feature_size
168+
with vllm_runner(model,
169+
dtype=dtype,
170+
max_num_seqs=16,
171+
max_model_len=4096,
172+
tensor_parallel_size=tensor_parallel_size,
173+
distributed_executor_backend=distributed_executor_backend,
174+
enforce_eager=True,
175+
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
176+
}) as vllm_model:
177+
vllm_outputs_per_image = [
178+
vllm_model.generate_greedy_logprobs(prompts,
179+
max_tokens,
180+
num_logprobs=num_logprobs,
181+
images=images)
182+
for prompts, images in inputs
183+
]
184+
185+
def process(hf_inputs: BatchEncoding):
186+
return hf_inputs
187+
188+
from transformers import AutoConfig
189+
from transformers.models.mllama import MllamaConfig as MllamaConfigHf
190+
191+
# use transformer's MllamaConfig for hf_runner
192+
# and vllm's MllamaConfig for vllm_runner
193+
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
194+
with hf_runner(model,
195+
dtype=dtype,
196+
postprocess_inputs=process,
197+
auto_cls=AutoModelForVision2Seq) as hf_model:
198+
hf_outputs_per_image = [
199+
hf_model.generate_greedy_logprobs_limit(prompts,
200+
max_tokens,
201+
num_logprobs=num_logprobs,
202+
images=images)
203+
for prompts, images in inputs
204+
]
205+
206+
from vllm.transformers_utils.configs.mllama import MllamaConfig
207+
AutoConfig.register("mllama", MllamaConfig, exist_ok=True)
208+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
209+
vllm_outputs_per_image):
210+
check_logprobs_close(
211+
outputs_0_lst=hf_outputs,
212+
outputs_1_lst=[
213+
vllm_to_hf_output(vllm_output, model)
214+
for vllm_output in vllm_outputs
215+
],
216+
name_0="hf",
217+
name_1="vllm",
218+
)
219+
220+
221+
@pytest.mark.parametrize("model", models)
222+
@pytest.mark.parametrize(
223+
"sizes",
224+
[
225+
# Text only
226+
[],
227+
# Single-size
228+
[(512, 512)],
229+
# Single-size, batched
230+
[(512, 512), (512, 512), (512, 512)],
231+
# Multi-size, batched
232+
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
233+
(1024, 1024), (512, 1536), (512, 2028)],
234+
# Multi-size, batched, including text only
235+
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
236+
(1024, 1024), (512, 1536), (512, 2028), None],
237+
# mllama has 8 possible aspect ratios, carefully set the sizes
238+
# to cover all of them
239+
],
240+
)
241+
@pytest.mark.parametrize("dtype", ["bfloat16"])
242+
@pytest.mark.parametrize("max_tokens", [128])
243+
@pytest.mark.parametrize("num_logprobs", [5])
244+
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
245+
max_tokens, num_logprobs) -> None:
246+
run_test(
247+
hf_runner,
248+
vllm_runner,
249+
image_assets,
250+
model,
251+
sizes=sizes,
252+
dtype=dtype,
253+
max_tokens=max_tokens,
254+
num_logprobs=num_logprobs,
255+
tensor_parallel_size=1,
256+
)
257+
258+
259+
@multi_gpu_test(num_gpus=2)
260+
@pytest.mark.parametrize("model", models)
261+
@pytest.mark.parametrize(
262+
"sizes",
263+
[
264+
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
265+
(1024, 1024), (512, 1536), (512, 2028), None],
266+
],
267+
)
268+
@pytest.mark.parametrize("dtype", ["bfloat16"])
269+
@pytest.mark.parametrize("max_tokens", [128])
270+
@pytest.mark.parametrize("num_logprobs", [5])
271+
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
272+
dtype, max_tokens, num_logprobs) -> None:
273+
run_test(
274+
hf_runner,
275+
vllm_runner,
276+
image_assets,
277+
model,
278+
sizes=sizes,
279+
dtype=dtype,
280+
max_tokens=max_tokens,
281+
num_logprobs=num_logprobs,
282+
tensor_parallel_size=2,
283+
)

vllm/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,9 @@ def get_multimodal_config(self) -> "MultiModalConfig":
576576
@property
577577
def is_encoder_decoder_model(self) -> bool:
578578
"""Extract the HF encoder/decoder model flag."""
579-
return getattr(self.hf_config, "is_encoder_decoder", False)
579+
return getattr(self.hf_config, "is_encoder_decoder", False) or (
580+
(hasattr(self.hf_config, "text_config") and getattr(
581+
self.hf_config.text_config, "is_encoder_decoder", False)))
580582

581583
@property
582584
def is_embedding_model(self) -> bool:

vllm/engine/llm_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,11 @@ def is_embedding_model(self):
17341734

17351735
def _validate_model_inputs(self, inputs: Union[LLMInputs,
17361736
EncoderDecoderLLMInputs]):
1737-
if self.is_encoder_decoder_model():
1737+
if self.model_config.is_multimodal_model:
1738+
# For encoder-decoder multimodal models, the max_prompt_len
1739+
# restricts the decoder prompt length
1740+
prompt_ids = inputs.get("prompt_token_ids")
1741+
elif self.is_encoder_decoder_model():
17381742
prompt_ids = inputs.get("encoder_prompt_token_ids")
17391743
else:
17401744
prompt_ids = inputs.get("prompt_token_ids")

0 commit comments

Comments
 (0)