Skip to content

Commit 59c9b6e

Browse files
ywang96ywang96
and
ywang96
authored
[V1][VLM] Proper memory profiling for image language models (#11210)
Signed-off-by: Roger Wang <[email protected]> Co-authored-by: ywang96 <[email protected]>
1 parent 66d4b16 commit 59c9b6e

File tree

6 files changed

+98
-13
lines changed

6 files changed

+98
-13
lines changed

vllm/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,14 @@ class SchedulerConfig:
12801280

12811281
is_multimodal_model: bool = False
12821282

1283+
# FIXME(woosuk & ywang96): Below are placeholder values. We need to
1284+
# calculate the actual values from the configurations.
1285+
# Multimodal encoder run compute budget, only used in V1
1286+
max_num_encoder_input_tokens = 16384
1287+
1288+
# Multimodal encoder cache size, only used in V1
1289+
encoder_cache_size = 16384
1290+
12831291
# Whether to perform preemption by swapping or
12841292
# recomputation. If not specified, we determine the mode as follows:
12851293
# We use recomputation by default since it incurs lower overhead than

vllm/model_executor/models/pixtral.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
245245
# Do not split, return as tensor of shape [1, fs, hs]
246246
return image_embeds.unsqueeze(0)
247247

248+
# If the last split index is the last index in image_tokens, we
249+
# ignore it to avoid empty split tensor
250+
if split_indices[-1] == len(image_tokens):
251+
split_indices = split_indices[:-1]
252+
248253
image_embeds = image_embeds.tensor_split(split_indices.cpu())
249254
return image_embeds
250255

vllm/multimodal/registry.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,23 @@ def register_max_image_tokens(
200200
"""
201201
return self.register_max_multimodal_tokens("image", max_mm_tokens)
202202

203+
def get_max_tokens_per_item_by_modality(
204+
self,
205+
model_config: "ModelConfig",
206+
) -> Mapping[str, int]:
207+
"""
208+
Get the maximum number of tokens per data item from each modality
209+
for profiling the memory usage of a model.
210+
211+
Note:
212+
This is currently directly used only in V1.
213+
"""
214+
215+
return {
216+
key: plugin.get_max_multimodal_tokens(model_config)
217+
for key, plugin in self._plugins.items()
218+
}
219+
203220
def get_max_tokens_by_modality(
204221
self,
205222
model_config: "ModelConfig",
@@ -216,9 +233,9 @@ def get_max_tokens_by_modality(
216233
limits_per_plugin = self._limits_by_model[model_config]
217234

218235
return {
219-
key: (limits_per_plugin[key] *
220-
plugin.get_max_multimodal_tokens(model_config))
221-
for key, plugin in self._plugins.items()
236+
key: limits_per_plugin[key] * max_tokens_per_mm_item
237+
for key, max_tokens_per_mm_item in
238+
self.get_max_tokens_per_item_by_modality(model_config).items()
222239
}
223240

224241
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:

vllm/v1/core/scheduler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,13 @@ def __init__(
7373
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
7474
# projector if needed). Currently, we assume that the encoder also
7575
# has the Transformer architecture (e.g., ViT).
76-
# FIXME(woosuk): Below are placeholder values. We need to calculate the
77-
# actual values from the configurations.
78-
self.max_num_encoder_input_tokens = 16384
76+
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
7977
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
8078
# the encoder cache will not be initialized and used, regardless of
8179
# the cache size. This is because the memory space for the encoder cache
8280
# is preallocated in the profiling run.
83-
self.encoder_cache_manager = EncoderCacheManager(cache_size=16384)
81+
self.encoder_cache_manager = EncoderCacheManager(
82+
cache_size=self.scheduler_config.encoder_cache_size)
8483

8584
def schedule(self) -> "SchedulerOutput":
8685
# NOTE(woosuk) on the scheduling algorithm:

vllm/v1/engine/mm_input_mapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def cache_hit_ratio(self, steps):
5454
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
5555
self.mm_cache_hits / self.mm_cache_total)
5656

57+
# TODO: Support modalities beyond image.
5758
def process_inputs(
5859
self,
5960
mm_data: MultiModalDataDict,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
from vllm.config import CompilationLevel, VllmConfig
1111
from vllm.distributed.parallel_state import graph_capture
1212
from vllm.forward_context import set_forward_context
13-
from vllm.inputs import INPUT_REGISTRY, InputRegistry
13+
from vllm.inputs import INPUT_REGISTRY
1414
from vllm.logger import init_logger
1515
from vllm.model_executor.model_loader import get_model
16-
from vllm.multimodal import MultiModalKwargs
16+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
1717
from vllm.sampling_params import SamplingType
1818
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
1919
LayerBlockType, cdiv, is_pin_memory_available)
2020
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
2121
FlashAttentionMetadata)
22+
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
2223
from vllm.v1.outputs import ModelRunnerOutput
2324
from vllm.v1.sample.metadata import SamplingMetadata
2425
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -35,7 +36,6 @@ def __init__(
3536
self,
3637
vllm_config: VllmConfig,
3738
device: torch.device,
38-
input_registry: InputRegistry = INPUT_REGISTRY,
3939
):
4040
self.vllm_config = vllm_config
4141
self.model_config = vllm_config.model_config
@@ -77,7 +77,12 @@ def __init__(
7777
self.hidden_size = model_config.get_hidden_size()
7878

7979
# Multi-modal data support
80-
self.input_registry = input_registry
80+
self.input_registry = INPUT_REGISTRY
81+
self.mm_registry = MULTIMODAL_REGISTRY
82+
# NOTE: mm_input_mapper is only used for memory profiling.
83+
self.mm_input_mapper = MMInputMapperClient(self.model_config)
84+
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
85+
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
8186

8287
# Lazy initialization
8388
# self.model: nn.Module # Set after load_model
@@ -599,8 +604,6 @@ def _dummy_run(
599604
return hidden_states
600605

601606
def profile_run(self) -> None:
602-
# TODO(woosuk): Profile the max memory usage of the encoder and
603-
# the encoder cache.
604607
# use an empty tensor instead of `None`` to force Dynamo to pass
605608
# it by reference, rather by specializing on the value `None`.
606609
# the `dtype` argument does not matter, and we use `float32` as
@@ -612,6 +615,57 @@ def profile_run(self) -> None:
612615
torch.tensor([], dtype=torch.float32, device=self.device)
613616
for _ in range(self.num_attn_layers)
614617
]
618+
619+
# Profile with multimodal encoder & encoder cache.
620+
# TODO (ywang96): generalize this beyond image modality since
621+
# mm_input_mapper only supports image inputs.
622+
if self.is_multimodal_model:
623+
624+
# Create dummy batch of multimodal inputs.
625+
dummy_request_data = self.input_registry.dummy_data_for_profiling(
626+
model_config=self.model_config,
627+
seq_len=self.max_num_tokens,
628+
mm_registry=self.mm_registry,
629+
)
630+
dummy_mm_data = dummy_request_data.multi_modal_data
631+
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
632+
mm_data=dummy_mm_data,
633+
mm_hashes=None,
634+
mm_processor_kwargs=None,
635+
precomputed_mm_inputs=None)
636+
637+
# NOTE: Currently model is profiled with a single non-text
638+
# modality even when it supports multiple.
639+
max_tokens_per_mm_item = max(
640+
self.mm_registry.get_max_tokens_per_item_by_modality(
641+
self.model_config).values())
642+
643+
max_num_mm_items = min(
644+
self.max_num_encoder_input_tokens,
645+
self.encoder_cache_size) // max_tokens_per_mm_item
646+
647+
# Dummy data definition in V0 may contain multiple multimodal items
648+
# (e.g, multiple images) for a single request, therefore here we
649+
# always replicate first item by max_num_mm_items times since in V1
650+
# they are scheduled to be processed separately.
651+
batched_dummy_mm_inputs = MultiModalKwargs.batch(
652+
[dummy_mm_kwargs[0]] * max_num_mm_items)
653+
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
654+
batched_dummy_mm_inputs, device=self.device)
655+
656+
# Run multimodal encoder.
657+
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
658+
**batched_dummy_mm_inputs)
659+
assert len(dummy_encoder_outputs) == max_num_mm_items, (
660+
"Expected dimension 0 of encoder outputs to match the number "
661+
f"of multimodal data items: {max_num_mm_items}, got "
662+
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
663+
"due to the 'get_multimodal_embeddings' method of the model "
664+
"not implemented correctly.")
665+
666+
# Cache the dummy encoder outputs.
667+
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
668+
615669
# Trigger compilation for general shape.
616670
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
617671
dummy_kv_caches)
@@ -620,6 +674,7 @@ def profile_run(self) -> None:
620674
# TODO(woosuk): Consider the memory usage of the sampler.
621675
torch.cuda.synchronize()
622676
del hidden_states, logits
677+
self.encoder_cache.clear()
623678
gc.collect()
624679

625680
def capture_model(self) -> None:

0 commit comments

Comments
 (0)