Skip to content

Commit 5108119

Browse files
committed
Initial prototype for multi-modal processor
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 93dee88 commit 5108119

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+943
-351
lines changed

docs/source/dev/multimodal/multimodal_index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Base Classes
5353

5454
.. autodata:: vllm.multimodal.MultiModalDataDict
5555

56-
.. autoclass:: vllm.multimodal.MultiModalInputs
56+
.. autoclass:: vllm.multimodal.MultiModalKwargs
5757
:members:
5858
:show-inheritance:
5959

docs/source/models/enabling_multimodal_inputs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
6666
3. Register maximum number of multi-modal tokens
6767
------------------------------------------------
6868

69-
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
69+
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
7070
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
7171

7272
.. code-block:: diff

tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from PIL.Image import Image
77

88
from vllm.inputs import InputContext, token_inputs
9-
from vllm.multimodal.base import MultiModalInputs
9+
from vllm.multimodal import MultiModalKwargs
1010
from vllm.multimodal.utils import cached_get_tokenizer
1111

1212
from .....conftest import IMAGE_ASSETS
@@ -96,7 +96,7 @@ def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
9696
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
9797
# Ensure that we get the appropriately shaped pixel_values
9898
# for images and image embeddings, respectively.
99-
assert isinstance(mapped_img_data, MultiModalInputs)
99+
assert isinstance(mapped_img_data, MultiModalKwargs)
100100
assert "pixel_values" in mapped_img_data
101101
assert mapped_img_data["pixel_values"].shape == expected_shape
102102

tests/multimodal/test_base.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from vllm.multimodal.base import MultiModalInputs, NestedTensors
3+
from vllm.multimodal.base import MultiModalKwargs, NestedTensors
44

55

