Skip to content

Commit d54859c

Browse files
farzadabkwang1012
authored andcommitted
[Model] Ultravox Model: Support v0.5 Release (vllm-project#12912)
Signed-off-by: Farzad Abdolhosseini <[email protected]>
1 parent 2b10721 commit d54859c

File tree

12 files changed

+36
-22
lines changed

12 files changed

+36
-22
lines changed

docs/source/models/supported_models.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ See [this page](#generative-models) for more information on how to use generativ
856856
- * `UltravoxModel`
857857
* Ultravox
858858
* T + A<sup>E+</sup>
859-
* `fixie-ai/ultravox-v0_3`
859+
* `fixie-ai/ultravox-v0_5-llama-3_2-1b`
860860
* ✅︎
861861
* ✅︎
862862
* ✅︎

docs/source/serving/multimodal_inputs.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,12 @@ export VLLM_VIDEO_FETCH_TIMEOUT=<timeout>
359359
### Audio
360360

361361
Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in).
362-
Here is a simple example using Ultravox-v0.3.
362+
Here is a simple example using Ultravox-v0.5-1B.
363363

364364
First, launch the OpenAI-compatible server:
365365

366366
```bash
367-
vllm serve fixie-ai/ultravox-v0_3
367+
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b
368368
```
369369

370370
Then, you can use the OpenAI client as follows:

examples/offline_inference/audio_language.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
# Unless specified, these settings have been tested to work on a single L4.
2525

2626

27-
# Ultravox 0.3
27+
# Ultravox 0.5-1B
2828
def run_ultravox(question: str, audio_count: int):
29-
model_name = "fixie-ai/ultravox-v0_3"
29+
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
3030

3131
tokenizer = AutoTokenizer.from_pretrained(model_name)
3232
messages = [{

examples/online_serving/openai_chat_completion_client_for_multimodal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
1313
1414
(audio inference with Ultravox)
15-
vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096
15+
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
1616
"""
1717
import base64
1818

tests/distributed/test_pipeline_parallel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def iter_params(self, model_name: str):
232232
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
233233
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
234234
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
235-
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
235+
"fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
236236
# [Encoder-decoder]
237237
# TODO: Implement PP
238238
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
@@ -251,7 +251,7 @@ def iter_params(self, model_name: str):
251251
# [MULTIMODAL GENERATION]
252252
"OpenGVLab/InternVL2-1B",
253253
"microsoft/Phi-3-vision-128k-instruct",
254-
"fixie-ai/ultravox-v0_3",
254+
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
255255
# [LANGUAGE GENERATION - HYBRID ARCH]
256256
"ai21labs/Jamba-tiny-dev",
257257
]

tests/entrypoints/openai/test_audio.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ...utils import RemoteOpenAIServer
1313

14-
MODEL_NAME = "fixie-ai/ultravox-v0_3"
14+
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
1515
TEST_AUDIO_URLS = [
1616
AudioAsset("winning_call").url,
1717
]

tests/entrypoints/test_chat_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
EXAMPLES_DIR = VLLM_PATH / "examples"
2222

2323
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
24-
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3"
24+
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
2525
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
2626
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
2727
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"

tests/models/decoder_only/audio_language/test_ultravox.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ....utils import RemoteOpenAIServer
1616
from ...utils import check_logprobs_close
1717

18-
MODEL_NAME = "fixie-ai/ultravox-v0_3"
18+
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
1919

2020
AudioTuple = Tuple[np.ndarray, int]
2121

tests/models/multimodal/processing/test_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _test_processing_correctness(
164164
"Qwen/Qwen2-VL-2B-Instruct",
165165
"Qwen/Qwen2.5-VL-3B-Instruct",
166166
"Qwen/Qwen2-Audio-7B-Instruct",
167-
"fixie-ai/ultravox-v0_3",
167+
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
168168
])
169169
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
170170
@pytest.mark.parametrize("num_batches", [32])

tests/models/registry.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def check_available_online(
267267
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
268268
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
269269
min_transformers_version="4.49"), # noqa: E501
270-
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
270+
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b",
271271
trust_remote_code=True),
272272
# [Encoder-decoder]
273273
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501

vllm/model_executor/models/ultravox.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -258,27 +258,35 @@ def __init__(self, config: UltravoxConfig):
258258
super().__init__()
259259
self.hidden_dim = config.hidden_size
260260
self._pad_and_stack = StackAudioFrames(config.stack_factor)
261-
dim = config.audio_config.hidden_size * config.stack_factor
262-
self.ln_pre = RMSNorm(dim)
263-
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
264-
dim = self.hidden_dim
261+
dim_in = config.audio_config.hidden_size * config.stack_factor
262+
self.ln_pre = RMSNorm(dim_in)
263+
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
264+
dim_mid = self.hidden_dim
265265

266266
if config.projector_act == "swiglu":
267267
self.act = MulAndSilu()
268-
dim = dim // 2
268+
dim_mid = dim_mid // 2
269269
else:
270270
self.act = get_act_fn(config.projector_act)
271271

272-
self.linear_2 = nn.Linear(dim,
273-
config.text_config.hidden_size,
274-
bias=False)
275-
self.ln_post = RMSNorm(config.text_config.hidden_size)
272+
dim_out = config.text_config.hidden_size
273+
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
274+
275+
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
276+
# while v0.5.0 and above uses layer_norm after the first linear layer.
277+
if config.projector_ln_mid:
278+
self.ln_mid: nn.Module = RMSNorm(dim_mid)
279+
self.ln_post = nn.Identity()
280+
else:
281+
self.ln_mid = nn.Identity()
282+
self.ln_post = RMSNorm(dim_out)
276283

277284
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
278285
audio_features = self._pad_and_stack(audio_features)
279286
audio_features = self.ln_pre(audio_features)
280287
hidden_states = self.linear_1(audio_features)
281288
hidden_states = self.act(hidden_states)
289+
hidden_states = self.ln_mid(hidden_states)
282290
hidden_states = self.linear_2(hidden_states)
283291
hidden_states = self.ln_post(hidden_states)
284292
return hidden_states

vllm/transformers_utils/configs/ultravox.py

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
3737
The LoRA configuration for finetuning the text model.
3838
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
3939
The LoRA configuration for finetuning the audio model.
40+
projector_ln_mid (`bool`, *optional*, defaults to `False`):
41+
Whether to apply layer normalization at the middle of the
42+
projector or at the end. Versions v0.4.1 and below
43+
use `False`, but v0.5 and above use `True`.
4044
"""
4145

4246
model_type = "ultravox"
@@ -56,6 +60,7 @@ def __init__(
5660
projector_act: str = "swiglu",
5761
text_model_lora_config: Optional[Dict[str, Any]] = None,
5862
audio_model_lora_config: Optional[Dict[str, Any]] = None,
63+
projector_ln_mid: bool = False,
5964
**kwargs,
6065
):
6166
self.ignore_index = ignore_index
@@ -68,6 +73,7 @@ def __init__(
6873
self.stack_factor = stack_factor
6974
self.norm_init = norm_init
7075
self.projector_act = projector_act
76+
self.projector_ln_mid = projector_ln_mid
7177

7278
if text_model_id is not None:
7379
# Avoid circular import

0 commit comments

Comments
 (0)