Skip to content

Commit d927dbc

Browse files
[Model] Refactor Ultravox to use merged input processor (#11198)
Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent bddbbcb commit d927dbc

File tree

7 files changed

+129
-154
lines changed

7 files changed

+129
-154
lines changed

examples/offline_inference_audio_language.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):
2525

2626
tokenizer = AutoTokenizer.from_pretrained(model_name)
2727
messages = [{
28-
'role':
29-
'user',
30-
'content':
31-
"<|reserved_special_token_0|>\n" * audio_count + question
28+
'role': 'user',
29+
'content': "<|audio|>\n" * audio_count + question
3230
}]
3331
prompt = tokenizer.apply_chat_template(messages,
3432
tokenize=False,
3533
add_generation_prompt=True)
3634

37-
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
35+
llm = LLM(model=model_name,
36+
trust_remote_code=True,
37+
limit_mm_per_prompt={"audio": audio_count})
3838
stop_token_ids = None
3939
return llm, prompt, stop_token_ids
4040

tests/distributed/test_pipeline_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def iter_params(self, model_name: str):
214214
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
215215
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
216216
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
217-
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
217+
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
218218
# [Encoder-decoder]
219219
# TODO: Implement PP
220220
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),

tests/entrypoints/openai/test_audio.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def server():
2525
"--max-num-seqs",
2626
"5",
2727
"--enforce-eager",
28+
"--trust-remote-code",
2829
]
2930

3031
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
AudioTuple = Tuple[np.ndarray, int]
1818

19-
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
19+
VLLM_PLACEHOLDER = "<|audio|>"
2020
HF_PLACEHOLDER = "<|audio|>"
2121

2222
CHUNKED_PREFILL_KWARGS = {
@@ -46,7 +46,8 @@ def audio(request):
4646
def server(request, audio_assets):
4747
args = [
4848
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
49-
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
49+
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
50+
"--trust-remote-code"
5051
] + [
5152
f"--{key.replace('_','-')}={value}"
5253
for key, value in request.param.items()

vllm/entrypoints/chat_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def _placeholder_str(self, modality: ModalityStr,
418418
raise TypeError(f"Unknown {modality} model type: {model_type}")
419419
elif modality == "audio":
420420
if model_type == "ultravox":
421-
return "<|reserved_special_token_0|>"
421+
return "<|audio|>"
422422
if model_type == "qwen2_audio":
423423
return (f"Audio {current_count}: "
424424
f"<|audio_bos|><|AUDIO|><|audio_eos|>")

vllm/model_executor/models/ultravox.py

Lines changed: 104 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,39 @@
33

44
import math
55
from functools import cached_property, lru_cache
6-
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
7-
TypedDict, Union, cast)
6+
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
7+
Tuple, TypedDict, Union)
88

99
import numpy as np
1010
import torch
1111
import torch.utils.checkpoint
1212
from torch import nn
1313
from torch.nn import functional as F
14+
from transformers import BatchFeature
1415
from transformers.models.whisper import WhisperFeatureExtractor
1516
from transformers.models.whisper.modeling_whisper import WhisperEncoder
1617

1718
from vllm.attention import AttentionMetadata
1819
from vllm.config import VllmConfig
19-
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
20-
InputContext, token_inputs)
20+
from vllm.inputs import InputContext
2121
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
2222
from vllm.model_executor.layers.layernorm import RMSNorm
2323
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2424
from vllm.model_executor.model_loader.loader import DefaultModelLoader
2525
from vllm.model_executor.sampling_metadata import SamplingMetadata
26-
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
27-
NestedTensors)
28-
from vllm.multimodal.utils import (cached_get_tokenizer,
29-
consecutive_placeholder_ranges,
30-
repeat_and_pad_placeholder_tokens)
31-
from vllm.sequence import IntermediateTensors, SequenceData
26+
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
27+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
28+
MultiModalDataDict,
29+
MultiModalDataItems, ProcessorInputs,
30+
PromptReplacement)
31+
from vllm.sequence import IntermediateTensors
3232
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
33-
from vllm.utils import is_list_of
3433

3534
from .interfaces import SupportsMultiModal, SupportsPP
3635
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
3736
init_vllm_registered_model, maybe_prefix,
3837
merge_multimodal_embeddings_from_map)
3938

40-
_AUDIO_PLACEHOLDER_TOKEN = 128002
4139
_AUDIO_TOKENS_PER_SECOND = 6.25
4240

4341

