-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[V1][Core] Autotune encoder cache budget #11895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
495f669
8c67ecd
5938a1f
0e4ab3c
bd1ccf1
2a4b1d5
9ee3f3d
7614888
aaf3cef
e8f50f4
767b0d6
eb125b5
f539470
29ad359
3103622
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,15 @@ | ||||||||||||||||||
from typing import Dict, List, Set, Tuple | ||||||||||||||||||
from typing import TYPE_CHECKING, Dict, List, Set, Tuple | ||||||||||||||||||
|
||||||||||||||||||
from vllm.logger import init_logger | ||||||||||||||||||
from vllm.multimodal import MULTIMODAL_REGISTRY | ||||||||||||||||||
from vllm.utils import cdiv | ||||||||||||||||||
from vllm.v1.request import Request | ||||||||||||||||||
|
||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||
from vllm.config import ModelConfig, SchedulerConfig | ||||||||||||||||||
|
||||||||||||||||||
logger = init_logger(__name__) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class EncoderCacheManager: | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -46,3 +54,79 @@ def get_freed_ids(self) -> List[Tuple[str, int]]: | |||||||||||||||||
freed = self.freed | ||||||||||||||||||
self.freed = [] | ||||||||||||||||||
return freed | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def compute_encoder_cache_budget( | ||||||||||||||||||
model_config: "ModelConfig", | ||||||||||||||||||
scheduler_config: "SchedulerConfig", | ||||||||||||||||||
) -> int: | ||||||||||||||||||
"""Compute the encoder cache budget based on the model and scheduler | ||||||||||||||||||
configurations. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
model_config: Model configuration. | ||||||||||||||||||
scheduler_config: Scheduler configuration. | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
The encoder cache budget, in unit of number of tokens | ||||||||||||||||||
in the input sequence. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
encoder_cache_budget = 0 | ||||||||||||||||||
|
||||||||||||||||||
# TODO: handle encoder-decoder models once we support them. | ||||||||||||||||||
if not model_config.is_multimodal_model: | ||||||||||||||||||
return encoder_cache_budget | ||||||||||||||||||
|
||||||||||||||||||
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 | ||||||||||||||||||
model_config) | ||||||||||||||||||
|
||||||||||||||||||
if not max_tokens_by_modality_dict: | ||||||||||||||||||
logger.warning( | ||||||||||||||||||
"All non-text modalities supported by the model have been " | ||||||||||||||||||
"explicitly disabled via limit_mm_per_prompt. Encoder cache will " | ||||||||||||||||||
"not be initialized.") | ||||||||||||||||||
return encoder_cache_budget | ||||||||||||||||||
|
||||||||||||||||||
modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), | ||||||||||||||||||
key=lambda item: item[1]) | ||||||||||||||||||
|
||||||||||||||||||
max_num_batched_tokens = scheduler_config.max_num_batched_tokens | ||||||||||||||||||
max_num_reqs = scheduler_config.max_num_seqs | ||||||||||||||||||
|
||||||||||||||||||
# In case that the biggest possible multimodal item takes space more | ||||||||||||||||||
# than the batch size, then it needs to be cached and chunk prefilled. | ||||||||||||||||||
if max_tokens_per_mm_item > max_num_batched_tokens: | ||||||||||||||||||
num_items = 1 | ||||||||||||||||||
|
||||||||||||||||||
# In case that the biggest possible multimodal item takes space less | ||||||||||||||||||
# the batch size, then all items will be full prefilled except one. | ||||||||||||||||||
else: | ||||||||||||||||||
num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item) | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment seems a bit confusing to me. I tried to rephrase base on my understanding but please help clarify: num_items == 1: # The biggest possible multimodal item cannot be prefilled in a batch,
# so it must be cached and chunked prefill. num_items > 1: # A batch can cover all (except the last one) multimodal items. Meanwhile, I don't fully understand what you meant by "cached" and "chunked prefill" tho. I suppose they are orthogonal to the number of items? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will clarify this. During profiling we always take the worst case (i.e requests will all have the biggest possible multimodal item), so what I meant by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarified via 2a4b1d5 |
||||||||||||||||||
# NOTE: We need the encoder cache to be able to compute & hold ONE | ||||||||||||||||||
# ADDITIONAL multimodal item, and is required only when: | ||||||||||||||||||
# - Two requests in the current batch share the same prefix with such item | ||||||||||||||||||
# as part of the prefix. | ||||||||||||||||||
# - AND the prefix length is divisible by the block size, triggering the | ||||||||||||||||||
# recomputation of the last block. | ||||||||||||||||||
# - AND the part of the embeddings of the item is in this last block. | ||||||||||||||||||
|
||||||||||||||||||
# This can be improved when we have a global encoder cache that does | ||||||||||||||||||
# not associate items to request id only. | ||||||||||||||||||
num_items += 1 | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is only applicable to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is applicable to all cases, and is in fact in the Here's a concrete example: Time 0: Request 0 gets scheduled for 8192 tokens. Since Time 1:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both cases would need this. Also for this comment # This can be improved when we have a global encoder cache that does
# not associate items to request id only. This cannot address the issue fundamentally, because we also need to guarantee the item is always available in the encoder cache when we schedule the request. For example, an item used by request A and request B. Request A has finished so prefix and mm items are cached. However, due to encoder cache budget, one item in request A is evicted before request B comes. This would result in the same problem. I guess this can somehow be avoided if we could guarantee all prefix cached mm items are always available in encoder cache as well, but fundamentally this has to be solved by supporting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's a good callout! I've adjusted the comment accordingly. |
||||||||||||||||||
# Number of items needed cannot be bigger than max number of running | ||||||||||||||||||
# requests * max number of multimodal items per request. | ||||||||||||||||||
max_mm_items_per_req = max( | ||||||||||||||||||
MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values()) | ||||||||||||||||||
|
||||||||||||||||||
num_items = min(num_items, max_num_reqs * max_mm_items_per_req) | ||||||||||||||||||
encoder_cache_budget = num_items * max_tokens_per_mm_item | ||||||||||||||||||
|
||||||||||||||||||
logger.info( | ||||||||||||||||||
"Encoder cache will be initialized with a budget of %s tokens," | ||||||||||||||||||
" and profiled with %s %s items of the maximum feature size.", | ||||||||||||||||||
encoder_cache_budget, num_items, modality) | ||||||||||||||||||
|
||||||||||||||||||
return encoder_cache_budget |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
LayerBlockType, cdiv, is_pin_memory_available) | ||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, | ||
FlashAttentionMetadata) | ||
from vllm.v1.core.encoder_cache_manager import compute_encoder_cache_budget | ||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient | ||
from vllm.v1.outputs import ModelRunnerOutput | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
@@ -87,8 +88,8 @@ def __init__( | |
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) | ||
self.mm_input_mapper_profiling.use_cache = False | ||
|
||
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 | ||
self.encoder_cache_size = self.scheduler_config.encoder_cache_size | ||
self.encoder_cache_budget = compute_encoder_cache_budget( | ||
self.model_config, self.scheduler_config) | ||
|
||
# Lazy initialization | ||
# self.model: nn.Module # Set after load_model | ||
|
@@ -720,53 +721,27 @@ def profile_run(self) -> None: | |
] | ||
|
||
# Profile with multimodal encoder & encoder cache. | ||
if self.is_multimodal_model: | ||
|
||
# Create dummy batch of multimodal inputs. | ||
dummy_request_data = self.input_registry.dummy_data_for_profiling( | ||
model_config=self.model_config, | ||
seq_len=self.max_num_tokens, | ||
mm_registry=self.mm_registry, | ||
) | ||
dummy_mm_data = dummy_request_data.multi_modal_data | ||
# TODO: handle encoder-decoder models once we support them. | ||
if self.is_multimodal_model and self.encoder_cache_budget > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A dumb question: Is it required to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's possible that the users will explicity disable all modalities (e.g, --limit-mm-per-prompt image=0) for whatever reason that have, in that case the model is just a text model with a multimodal encoder loaded but not used (this can be also optimized in the future), so we should just skip the profiling with multimodal data entirely. Not a required check, but running all the following code means getting all the dummy data then not using any of them. |
||
|
||
# NOTE: Currently model is profiled with a single non-text | ||
# modality with the max possible input tokens even when | ||
# it supports multiple. | ||
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501 | ||
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 | ||
self.model_config) | ||
|
||
dummy_data_modality, max_tokens_per_mm_item = max( | ||
max_tokens_by_modality_dict.items(), key=lambda item: item[1]) | ||
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Check how many items of this modality can be supported by | ||
# the encoder cache budget. | ||
encoder_cache_budget = min(self.max_num_encoder_input_tokens, | ||
self.encoder_cache_size) | ||
max_num_mm_items_encoder_budget = encoder_cache_budget // \ | ||
max_tokens_per_mm_item | ||
|
||
# TODO: Allow users to set encoder_cache_budget in case this | ||
# happens. | ||
assert max_num_mm_items_encoder_budget > 0, ( | ||
f"Encoder cache budget={encoder_cache_budget} is too small to " | ||
f"support the maximum possible size of multimodal embeddings" | ||
f"={max_tokens_per_mm_item}.") | ||
|
||
# Check how many items of this modality can be supported by | ||
# the decoder budget. | ||
max_mm_items_per_req = max( | ||
self.mm_registry.get_mm_limits_per_prompt( | ||
self.model_config).values()) | ||
|
||
# NOTE: We do not consider max_num_batched_tokens on purpose | ||
# because the multimodal embeddings can be generated in advance | ||
# and chunked prefilled. | ||
max_num_mm_items_decoder_budget = self.max_num_reqs * \ | ||
max_mm_items_per_req | ||
|
||
max_num_mm_items = min(max_num_mm_items_encoder_budget, | ||
max_num_mm_items_decoder_budget) | ||
max_num_mm_items = self.encoder_cache_budget // max_tokens_per_mm_item # noqa: E501 | ||
|
||
# Create dummy batch of multimodal inputs. | ||
dummy_request_data = self.input_registry.dummy_data_for_profiling( | ||
model_config=self.model_config, | ||
seq_len=self.max_num_tokens, | ||
mm_registry=self.mm_registry, | ||
) | ||
dummy_mm_data = dummy_request_data.multi_modal_data | ||
Comment on lines
+768
to
+774
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this is just a reordering for better readability. |
||
|
||
# Dummy data definition in V0 may contain multiple multimodal items | ||
# (e.g, multiple images) for a single request, therefore here we | ||
|
Uh oh!
There was an error while loading. Please reload this page.