-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[V1][Core] Autotune encoder cache budget #11895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
495f669
8c67ecd
5938a1f
0e4ab3c
bd1ccf1
2a4b1d5
9ee3f3d
7614888
aaf3cef
e8f50f4
767b0d6
eb125b5
f539470
29ad359
3103622
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,14 @@ | ||
from typing import Dict, List, Set, Tuple | ||
from typing import TYPE_CHECKING, Dict, List, Set, Tuple | ||
|
||
from vllm.logger import init_logger | ||
from vllm.multimodal import MULTIMODAL_REGISTRY | ||
from vllm.v1.request import Request | ||
|
||
if TYPE_CHECKING: | ||
from vllm.config import ModelConfig, SchedulerConfig | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class EncoderCacheManager: | ||
|
||
|
@@ -46,3 +53,72 @@ def get_freed_ids(self) -> List[Tuple[str, int]]: | |
freed = self.freed | ||
self.freed = [] | ||
return freed | ||
|
||
|
||
def compute_encoder_budget( | ||
model_config: "ModelConfig", | ||
scheduler_config: "SchedulerConfig", | ||
) -> Tuple[int, int]: | ||
"""Compute the encoder cache budget based on the model and scheduler | ||
configurations. | ||
|
||
Args: | ||
model_config: Model configuration. | ||
scheduler_config: Scheduler configuration. | ||
|
||
Returns: | ||
- Compute budget for encoder execution, in unit of number of tokens | ||
in the input sequence. | ||
- Space budget for encoder cache size, in unit of number of tokens | ||
in the input sequence. | ||
""" | ||
|
||
if not model_config.is_multimodal_model: | ||
return 0, 0 | ||
|
||
# TODO: handle encoder-decoder models once we support them. | ||
( | ||
encoder_compute_budget, | ||
encoder_cache_size, | ||
) = _compute_encoder_budget_multimodal(model_config, scheduler_config) | ||
|
||
return encoder_compute_budget, encoder_cache_size | ||
|
||
|
||
def _compute_encoder_budget_multimodal( | ||
model_config: "ModelConfig", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: Why do we need this separate function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually we might need to use |
||
scheduler_config: "SchedulerConfig", | ||
) -> Tuple[int, int]: | ||
"""Compute the encoder cache budget based on the model and scheduler | ||
configurations for a multimodal model. | ||
|
||
Args: | ||
model_config: Model configuration. | ||
scheduler_config: Scheduler configuration. | ||
|
||
Returns: | ||
- Compute budget for encoder execution, in unit of number of tokens | ||
in the input sequence. | ||
- Space budget for encoder cache size, in unit of number of tokens | ||
in the input sequence. | ||
""" | ||
|
||
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 | ||
model_config) | ||
|
||
if not max_tokens_by_modality_dict: | ||
logger.warning( | ||
"All non-text modalities supported by the model have been " | ||
"explicitly disabled via limit_mm_per_prompt. Encoder cache will " | ||
"not be initialized.") | ||
return 0, 0 | ||
|
||
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), | ||
key=lambda item: item[1]) | ||
|
||
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, | ||
max_tokens_per_mm_item) | ||
encoder_cache_size = max(scheduler_config.encoder_cache_size, | ||
max_tokens_per_mm_item) | ||
|
||
return encoder_compute_budget, encoder_cache_size |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
is_pin_memory_available) | ||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, | ||
FlashAttentionMetadata) | ||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget | ||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient | ||
from vllm.v1.outputs import ModelRunnerOutput | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
@@ -88,8 +89,12 @@ def __init__( | |
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) | ||
self.mm_input_mapper_profiling.use_cache = False | ||
|
||
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 | ||
self.encoder_cache_size = self.scheduler_config.encoder_cache_size | ||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( | ||
model_config=model_config, | ||
scheduler_config=scheduler_config, | ||
) | ||
self.max_num_encoder_input_tokens = encoder_compute_budget | ||
self.encoder_cache_size = encoder_cache_size | ||
|
||
# Lazy initialization | ||
# self.model: nn.Module # Set after load_model | ||
|
@@ -721,44 +726,30 @@ def profile_run(self) -> None: | |
] | ||
|
||
# Profile with multimodal encoder & encoder cache. | ||
if self.is_multimodal_model: | ||
|
||
# Create dummy batch of multimodal inputs. | ||
dummy_request_data = self.input_registry.dummy_data_for_profiling( | ||
model_config=self.model_config, | ||
seq_len=self.max_num_tokens, | ||
mm_registry=self.mm_registry, | ||
) | ||
dummy_mm_data = dummy_request_data.multi_modal_data | ||
# TODO: handle encoder-decoder models once we support them. | ||
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 | ||
and self.encoder_cache_size > 0): | ||
|
||
# NOTE: Currently model is profiled with a single non-text | ||
# modality with the max possible input tokens even when | ||
# it supports multiple. | ||
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501 | ||
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 | ||
self.model_config) | ||
|
||
dummy_data_modality, max_tokens_per_mm_item = max( | ||
max_tokens_by_modality_dict.items(), key=lambda item: item[1]) | ||
|
||
# Check how many items of this modality can be supported by | ||
# the encoder cache budget. | ||
encoder_cache_budget = min(self.max_num_encoder_input_tokens, | ||
self.encoder_cache_size) | ||
max_num_mm_items_encoder_budget = encoder_cache_budget // \ | ||
max_tokens_per_mm_item | ||
|
||
# TODO: Allow users to set encoder_cache_budget in case this | ||
# happens. | ||
assert max_num_mm_items_encoder_budget > 0, ( | ||
f"Encoder cache budget={encoder_cache_budget} is too small to " | ||
f"support the maximum possible size of multimodal embeddings" | ||
f"={max_tokens_per_mm_item}.") | ||
# the encoder budget. | ||
encoder_budget = min(self.max_num_encoder_input_tokens, | ||
self.encoder_cache_size) | ||
|
||
max_num_mm_items_encoder_budget = cdiv(encoder_budget, | ||
max_tokens_per_mm_item) | ||
|
||
# Check how many items of this modality can be supported by | ||
# the decoder budget. | ||
max_mm_items_per_req = max( | ||
self.mm_registry.get_mm_limits_per_prompt( | ||
self.model_config).values()) | ||
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( | ||
self.model_config)[dummy_data_modality] | ||
|
||
# NOTE: We do not consider max_num_batched_tokens on purpose | ||
# because the multimodal embeddings can be generated in advance | ||
|
@@ -769,6 +760,19 @@ def profile_run(self) -> None: | |
max_num_mm_items = min(max_num_mm_items_encoder_budget, | ||
max_num_mm_items_decoder_budget) | ||
|
||
logger.info( | ||
"Encoder cache will be initialized with a budget of %s tokens," | ||
" and profiled with %s %s items of the maximum feature size.", | ||
encoder_budget, max_num_mm_items, dummy_data_modality) | ||
|
||
# Create dummy batch of multimodal inputs. | ||
dummy_request_data = self.input_registry.dummy_data_for_profiling( | ||
model_config=self.model_config, | ||
seq_len=self.max_num_tokens, | ||
mm_registry=self.mm_registry, | ||
) | ||
dummy_mm_data = dummy_request_data.multi_modal_data | ||
Comment on lines
+768
to
+774
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this is just a reordering for better readability. |
||
|
||
# Dummy data definition in V0 may contain multiple multimodal items | ||
# (e.g, multiple images) for a single request, therefore here we | ||
# always replicate first item by max_num_mm_items times since in V1 | ||
|
Uh oh!
There was an error while loading. Please reload this page.