@@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
7270
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
7371

7472

75-
def dummy_seq_data_for_ultravox(
76-
ctx: InputContext,
77-
seq_len: int,
78-
audio_count: int,
79-
):
80-
audio_length = min(get_ultravox_max_audio_tokens(ctx),
81-
seq_len // audio_count)
73+
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
8274

83-
return SequenceData.from_prompt_token_counts(
84-
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
85-
(0, seq_len - audio_length * audio_count)), {
86-
"audio":
87-
consecutive_placeholder_ranges(num_items=audio_count,
88-
item_size=audio_length)
89-
}
90-
91-
92-
def dummy_audio_for_ultravox(
93-
ctx: InputContext,
94-
audio_count: int,
95-
):
96-
feature_extractor = whisper_feature_extractor(ctx)
97-
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
98-
return {"audio": [audio_and_sr] * audio_count}
99-
100-
101-
def dummy_data_for_ultravox(
102-
ctx: InputContext,
103-
seq_len: int,
104-
mm_counts: Mapping[str, int],
105-
):
106-
audio_count = mm_counts["audio"]
107-
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
108-
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
109-
110-
return DummyData(seq_data, mm_dict, ranges)
111-
112-
113-
def input_mapper_for_ultravox(ctx: InputContext, data: object):
114-
if not isinstance(data, list):
115-
data = [data]
116-
117-
if len(data) == 0:
118-
return MultiModalKwargs()
119-
120-
# If the audio inputs are embeddings, no need for preprocessing
121-
if is_list_of(data, torch.Tensor, check="all"):
122-
return MultiModalKwargs({"audio_embeds": data})
123-
124-
audio_features = []
125-
for audio_input in data:
126-
if not isinstance(audio_input, tuple):
127-
raise NotImplementedError(
128-
f"Unsupported data type: {type(audio_input)}")
129-
130-
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
131-
feature_extractor = whisper_feature_extractor(ctx)
75+
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
76+
return self._get_hf_processor().audio_processor.feature_extractor
13277

78+
def _resample_audio(
79+
self,
80+
audio: np.ndarray,
81+
sr: int,
82+
) -> Dict[str, Union[np.ndarray, int]]:
83+
# resample audio to the model's sampling rate
84+
feature_extractor = self._get_feature_extractor()
13385
if sr != feature_extractor.sampling_rate:
13486
try:
13587
import librosa
@@ -140,78 +92,92 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
14092
orig_sr=sr,
14193
target_sr=feature_extractor.sampling_rate)
14294
sr = feature_extractor.sampling_rate
95+
return {"audio": audio, "sampling_rate": sr}
14396

144-
minimum_audio_length = feature_extractor.n_fft // 2 + 1
145-
if len(audio) < minimum_audio_length:
146-
# Not enough audio; pad it.
147-
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
148-
149-
single_audio_features = feature_extractor(
150-
audio, sampling_rate=sr, padding="longest",
151-
return_tensors="pt")["input_features"]
152-
153-
# Remove the batch dimension because we're wrapping it in a list.
154-
audio_features.append(single_audio_features.squeeze(0))
155-
156-
return MultiModalKwargs({"audio_features": audio_features})
157-
158-
159-
def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
160-
multi_modal_data = inputs.get("multi_modal_data")
161-
if multi_modal_data is None or "audio" not in multi_modal_data:
162-
return inputs
97+
def _apply_hf_processor(
98+
self,
99+
prompt: str,
100+
mm_data: MultiModalDataDict,
101+
mm_processor_kwargs: Mapping[str, object],
102+
) -> BatchFeature:
103+
if not mm_data or not mm_data.get("audio", None):
104+
return super()._apply_hf_processor(prompt, mm_data,
105+
mm_processor_kwargs)
106+
107+
audio_data = mm_data["audio"]
108+
if not isinstance(audio_data, list):
109+
audio_data = [audio_data]
110+
111+
# Ultravox processor doesn't support multiple inputs,
112+
# therefore we need to input text and audio one by one
113+
tokenizer = self._get_tokenizer()
114+
audio_features, audio_token_len = [], []
115+
processed_inputs = {}
116+
for audio, sr in audio_data:
117+
data = self._resample_audio(audio, sr)
118+
processed_inputs = super()._apply_hf_processor(
119+
prompt, data, mm_processor_kwargs)
120+
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
121+
skip_special_tokens=False)
122+
audio_features.append(
123+
processed_inputs.pop("audio_values").squeeze(0))
124+
audio_token_len.append(
125+
processed_inputs.pop("audio_token_len").item())
126+
127+
return dict(
128+
**processed_inputs,
129+
audio_features=audio_features,
130+
audio_token_len=audio_token_len,
131+
)
163132

