Skip to content

[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 #16273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union

import msgspec
Expand Down Expand Up @@ -52,7 +53,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: list[int]
mm_inputs: Optional[list[MultiModalKwargs]]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
)

# Setup MM Input Mapper.
self.mm_input_cache_server = MMInputCacheServer(
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config)

# Setup batch queue for pipeline parallelism.
Expand Down Expand Up @@ -173,7 +173,7 @@ def add_request(self, request: EngineCoreRequest):
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)

req = Request.from_engine_core_request(request)
Expand Down
41 changes: 34 additions & 7 deletions vllm/v1/engine/mm_input_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Optional

from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of

# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
Expand All @@ -11,9 +14,11 @@
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
Expand All @@ -25,26 +30,48 @@
# variable VLLM_MM_INPUT_CACHE_GIB.


class MMInputCacheServer:
class MirroredProcessingCache:

def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)

def get_and_update(
def get_and_update_p0(
self,
mm_inputs: list[MultiModalKwargs],
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> list[MultiModalKwargs]:
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)

if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs

full_mm_inputs = []
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_hash in self.mm_cache:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input

full_mm_inputs.append(mm_input)

return full_mm_inputs

def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)

if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs

full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:
Expand Down
23 changes: 17 additions & 6 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, Union

from vllm.config import VllmConfig
Expand All @@ -19,6 +19,7 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(
self.tokenizer,
mm_registry)

self.mm_input_cache_client = MirroredProcessingCache(self.model_config)

# Multi-modal hasher (for images)
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
Expand Down Expand Up @@ -231,7 +234,7 @@ def process_inputs(
self.tokenizer.get_lora_tokenizer(lora_request))

# Multimodal related.
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
Expand All @@ -256,20 +259,28 @@ def process_inputs(
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}

for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))

orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
sorted_mm_inputs = [
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]

if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs

return EngineCoreRequest(
request_id=request_id,
prompt=decoder_inputs.get("prompt"),
Expand Down
14 changes: 9 additions & 5 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import enum
from typing import TYPE_CHECKING, Optional, Union

from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList

if TYPE_CHECKING:

from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange


class Request:
Expand All @@ -23,9 +22,9 @@ def __init__(
request_id: str,
prompt: Optional[str],
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
Expand Down Expand Up @@ -75,6 +74,11 @@ def __init__(

@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
assert isinstance(request.mm_inputs, list)
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
"mm_inputs was not updated in EngineCore.add_request")

return cls(
request_id=request.request_id,
prompt=request.prompt,
Expand Down