|
1 | 1 | # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
2 | 2 | """PyTorch Ultravox model."""
|
3 |
| - |
4 | 3 | import math
|
5 | 4 | from functools import cached_property
|
6 | 5 | from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
|
14 | 13 | from transformers.models.whisper import WhisperFeatureExtractor
|
15 | 14 | from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
16 | 15 |
|
| 16 | +from vllm import envs |
17 | 17 | from vllm.attention import AttentionMetadata
|
18 | 18 | from vllm.config import VllmConfig
|
19 | 19 | from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
|
35 | 35 | from .interfaces import SupportsMultiModal, SupportsPP
|
36 | 36 | from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
37 | 37 | init_vllm_registered_model, maybe_prefix,
|
| 38 | + merge_multimodal_embeddings, |
38 | 39 | merge_multimodal_embeddings_from_map)
|
39 | 40 |
|
| 41 | +_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" |
| 42 | +_AUDIO_PLACEHOLDER_TOKEN = 128002 |
40 | 43 | _AUDIO_TOKENS_PER_SECOND = 6.25
|
41 | 44 |
|
42 | 45 |
|
@@ -64,7 +67,14 @@ def _get_hf_processor(
|
64 | 67 | # Ignored in initialization
|
65 | 68 | sampling_rate: Optional[int] = None,
|
66 | 69 | ) -> ProcessorMixin:
|
67 |
| - return self.ctx.get_hf_processor() |
| 70 | + hf_processor = self.ctx.get_hf_processor() |
| 71 | + |
| 72 | + # NOTE: Ultravox processing definition uses '<|eot_id|>' as the |
| 73 | + # placeholder that will cause confusion with the actual end of turn |
| 74 | + # token, thus we override placeholder with a reserved special |
| 75 | + # token. |
| 76 | + hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE |
| 77 | + return hf_processor |
68 | 78 |
|
69 | 79 | def _get_feature_extractor(
|
70 | 80 | self,
|
@@ -465,11 +475,15 @@ def get_input_embeddings(
|
465 | 475 | inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
466 | 476 | if multimodal_embeddings is not None:
|
467 | 477 |
|
468 |
| - # TODO(ywang96): use merge_multimodal_embeddings after |
469 |
| - # v0 is deprecated |
470 |
| - merge_multimodal_embeddings_from_map( |
471 |
| - inputs_embeds, multimodal_embeddings, |
472 |
| - attn_metadata.multi_modal_placeholder_index_maps["audio"]) |
| 478 | + # TODO(ywang96): remove this block after v0 is deprecated. |
| 479 | + if not envs.VLLM_USE_V1: |
| 480 | + merge_multimodal_embeddings_from_map( |
| 481 | + inputs_embeds, multimodal_embeddings, |
| 482 | + attn_metadata.multi_modal_placeholder_index_maps["audio"]) |
| 483 | + else: |
| 484 | + inputs_embeds = merge_multimodal_embeddings( |
| 485 | + input_ids, inputs_embeds, multimodal_embeddings, |
| 486 | + _AUDIO_PLACEHOLDER_TOKEN) |
473 | 487 | return inputs_embeds
|
474 | 488 |
|
475 | 489 | def forward(self,
|
|
0 commit comments