Skip to content

[Model] Refactor Ultravox to use merged input processor #11198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4c3e5d5
refactor ultravox process
Isotr0py Dec 10, 2024
d160560
fix processor inputs
Isotr0py Dec 10, 2024
91384bf
fix ultravox processor
Isotr0py Dec 13, 2024
6d31c3d
Merge branch 'vllm-project:main' into ultravox-refactor
Isotr0py Dec 13, 2024
782bd61
fix placeholder padding
Isotr0py Dec 13, 2024
5350918
Merge branch 'vllm-project:main' into ultravox-refactor
Isotr0py Dec 14, 2024
57c7ec9
add comments
Isotr0py Dec 14, 2024
c1a9cef
update example
Isotr0py Dec 14, 2024
89416a8
code format
Isotr0py Dec 14, 2024
9693691
remove unused code
Isotr0py Dec 14, 2024
8254384
Merge branch 'main' into ultravox-refactor
Isotr0py Dec 15, 2024
e0ef4bc
Merge branch 'vllm-project:main' into ultravox-refactor
Isotr0py Dec 15, 2024
08a3422
clean up
Isotr0py Dec 15, 2024
d72fe45
refactor
Isotr0py Dec 15, 2024
0b8aa47
code format
Isotr0py Dec 15, 2024
d5b7cf7
fix prompt replacement
Isotr0py Dec 15, 2024
980c731
code format
Isotr0py Dec 15, 2024
5cb6362
fix audio_token truncation
Isotr0py Dec 15, 2024
0854a67
fix mm_data
Isotr0py Dec 15, 2024
146fc63
fix audio_token_len and online inference
Isotr0py Dec 15, 2024
342048c
rename
Isotr0py Dec 15, 2024
daba237
Update vllm/model_executor/models/ultravox.py
Isotr0py Dec 15, 2024
7813d47
clean up
Isotr0py Dec 15, 2024
6e7b138
handle no audio data
Isotr0py Dec 15, 2024
ca58f8b
Update vllm/model_executor/models/ultravox.py
Isotr0py Dec 15, 2024
e1fdd36
cleanup replacement
Isotr0py Dec 15, 2024
8ef0b23
fix audio entrypoint and pp test
Isotr0py Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role':
'user',
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
'role': 'user',
'content': "<|audio|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

AudioTuple = Tuple[np.ndarray, int]

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
VLLM_PLACEHOLDER = "<|audio|>"
HF_PLACEHOLDER = "<|audio|>"

CHUNKED_PREFILL_KWARGS = {
Expand Down Expand Up @@ -46,7 +46,8 @@ def audio(request):
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
"--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _placeholder_str(self, modality: ModalityStr,
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
return "<|audio|>"
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
Expand Down
253 changes: 112 additions & 141 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,43 @@
"""PyTorch Ultravox model."""

import math
from collections import defaultdict
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union, cast)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature, ProcessorMixin
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map)

_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_PLACEHOLDER_STR = "<|reserved_special_token_0|>"
_AUDIO_TOKENS_PER_SECOND = 6.25


Expand Down Expand Up @@ -72,64 +73,15 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)


def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
audio_count: int,
):
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)

return SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}


def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}


def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)

return DummyData(seq_data, mm_dict, ranges)


def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]

if len(data) == 0:
return MultiModalKwargs()

# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalKwargs({"audio_embeds": data})

audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")

(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
class UltravoxProcessor(BaseMultiModalProcessor):

def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = whisper_feature_extractor(self.ctx)
if sr != feature_extractor.sampling_rate:
try:
import librosa
Expand All @@ -140,78 +92,99 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}

minimum_audio_length = feature_extractor.n_fft // 2 + 1
if len(audio) < minimum_audio_length:
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))

single_audio_features = feature_extractor(
audio, sampling_rate=sr, padding="longest",
return_tensors="pt")["input_features"]

# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))

return MultiModalKwargs({"audio_features": audio_features})


def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs

if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
audio_data = [audio_data]

# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
hf_inputs = defaultdict(list)
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
hf_inputs["audio_features"].append(
processed_inputs["audio_values"].squeeze(0))
hf_inputs["input_ids"] = processed_inputs["input_ids"]
return hf_inputs

def _get_hf_processor(
self,
**mm_processor_kwargs: Mapping[str, object],
) -> ProcessorMixin:
# Ultravox processor use eot_token_id as the audio placeholder token,
# we replace it with <|reserved_special_token_0|> for convenience.
hf_processor = super()._get_hf_processor(**mm_processor_kwargs)
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_STR
return hf_processor

def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
stack_factor = hf_processor.stack_factor
encoder_ds_factor = hf_processor.encoder_ds_factor

def get_replacement_ultravox(item_idx: int):
audio_data, sr = mm_items.audio[item_idx]
audio_data = self._resample_audio(audio_data, sr)["audio"]
audio_len = audio_data.shape[0]
nb_encoder_frames = int(round(audio_len / encoder_ds_factor +
1e-4))
audio_token_len = int(np.ceil(nb_encoder_frames / stack_factor))
max_audio_token_len = get_ultravox_max_audio_tokens(self.ctx)
audio_token_len = min(max(1, audio_token_len), max_audio_token_len)
return [_AUDIO_PLACEHOLDER_TOKEN] * audio_token_len

return [
PromptReplacement(
modality="audio",
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]

feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]

audio_token_counts = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)

tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)

new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts,
)

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = whisper_feature_extractor(self.ctx)
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate

audio_count = mm_counts["audio"]
audio = np.array([0.0] * audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count}

return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)


class StackAudioFrames(nn.Module):
Expand Down Expand Up @@ -332,11 +305,9 @@ def forward(
return hidden_states


@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
@MULTIMODAL_REGISTRY.register_processor(UltravoxProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
19 changes: 14 additions & 5 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,14 +594,10 @@ def _find_placeholders(
return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))

def _apply_hf_processor(
def _get_processor_data(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self._get_hf_processor(**mm_processor_kwargs)

processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_data.items():
Expand All @@ -619,6 +615,19 @@ def _apply_hf_processor(
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data

def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs)

processor_data, passthrough_data = self._get_processor_data(mm_data)

assert callable(hf_processor)
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
Expand Down
Loading