164-
if "multi_modal_placeholders" in inputs and "audio" in inputs[
165-
"multi_modal_placeholders"]:
166-
# The inputs already have placeholders.
167-
return inputs
133+
def _get_processor_data(
134+
self,
135+
mm_data: MultiModalDataDict,
136+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
137+
# Ultravox uses "audio" instead of "audios" as calling keyword
138+
processor_data, passthrough_data = super()._get_processor_data(mm_data)
139+
if "audios" in processor_data:
140+
processor_data["audio"] = processor_data.pop("audios")
141+
return processor_data, passthrough_data
142+
143+
def _get_prompt_replacements(
144+
self,
145+
mm_items: MultiModalDataItems,
146+
hf_inputs: BatchFeature,
147+
mm_processor_kwargs: Mapping[str, object],
148+
) -> list[PromptReplacement]:
149+
hf_processor = self._get_hf_processor()
150+
placeholder = hf_processor.audio_token_replacement
151+
152+
def get_replacement_ultravox(item_idx: int):
153+
audio_token_len = hf_inputs["audio_token_len"][item_idx]
154+
return placeholder * audio_token_len
155+
156+
return [
157+
PromptReplacement(
158+
modality="audio",
159+
target="<|audio|>",
160+
replacement=get_replacement_ultravox,
161+
)
162+
]
168163

169-
feature_extractor = whisper_feature_extractor(ctx)
170-
audios = multi_modal_data["audio"]
171-
if not isinstance(audios, list):
172-
audios = [audios]
173-
174-
audio_token_counts = []
175-
for audio in audios:
176-
if isinstance(audio, torch.Tensor):
177-
audio_num_tokens = audio.shape[1]
178-
audio_token_counts.append(audio_num_tokens)
179-
else:
180-
audio_data, sample_rate = audio
181-
audio_length = audio_data.shape[0]
182-
if sample_rate != feature_extractor.sampling_rate:
183-
# Account for resampling.
184-
adjustment = feature_extractor.sampling_rate / sample_rate
185-
audio_length = math.ceil(adjustment * audio_length)
186-
187-
feature_extractor_output_length = math.ceil(
188-
(audio_length - (feature_extractor.hop_length - 1)) /
189-
feature_extractor.hop_length)
190-
191-
uv_config = ctx.get_hf_config(UltravoxConfig)
192-
audio_num_tokens = min(
193-
max(
194-
1,
195-
math.ceil(feature_extractor_output_length /
196-
(uv_config.stack_factor * 2))),
197-
get_ultravox_max_audio_tokens(ctx))
198-
audio_token_counts.append(audio_num_tokens)
199-
200-
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
201-
202-
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
203-
tokenizer,
204-
inputs.get("prompt"),
205-
inputs["prompt_token_ids"],
206-
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
207-
repeat_count=audio_token_counts,
208-
)
209-
210-
# NOTE: Create a defensive copy of the original inputs
211-
return token_inputs(prompt_token_ids=new_token_ids,
212-
prompt=new_prompt,
213-
multi_modal_data=multi_modal_data,
214-
multi_modal_placeholders={"audio": ranges})
164+
def _get_dummy_mm_inputs(
165+
self,
166+
mm_counts: Mapping[str, int],
167+
) -> ProcessorInputs:
168+
feature_extractor = self._get_feature_extractor()
169+
sampling_rate = feature_extractor.sampling_rate
170+
audio_len = feature_extractor.chunk_length * sampling_rate
171+
172+
audio_count = mm_counts["audio"]
173+
audio = np.zeros(audio_len)
174+
data = {"audio": [(audio, sampling_rate)] * audio_count}
175+
176+
return ProcessorInputs(
177+
prompt_text="<|audio|>" * audio_count,
178+
mm_data=data,
179+
mm_processor_kwargs={},
180+
)
215181

216182

217183
class StackAudioFrames(nn.Module):
@@ -332,11 +298,9 @@ def forward(
332298
return hidden_states
333299

334300

335-
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
336301
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
337302
"audio", get_ultravox_max_audio_tokens)
338-
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
339-
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
303+
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
340304
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
341305

342306
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

0 commit comments

Comments
 (0)