Skip to content

[V1][VLM] Proper memory profiling for image language models #11210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,14 @@ class SchedulerConfig:

is_multimodal_model: bool = False

# FIXME(woosuk & ywang96): Below are placeholder values. We need to
# calculate the actual values from the configurations.
# Multimodal encoder run compute budget, only used in V1
max_num_encoder_input_tokens = 16384

# Multimodal encoder cache size, only used in V1
encoder_cache_size = 16384

# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)

# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]

image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds

Expand Down
23 changes: 20 additions & 3 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@ def register_max_image_tokens(
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)

def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.

Note:
This is currently directly used only in V1.
"""

return {
key: plugin.get_max_multimodal_tokens(model_config)
for key, plugin in self._plugins.items()
}

def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
Expand All @@ -216,9 +233,9 @@ def get_max_tokens_by_modality(
limits_per_plugin = self._limits_by_model[model_config]

return {
key: (limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items()
key: limits_per_plugin[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
}

def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ def __init__(
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
self.max_num_encoder_input_tokens = 16384
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=16384)
self.encoder_cache_manager = EncoderCacheManager(
cache_size=self.scheduler_config.encoder_cache_size)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def cache_hit_ratio(self, steps):
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)

# TODO: Support modalities beyond image.
def process_inputs(
self,
mm_data: MultiModalDataDict,
Expand Down
67 changes: 61 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
Expand All @@ -35,7 +36,6 @@ def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
input_registry: InputRegistry = INPUT_REGISTRY,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
Expand Down Expand Up @@ -77,7 +77,12 @@ def __init__(
self.hidden_size = model_config.get_hidden_size()

# Multi-modal data support
self.input_registry = input_registry
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper is only used for memory profiling.
self.mm_input_mapper = MMInputMapperClient(self.model_config)
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size

# Lazy initialization
# self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -591,8 +596,6 @@ def _dummy_run(
return hidden_states

def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
Expand All @@ -604,6 +607,57 @@ def profile_run(self) -> None:
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]

# Profile with multimodal encoder & encoder cache.
# TODO (ywang96): generalize this beyond image modality since
# mm_input_mapper only supports image inputs.
if self.is_multimodal_model:

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
mm_data=dummy_mm_data,
mm_hashes=None,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)

# NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple.
max_tokens_per_mm_item = max(
self.mm_registry.get_max_tokens_per_item_by_modality(
self.model_config).values())

max_num_mm_items = min(
self.max_num_encoder_input_tokens,
self.encoder_cache_size) // max_tokens_per_mm_item

# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs[0]] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)

# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
assert len(dummy_encoder_outputs) == max_num_mm_items, (
"Expected dimension 0 of encoder outputs to match the number "
f"of multimodal data items: {max_num_mm_items}, got "
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly.")

# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
Expand All @@ -612,6 +666,7 @@ def profile_run(self) -> None:
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
self.encoder_cache.clear()
gc.collect()

def capture_model(self) -> None:
Expand Down
Loading