diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 91e6c42d526..55b3f52356c 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -856,7 +856,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `UltravoxModel`
* Ultravox
* T + AE+
- * `fixie-ai/ultravox-v0_3`
+ * `fixie-ai/ultravox-v0_5-llama-3_2-1b`
* ✅︎
* ✅︎
* ✅︎
diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md
index 217b531e837..ade59e37738 100644
--- a/docs/source/serving/multimodal_inputs.md
+++ b/docs/source/serving/multimodal_inputs.md
@@ -359,12 +359,12 @@ export VLLM_VIDEO_FETCH_TIMEOUT=
### Audio
Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in).
-Here is a simple example using Ultravox-v0.3.
+Here is a simple example using Ultravox-v0.5-1B.
First, launch the OpenAI-compatible server:
```bash
-vllm serve fixie-ai/ultravox-v0_3
+vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b
```
Then, you can use the OpenAI client as follows:
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index 707ca9f8789..3e3034a02f0 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -24,9 +24,9 @@
# Unless specified, these settings have been tested to work on a single L4.
-# Ultravox 0.3
+# Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int):
- model_name = "fixie-ai/ultravox-v0_3"
+ model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py
index d5f798a8dae..ecfcf05a90d 100644
--- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py
+++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py
@@ -12,7 +12,7 @@
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
(audio inference with Ultravox)
-vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096
+vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
"""
import base64
diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py
index 5b6741d74ef..5d7cb9e4089 100644
--- a/tests/distributed/test_pipeline_parallel.py
+++ b/tests/distributed/test_pipeline_parallel.py
@@ -215,7 +215,7 @@ def iter_params(self, model_name: str):
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
- "fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
+ "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
# [Encoder-decoder]
# TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
@@ -234,7 +234,7 @@ def iter_params(self, model_name: str):
# [MULTIMODAL GENERATION]
"OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct",
- "fixie-ai/ultravox-v0_3",
+ "fixie-ai/ultravox-v0_5-llama-3_2-1b",
# [LANGUAGE GENERATION - HYBRID ARCH]
"ai21labs/Jamba-tiny-dev",
]
diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py
index 3459f24834d..fe7299a48e6 100644
--- a/tests/entrypoints/openai/test_audio.py
+++ b/tests/entrypoints/openai/test_audio.py
@@ -11,7 +11,7 @@
from ...utils import RemoteOpenAIServer
-MODEL_NAME = "fixie-ai/ultravox-v0_3"
+MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
TEST_AUDIO_URLS = [
AudioAsset("winning_call").url,
]
diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py
index 5c469007af2..c52fa905c80 100644
--- a/tests/entrypoints/test_chat_utils.py
+++ b/tests/entrypoints/test_chat_utils.py
@@ -21,7 +21,7 @@
EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
-ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3"
+ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py
index fe9361d1261..d1f643a8fdb 100644
--- a/tests/models/decoder_only/audio_language/test_ultravox.py
+++ b/tests/models/decoder_only/audio_language/test_ultravox.py
@@ -15,7 +15,7 @@
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close
-MODEL_NAME = "fixie-ai/ultravox-v0_3"
+MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
AudioTuple = Tuple[np.ndarray, int]
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index a56a9e2beef..6244056c747 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -164,7 +164,7 @@ def _test_processing_correctness(
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
- "fixie-ai/ultravox-v0_3",
+ "fixie-ai/ultravox-v0_5-llama-3_2-1b",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 3fd94b89c8a..66b7d3c2e77 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -267,7 +267,7 @@ def check_available_online(
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
- "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
+ "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b",
trust_remote_code=True),
# [Encoder-decoder]
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index 9da0682cfa8..063997a14a6 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -258,27 +258,35 @@ def __init__(self, config: UltravoxConfig):
super().__init__()
self.hidden_dim = config.hidden_size
self._pad_and_stack = StackAudioFrames(config.stack_factor)
- dim = config.audio_config.hidden_size * config.stack_factor
- self.ln_pre = RMSNorm(dim)
- self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
- dim = self.hidden_dim
+ dim_in = config.audio_config.hidden_size * config.stack_factor
+ self.ln_pre = RMSNorm(dim_in)
+ self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
+ dim_mid = self.hidden_dim
if config.projector_act == "swiglu":
self.act = MulAndSilu()
- dim = dim // 2
+ dim_mid = dim_mid // 2
else:
self.act = get_act_fn(config.projector_act)
- self.linear_2 = nn.Linear(dim,
- config.text_config.hidden_size,
- bias=False)
- self.ln_post = RMSNorm(config.text_config.hidden_size)
+ dim_out = config.text_config.hidden_size
+ self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
+
+ # Ultravox v0.4.1 and below use layer_norm after the second linear layer
+ # while v0.5.0 and above uses layer_norm after the first linear layer.
+ if config.projector_ln_mid:
+ self.ln_mid: nn.Module = RMSNorm(dim_mid)
+ self.ln_post = nn.Identity()
+ else:
+ self.ln_mid = nn.Identity()
+ self.ln_post = RMSNorm(dim_out)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
+ hidden_states = self.ln_mid(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_post(hidden_states)
return hidden_states
diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py
index 99715ba6d0b..6b2765db94e 100644
--- a/vllm/transformers_utils/configs/ultravox.py
+++ b/vllm/transformers_utils/configs/ultravox.py
@@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
+ projector_ln_mid (`bool`, *optional*, defaults to `False`):
+ Whether to apply layer normalization at the middle of the
+ projector or at the end. Versions v0.4.1 and below
+ use `False`, but v0.5 and above use `True`.
"""
model_type = "ultravox"
@@ -56,6 +60,7 @@ def __init__(
projector_act: str = "swiglu",
text_model_lora_config: Optional[Dict[str, Any]] = None,
audio_model_lora_config: Optional[Dict[str, Any]] = None,
+ projector_ln_mid: bool = False,
**kwargs,
):
self.ignore_index = ignore_index
@@ -68,6 +73,7 @@ def __init__(
self.stack_factor = stack_factor
self.norm_init = norm_init
self.projector_act = projector_act
+ self.projector_ln_mid = projector_ln_mid
if text_model_id is not None:
# Avoid circular import