Skip to content

[BugFix] Fix prefix caching V0 MLA #14255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,10 @@ def __init__(self, runner):
cache_config = runner.cache_config

self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.enable_prefix_caching = cache_config.enable_prefix_caching

if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
if self.chunked_prefill_enabled or self.enable_prefix_caching:
self.context_chunk_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
Expand All @@ -330,7 +331,7 @@ def __init__(self, runner):
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
assert self.context_chunk_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size

@contextmanager
Expand Down Expand Up @@ -430,23 +431,23 @@ def prepare_graph_input_buffers(self,
"TritonMLAState does not support encoder/decoder yet")

def begin_forward(self, model_input):
if self.chunked_prefill_enabled:
if not hasattr(self, "chunked_prefill_workspace"):
if self.chunked_prefill_enabled or self.enable_prefix_caching:
if not hasattr(self, "context_chunk_workspace"):
# not self.runner.device does not return the correct device
# for this process, (init_device sets the correct device but
# only on the Worker). The only way Ive figured out to get the
# correct device is to allocate the workspace on the first call
# to begin_forward and use the device of the input tokens
assert model_input.input_tokens is not None
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.context_chunk_workspace = torch.empty(
(self.context_chunk_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=model_input.input_tokens.device,
)

model_input.attn_metadata.chunked_prefill_workspace = \
self.chunked_prefill_workspace
model_input.attn_metadata.context_chunk_workspace = \
self.context_chunk_workspace


@dataclass
Expand Down Expand Up @@ -537,7 +538,7 @@ class MLACommonMetadata(AttentionMetadata):
context_chunk_seq_tot: Optional[List[int]] = None
context_chunk_max_seq_lens: Optional[List[int]] = None
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
chunked_prefill_workspace: Optional[torch.Tensor] = None
context_chunk_workspace: Optional[torch.Tensor] = None

def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
Expand Down Expand Up @@ -747,11 +748,13 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.block_size = input_builder.block_size
self.chunked_prefill_enabled = \
self.runner.scheduler_config.chunked_prefill_enabled
self.enable_prefix_caching = \
self.runner.cache_config.enable_prefix_caching

if self.chunked_prefill_enabled:
if self.chunked_prefill_enabled or self.enable_prefix_caching:
attn_state = self.input_builder.runner.attn_state
self.chunked_prefill_workspace_size = \
attn_state.chunked_prefill_workspace_size
self.context_chunk_workspace_size = \
attn_state.context_chunk_workspace_size
self.page_size = self.runner.block_size

def prepare(self):
Expand Down Expand Up @@ -920,7 +923,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
context_chunk_seq_tot = None
context_chunk_max_seq_lens = None

if self.chunked_prefill_enabled and self.num_prefills > 0 \
if (self.chunked_prefill_enabled or self.enable_prefix_caching) \
and self.num_prefills > 0 \
and context_lens_tensor is not None \
and context_lens_tensor[:self.num_prefills].max() > 0:

Expand All @@ -936,7 +940,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = \
self.chunked_prefill_workspace_size // num_prefills_with_context
self.context_chunk_workspace_size // num_prefills_with_context

# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
Expand Down Expand Up @@ -965,7 +969,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
chunk_seq_lens.max(dim=1).values.tolist()
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
assert max(context_chunk_seq_tot) <= \
self.chunked_prefill_workspace_size
self.context_chunk_workspace_size

return self.runner.attn_backend.make_metadata(
# Required by ModelRunner
Expand Down Expand Up @@ -1288,8 +1292,8 @@ def _compute_prefill_context(
# Fetch from attn_metadata directly, since it late bound by
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
# any weirdness around prefill_metadata caching
assert attn_metadata.chunked_prefill_workspace is not None
workspace = attn_metadata.chunked_prefill_workspace
assert attn_metadata.context_chunk_workspace is not None
workspace = attn_metadata.context_chunk_workspace

for i in range(iters):
toks = prefill_metadata.context_chunk_seq_tot[i]
Expand Down Expand Up @@ -1502,12 +1506,12 @@ def forward(
"output is not yet supported for MLAImplBase")

if attn_metadata.is_profile_run and \
attn_metadata.chunked_prefill_workspace is not None:
attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_ = torch.empty(
(attn_metadata.chunked_prefill_workspace.shape[0],
(attn_metadata.context_chunk_workspace.shape[0],
self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
device=k_c_normed.device,
dtype=k_c_normed.dtype,
Expand Down