3
3
4
4
import math
5
5
from functools import cached_property , lru_cache
6
- from typing import (Iterable , List , Literal , Mapping , Optional , Set , Tuple ,
7
- TypedDict , Union , cast )
6
+ from typing import (Any , Dict , Iterable , List , Literal , Mapping , Optional , Set ,
7
+ Tuple , TypedDict , Union )
8
8
9
9
import numpy as np
10
10
import torch
11
11
import torch .utils .checkpoint
12
12
from torch import nn
13
13
from torch .nn import functional as F
14
+ from transformers import BatchFeature
14
15
from transformers .models .whisper import WhisperFeatureExtractor
15
16
from transformers .models .whisper .modeling_whisper import WhisperEncoder
16
17
17
18
from vllm .attention import AttentionMetadata
18
19
from vllm .config import VllmConfig
19
- from vllm .inputs import (INPUT_REGISTRY , DecoderOnlyInputs , DummyData ,
20
- InputContext , token_inputs )
20
+ from vllm .inputs import InputContext
21
21
from vllm .model_executor .layers .activation import SiluAndMul , get_act_fn
22
22
from vllm .model_executor .layers .layernorm import RMSNorm
23
23
from vllm .model_executor .layers .sampler import SamplerOutput , get_sampler
24
24
from vllm .model_executor .model_loader .loader import DefaultModelLoader
25
25
from vllm .model_executor .sampling_metadata import SamplingMetadata
26
- from vllm .multimodal import ( MULTIMODAL_REGISTRY , MultiModalKwargs ,
27
- NestedTensors )
28
- from vllm . multimodal . utils import ( cached_get_tokenizer ,
29
- consecutive_placeholder_ranges ,
30
- repeat_and_pad_placeholder_tokens )
31
- from vllm .sequence import IntermediateTensors , SequenceData
26
+ from vllm .multimodal import MULTIMODAL_REGISTRY , NestedTensors
27
+ from vllm . multimodal . processing import ( BaseMultiModalProcessor ,
28
+ MultiModalDataDict ,
29
+ MultiModalDataItems , ProcessorInputs ,
30
+ PromptReplacement )
31
+ from vllm .sequence import IntermediateTensors
32
32
from vllm .transformers_utils .configs .ultravox import UltravoxConfig
33
- from vllm .utils import is_list_of
34
33
35
34
from .interfaces import SupportsMultiModal , SupportsPP
36
35
from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
37
36
init_vllm_registered_model , maybe_prefix ,
38
37
merge_multimodal_embeddings_from_map )
39
38
40
- _AUDIO_PLACEHOLDER_TOKEN = 128002
41
39
_AUDIO_TOKENS_PER_SECOND = 6.25
42
40
43
41
@@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
72
70
return math .ceil (feature_extractor .chunk_length * _AUDIO_TOKENS_PER_SECOND )
73
71
74
72
75
- def dummy_seq_data_for_ultravox (
76
- ctx : InputContext ,
77
- seq_len : int ,
78
- audio_count : int ,
79
- ):
80
- audio_length = min (get_ultravox_max_audio_tokens (ctx ),
81
- seq_len // audio_count )
73
+ class UltravoxMultiModalProcessor (BaseMultiModalProcessor ):
82
74
83
- return SequenceData .from_prompt_token_counts (
84
- (_AUDIO_PLACEHOLDER_TOKEN , audio_length * audio_count ),
85
- (0 , seq_len - audio_length * audio_count )), {
86
- "audio" :
87
- consecutive_placeholder_ranges (num_items = audio_count ,
88
- item_size = audio_length )
89
- }
90
-
91
-
92
- def dummy_audio_for_ultravox (
93
- ctx : InputContext ,
94
- audio_count : int ,
95
- ):
96
- feature_extractor = whisper_feature_extractor (ctx )
97
- audio_and_sr = (np .array ([0.0 ] * feature_extractor .chunk_length ), 1 )
98
- return {"audio" : [audio_and_sr ] * audio_count }
99
-
100
-
101
- def dummy_data_for_ultravox (
102
- ctx : InputContext ,
103
- seq_len : int ,
104
- mm_counts : Mapping [str , int ],
105
- ):
106
- audio_count = mm_counts ["audio" ]
107
- seq_data , ranges = dummy_seq_data_for_ultravox (ctx , seq_len , audio_count )
108
- mm_dict = dummy_audio_for_ultravox (ctx , audio_count )
109
-
110
- return DummyData (seq_data , mm_dict , ranges )
111
-
112
-
113
- def input_mapper_for_ultravox (ctx : InputContext , data : object ):
114
- if not isinstance (data , list ):
115
- data = [data ]
116
-
117
- if len (data ) == 0 :
118
- return MultiModalKwargs ()
119
-
120
- # If the audio inputs are embeddings, no need for preprocessing
121
- if is_list_of (data , torch .Tensor , check = "all" ):
122
- return MultiModalKwargs ({"audio_embeds" : data })
123
-
124
- audio_features = []
125
- for audio_input in data :
126
- if not isinstance (audio_input , tuple ):
127
- raise NotImplementedError (
128
- f"Unsupported data type: { type (audio_input )} " )
129
-
130
- (audio , sr ) = cast (Tuple [np .ndarray , Union [float , int ]], audio_input )
131
- feature_extractor = whisper_feature_extractor (ctx )
75
+ def _get_feature_extractor (self ) -> WhisperFeatureExtractor :
76
+ return self ._get_hf_processor ().audio_processor .feature_extractor
132
77
78
+ def _resample_audio (
79
+ self ,
80
+ audio : np .ndarray ,
81
+ sr : int ,
82
+ ) -> Dict [str , Union [np .ndarray , int ]]:
83
+ # resample audio to the model's sampling rate
84
+ feature_extractor = self ._get_feature_extractor ()
133
85
if sr != feature_extractor .sampling_rate :
134
86
try :
135
87
import librosa
@@ -140,78 +92,92 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
140
92
orig_sr = sr ,
141
93
target_sr = feature_extractor .sampling_rate )
142
94
sr = feature_extractor .sampling_rate
95
+ return {"audio" : audio , "sampling_rate" : sr }
143
96
144
- minimum_audio_length = feature_extractor .n_fft // 2 + 1
145
- if len (audio ) < minimum_audio_length :
146
- # Not enough audio; pad it.
147
- audio = np .pad (audio , (0 , minimum_audio_length - len (audio )))
148
-
149
- single_audio_features = feature_extractor (
150
- audio , sampling_rate = sr , padding = "longest" ,
151
- return_tensors = "pt" )["input_features" ]
152
-
153
- # Remove the batch dimension because we're wrapping it in a list.
154
- audio_features .append (single_audio_features .squeeze (0 ))
155
-
156
- return MultiModalKwargs ({"audio_features" : audio_features })
157
-
158
-
159
- def input_processor_for_ultravox (ctx : InputContext , inputs : DecoderOnlyInputs ):
160
- multi_modal_data = inputs .get ("multi_modal_data" )
161
- if multi_modal_data is None or "audio" not in multi_modal_data :
162
- return inputs
97
+ def _apply_hf_processor (
98
+ self ,
99
+ prompt : str ,
100
+ mm_data : MultiModalDataDict ,
101
+ mm_processor_kwargs : Mapping [str , object ],
102
+ ) -> BatchFeature :
103
+ if not mm_data or not mm_data .get ("audio" , None ):
104
+ return super ()._apply_hf_processor (prompt , mm_data ,
105
+ mm_processor_kwargs )
106
+
107
+ audio_data = mm_data ["audio" ]
108
+ if not isinstance (audio_data , list ):
109
+ audio_data = [audio_data ]
110
+
111
+ # Ultravox processor doesn't support multiple inputs,
112
+ # therefore we need to input text and audio one by one
113
+ tokenizer = self ._get_tokenizer ()
114
+ audio_features , audio_token_len = [], []
115
+ processed_inputs = {}
116
+ for audio , sr in audio_data :
117
+ data = self ._resample_audio (audio , sr )
118
+ processed_inputs = super ()._apply_hf_processor (
119
+ prompt , data , mm_processor_kwargs )
120
+ prompt = tokenizer .decode (processed_inputs ["input_ids" ][0 ],
121
+ skip_special_tokens = False )
122
+ audio_features .append (
123
+ processed_inputs .pop ("audio_values" ).squeeze (0 ))
124
+ audio_token_len .append (
125
+ processed_inputs .pop ("audio_token_len" ).item ())
126
+
127
+ return dict (
128
+ ** processed_inputs ,
129
+ audio_features = audio_features ,
130
+ audio_token_len = audio_token_len ,
131
+ )
163
132
164
- if "multi_modal_placeholders" in inputs and "audio" in inputs [
165
- "multi_modal_placeholders" ]:
166
- # The inputs already have placeholders.
167
- return inputs
133
+ def _get_processor_data (
134
+ self ,
135
+ mm_data : MultiModalDataDict ,
136
+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
137
+ # Ultravox uses "audio" instead of "audios" as calling keyword
138
+ processor_data , passthrough_data = super ()._get_processor_data (mm_data )
139
+ if "audios" in processor_data :
140
+ processor_data ["audio" ] = processor_data .pop ("audios" )
141
+ return processor_data , passthrough_data
142
+
143
+ def _get_prompt_replacements (
144
+ self ,
145
+ mm_items : MultiModalDataItems ,
146
+ hf_inputs : BatchFeature ,
147
+ mm_processor_kwargs : Mapping [str , object ],
148
+ ) -> list [PromptReplacement ]:
149
+ hf_processor = self ._get_hf_processor ()
150
+ placeholder = hf_processor .audio_token_replacement
151
+
152
+ def get_replacement_ultravox (item_idx : int ):
153
+ audio_token_len = hf_inputs ["audio_token_len" ][item_idx ]
154
+ return placeholder * audio_token_len
155
+
156
+ return [
157
+ PromptReplacement (
158
+ modality = "audio" ,
159
+ target = "<|audio|>" ,
160
+ replacement = get_replacement_ultravox ,
161
+ )
162
+ ]
168
163
169
- feature_extractor = whisper_feature_extractor (ctx )
170
- audios = multi_modal_data ["audio" ]
171
- if not isinstance (audios , list ):
172
- audios = [audios ]
173
-
174
- audio_token_counts = []
175
- for audio in audios :
176
- if isinstance (audio , torch .Tensor ):
177
- audio_num_tokens = audio .shape [1 ]
178
- audio_token_counts .append (audio_num_tokens )
179
- else :
180
- audio_data , sample_rate = audio
181
- audio_length = audio_data .shape [0 ]
182
- if sample_rate != feature_extractor .sampling_rate :
183
- # Account for resampling.
184
- adjustment = feature_extractor .sampling_rate / sample_rate
185
- audio_length = math .ceil (adjustment * audio_length )
186
-
187
- feature_extractor_output_length = math .ceil (
188
- (audio_length - (feature_extractor .hop_length - 1 )) /
189
- feature_extractor .hop_length )
190
-
191
- uv_config = ctx .get_hf_config (UltravoxConfig )
192
- audio_num_tokens = min (
193
- max (
194
- 1 ,
195
- math .ceil (feature_extractor_output_length /
196
- (uv_config .stack_factor * 2 ))),
197
- get_ultravox_max_audio_tokens (ctx ))
198
- audio_token_counts .append (audio_num_tokens )
199
-
200
- tokenizer = cached_get_tokenizer (ctx .model_config .tokenizer )
201
-
202
- new_prompt , new_token_ids , ranges = repeat_and_pad_placeholder_tokens (
203
- tokenizer ,
204
- inputs .get ("prompt" ),
205
- inputs ["prompt_token_ids" ],
206
- placeholder_token_id = _AUDIO_PLACEHOLDER_TOKEN ,
207
- repeat_count = audio_token_counts ,
208
- )
209
-
210
- # NOTE: Create a defensive copy of the original inputs
211
- return token_inputs (prompt_token_ids = new_token_ids ,
212
- prompt = new_prompt ,
213
- multi_modal_data = multi_modal_data ,
214
- multi_modal_placeholders = {"audio" : ranges })
164
+ def _get_dummy_mm_inputs (
165
+ self ,
166
+ mm_counts : Mapping [str , int ],
167
+ ) -> ProcessorInputs :
168
+ feature_extractor = self ._get_feature_extractor ()
169
+ sampling_rate = feature_extractor .sampling_rate
170
+ audio_len = feature_extractor .chunk_length * sampling_rate
171
+
172
+ audio_count = mm_counts ["audio" ]
173
+ audio = np .zeros (audio_len )
174
+ data = {"audio" : [(audio , sampling_rate )] * audio_count }
175
+
176
+ return ProcessorInputs (
177
+ prompt_text = "<|audio|>" * audio_count ,
178
+ mm_data = data ,
179
+ mm_processor_kwargs = {},
180
+ )
215
181
216
182
217
183
class StackAudioFrames (nn .Module ):
@@ -332,11 +298,9 @@ def forward(
332
298
return hidden_states
333
299
334
300
335
- @MULTIMODAL_REGISTRY .register_input_mapper ("audio" , input_mapper_for_ultravox )
336
301
@MULTIMODAL_REGISTRY .register_max_multimodal_tokens (
337
302
"audio" , get_ultravox_max_audio_tokens )
338
- @INPUT_REGISTRY .register_dummy_data (dummy_data_for_ultravox )
339
- @INPUT_REGISTRY .register_input_processor (input_processor_for_ultravox )
303
+ @MULTIMODAL_REGISTRY .register_processor (UltravoxMultiModalProcessor )
340
304
class UltravoxModel (nn .Module , SupportsMultiModal , SupportsPP ):
341
305
342
306
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
0 commit comments