66
def assert_nested_tensors_equal(expected: NestedTensors,
@@ -13,40 +13,40 @@ def assert_nested_tensors_equal(expected: NestedTensors,
1313
assert_nested_tensors_equal(expected_item, actual_item)
1414

1515

16-
def assert_multimodal_inputs_equal(expected: MultiModalInputs,
17-
actual: MultiModalInputs):
16+
def assert_multimodal_inputs_equal(expected: MultiModalKwargs,
17+
actual: MultiModalKwargs):
1818
assert set(expected.keys()) == set(actual.keys())
1919
for key in expected:
2020
assert_nested_tensors_equal(expected[key], actual[key])
2121

2222

2323
def test_multimodal_input_batch_single_tensor():
2424
t = torch.rand([1, 2])
25-
result = MultiModalInputs.batch([{"image": t}])
25+
result = MultiModalKwargs.batch([{"image": t}])
2626
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
2727

2828

2929
def test_multimodal_input_batch_multiple_tensors():
3030
a = torch.rand([1, 1, 2])
3131
b = torch.rand([1, 1, 2])
3232
c = torch.rand([1, 1, 2])
33-
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
33+
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
3434
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
3535

3636

3737
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
3838
a = torch.rand([1, 2, 2])
3939
b = torch.rand([1, 3, 2])
4040
c = torch.rand([1, 4, 2])
41-
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
41+
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
4242
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
4343

4444

4545
def test_multimodal_input_batch_nested_tensors():
4646
a = torch.rand([2, 3])
4747
b = torch.rand([2, 3])
4848
c = torch.rand([2, 3])
49-
result = MultiModalInputs.batch([{
49+
result = MultiModalKwargs.batch([{
5050
"image": [a]
5151
}, {
5252
"image": [b]
@@ -65,7 +65,7 @@ def test_multimodal_input_batch_heterogeneous_lists():
6565
a = torch.rand([1, 2, 3])
6666
b = torch.rand([1, 2, 3])
6767
c = torch.rand([1, 2, 3])
68-
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
68+
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
6969
assert_multimodal_inputs_equal(
7070
result,
7171
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
@@ -76,7 +76,7 @@ def test_multimodal_input_batch_multiple_batchable_lists():
7676
b = torch.rand([1, 2, 3])
7777
c = torch.rand([1, 2, 3])
7878
d = torch.rand([1, 2, 3])
79-
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
79+
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}])
8080
assert_multimodal_inputs_equal(
8181
result,
8282
{"image": torch.stack([torch.stack([a, b]),
@@ -88,8 +88,8 @@ def test_multimodal_input_batch_mixed_stacking_depths():
8888
b = torch.rand([1, 3, 3])
8989
c = torch.rand([1, 4, 3])
9090

91-
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
91+
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
9292
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
9393

94-
result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
94+
result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}])
9595
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class ModelConfig:
107107
matches the model name exposed via the APIs. If multiple model
108108
names provided, the first name will be used. If not specified,
109109
the model name will be the same as `model`.
110-
limit_mm_per_prompt: Maximum number of data instances per modality
110+
limit_mm_per_prompt: Maximum number of data items per modality
111111
per prompt. Only applicable for multimodal models.
112112
override_neuron_config: Initialize non default neuron config or
113113
override default neuron config that are specific to Neuron devices,

vllm/engine/async_llm_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.executor.gpu_executor import GPUExecutorAsync
2020
from vllm.executor.ray_utils import initialize_ray_cluster
2121
from vllm.inputs import PromptType
22+
from vllm.inputs.preprocess import InputPreprocessor
2223
from vllm.logger import init_logger
2324
from vllm.lora.request import LoRARequest
2425
from vllm.model_executor.guided_decoding import (
@@ -721,6 +722,9 @@ def _error_callback(self, exc: Exception) -> None:
721722
self.set_errored(exc)
722723
self._request_tracker.propagate_exception(exc)
723724

725+
async def get_input_preprocessor(self) -> InputPreprocessor:
726+
return self.engine.input_preprocessor
727+
724728
async def get_tokenizer(
725729
self,
726730
lora_request: Optional[LoRARequest] = None,

vllm/engine/llm_engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.model_executor.guided_decoding import (
4040
get_local_guided_decoding_logits_processor)
4141
from vllm.model_executor.layers.sampler import SamplerOutput
42+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
4243
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
4344
RequestOutputFactory)
4445
from vllm.pooling_params import PoolingParams
@@ -226,6 +227,7 @@ def __init__(
226227
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
227228
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
228229
input_registry: InputRegistry = INPUT_REGISTRY,
230+
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
229231
use_cached_outputs: bool = False,
230232
) -> None:
231233

@@ -338,7 +340,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
338340
model_config)
339341

340342
self.input_preprocessor = InputPreprocessor(model_config,
341-
self.tokenizer)
343+
self.tokenizer,
344+
mm_registry)
342345

343346
self.input_registry = input_registry
344347
self.input_processor = input_registry.create_input_processor(

vllm/engine/multiprocessing/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
# yapf: enable
3232
from vllm.envs import VLLM_RPC_TIMEOUT
3333
from vllm.inputs import PromptType
34+
from vllm.inputs.preprocess import InputPreprocessor
3435
from vllm.logger import init_logger
3536
from vllm.lora.request import LoRARequest
3637
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -94,6 +95,8 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
9495
parallel_config=engine_config.parallel_config,
9596
enable_lora=bool(engine_config.lora_config),
9697
)
98+
self.input_preprocessor = InputPreprocessor(self.model_config,
99+
self.tokenizer)
97100

98101
# Send RPCGenerateRequest to the MQLLMEngine.
99102
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
@@ -345,6 +348,9 @@ async def _check_success(error_message: str, socket: Socket):
345348
or response != VLLM_RPC_SUCCESS_STR):
346349
raise ValueError(error_message)
347350

351+
async def get_input_preprocessor(self) -> InputPreprocessor:
352+
return self.input_preprocessor
353+
348354
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
349355
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
350356

vllm/engine/protocol.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def generate(
6262
async def beam_search(
6363
self,
6464
prompt: PromptType,
65-
model_config: ModelConfig,
6665
request_id: str,
6766
params: BeamSearchParams,
6867
) -> AsyncGenerator[RequestOutput, None]:
@@ -74,13 +73,14 @@ async def beam_search(
7473
length_penalty = params.length_penalty
7574
include_stop_str_in_output = params.include_stop_str_in_output
7675

77-
tokenizer = await self.get_tokenizer()
78-
input_preprocessor = InputPreprocessor(model_config, tokenizer)
76+
preprocessor = await self.get_input_preprocessor()
77+
tokenizer_group = preprocessor.get_tokenizer_group()
78+
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
7979

8080
if is_explicit_encoder_decoder_prompt(prompt):
8181
raise NotImplementedError
8282
else:
83-
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
83+
processed_inputs = preprocessor._prompt_to_llm_inputs(
8484
prompt,
8585
request_id=request_id,
8686
)
@@ -220,6 +220,7 @@ async def abort(self, request_id: str) -> None:
220220
Args:
221221
request_id: The unique id of the request.
222222
"""
223+
...
223224

224225
@abstractmethod
225226
async def get_model_config(self) -> ModelConfig:
@@ -228,8 +229,13 @@ async def get_model_config(self) -> ModelConfig:
228229

229230
@abstractmethod
230231
async def get_decoding_config(self) -> DecodingConfig:
231-
...
232232
"""Get the decoding configuration of the vLLM engine."""
233+
...
234+
235+
@abstractmethod
236+
async def get_input_preprocessor(self) -> InputPreprocessor:
237+
"""Get the input processor of the vLLM engine."""
238+
...
233239

234240
@abstractmethod
235241
async def get_tokenizer(

vllm/entrypoints/openai/serving_chat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ async def create_chat_completion(
187187
if isinstance(sampling_params, BeamSearchParams):
188188
generator = self.engine_client.beam_search(
189189
prompt=engine_prompt,
190-
model_config=self.model_config,
191190
request_id=request_id,
192191
params=sampling_params,
193192
)

vllm/entrypoints/openai/serving_completion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ async def create_completion(
140140
if isinstance(sampling_params, BeamSearchParams):
141141
generator = self.engine_client.beam_search(
142142
prompt=engine_prompt,
143-
model_config=self.model_config,
144143
request_id=request_id,
145144
params=sampling_params,
146145
)

vllm/inputs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
44
TokensPrompt, build_explicit_enc_dec_prompt,
55
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
6-
from .registry import DummyData, InputContext, InputRegistry
6+
from .registry import (DummyData, InputContext, InputProcessingContext,
7+
InputRegistry)
78

89
INPUT_REGISTRY = InputRegistry()
910
"""
@@ -32,6 +33,7 @@
3233
"INPUT_REGISTRY",
3334
"DummyData",
3435
"InputContext",
36+
"InputProcessingContext",
3537
"InputRegistry",
3638
]
3739

vllm/inputs/data.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
if TYPE_CHECKING:
77
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
8+
from vllm.multimodal.inputs import MultiModalInputsV2
89

910

1011
class TextPrompt(TypedDict):
@@ -36,13 +37,13 @@ class TokensPrompt(TypedDict):
3637

3738
multi_modal_data: NotRequired["MultiModalDataDict"]
3839
"""
39-
Optional multi-modal data to pass to the model,
40+
DEPRECATED: Optional multi-modal data to pass to the model,
4041
if the model supports it.
4142
"""
4243

4344
mm_processor_kwargs: NotRequired[Dict[str, Any]]
4445
"""
45-
Optional multi-modal processor kwargs to be forwarded to the
46+
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
4647
multimodal input mapper & processor. Note that if multiple modalities
4748
have registered mappers etc for the model being considered, we attempt
4849
to pass the mm_processor_kwargs to each of them.
@@ -176,7 +177,7 @@ def token_inputs(
176177
return inputs
177178

178179

179-
DecoderOnlyInputs = TokenInputs
180+
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
180181
"""
181182
The inputs in :class:`~vllm.LLMEngine` before they are
182183
passed to the model executor.
@@ -191,14 +192,14 @@ class EncoderDecoderInputs(TypedDict):
191192
192193
This specifies the required data for encoder-decoder models.
193194
"""
194-
encoder: TokenInputs
195+
encoder: Union[TokenInputs, "MultiModalInputsV2"]
195196
"""The inputs for the encoder portion."""
196197

197-
decoder: TokenInputs
198+
decoder: Union[TokenInputs, "MultiModalInputsV2"]
198199
"""The inputs for the decoder portion."""
199200

200201

201-
SingletonInputs = TokenInputs
202+
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
202203
"""
203204
A processed :class:`SingletonPrompt` which can be passed to
204205
:class:`vllm.sequence.Sequence`.

0 commit comments

Comments
 (0)