Skip to content

Commit e3baabd

Browse files
Fix prefix caching V0 R1
Co-authored-by: Ying Zhong <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent eb59b5a commit e3baabd

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,10 @@ def __init__(self, runner):
313313
cache_config = runner.cache_config
314314

315315
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
316+
self.enable_prefix_caching = cache_config.enable_prefix_caching
316317

317-
if self.chunked_prefill_enabled:
318-
self.chunked_prefill_workspace_size = min(
318+
if self.chunked_prefill_enabled or self.enable_prefix_caching:
319+
self.context_chunk_workspace_size = min(
319320
# Max sure there is enough for 8 full length request or at least
320321
# 4 pages of cache per request
321322
max(
@@ -330,7 +331,7 @@ def __init__(self, runner):
330331
# 2*(192*128)*(64*1024) = 3gb
331332
# (assuming 192 QK head dim, 128 heads, and fp16)
332333
128 * 1024)
333-
assert self.chunked_prefill_workspace_size >= \
334+
assert self.context_chunk_workspace_size >= \
334335
scheduler_config.max_num_seqs * cache_config.block_size
335336

336337
@contextmanager
@@ -430,23 +431,23 @@ def prepare_graph_input_buffers(self,
430431
"TritonMLAState does not support encoder/decoder yet")
431432

432433
def begin_forward(self, model_input):
433-
if self.chunked_prefill_enabled:
434-
if not hasattr(self, "chunked_prefill_workspace"):
434+
if self.chunked_prefill_enabled or self.enable_prefix_caching:
435+
if not hasattr(self, "context_chunk_workspace"):
435436
# not self.runner.device does not return the correct device
436437
# for this process, (init_device sets the correct device but
437438
# only on the Worker). The only way Ive figured out to get the
438439
# correct device is to allocate the workspace on the first call
439440
# to begin_forward and use the device of the input tokens
440441
assert model_input.input_tokens is not None
441-
self.chunked_prefill_workspace = torch.empty(
442-
(self.chunked_prefill_workspace_size,
442+
self.context_chunk_workspace = torch.empty(
443+
(self.context_chunk_workspace_size,
443444
self.model_config.get_head_size()),
444445
dtype=self.model_config.dtype,
445446
device=model_input.input_tokens.device,
446447
)
447448

448-
model_input.attn_metadata.chunked_prefill_workspace = \
449-
self.chunked_prefill_workspace
449+
model_input.attn_metadata.context_chunk_workspace = \
450+
self.context_chunk_workspace
450451

451452

452453
@dataclass
@@ -537,7 +538,7 @@ class MLACommonMetadata(AttentionMetadata):
537538
context_chunk_seq_tot: Optional[List[int]] = None
538539
context_chunk_max_seq_lens: Optional[List[int]] = None
539540
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
540-
chunked_prefill_workspace: Optional[torch.Tensor] = None
541+
context_chunk_workspace: Optional[torch.Tensor] = None
541542

542543
def __post_init__(self):
543544
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
@@ -747,11 +748,13 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
747748
self.block_size = input_builder.block_size
748749
self.chunked_prefill_enabled = \
749750
self.runner.scheduler_config.chunked_prefill_enabled
751+
self.enable_prefix_caching = \
752+
self.runner.cache_config.enable_prefix_caching
750753

751-
if self.chunked_prefill_enabled:
754+
if self.chunked_prefill_enabled or self.enable_prefix_caching:
752755
attn_state = self.input_builder.runner.attn_state
753-
self.chunked_prefill_workspace_size = \
754-
attn_state.chunked_prefill_workspace_size
756+
self.context_chunk_workspace_size = \
757+
attn_state.context_chunk_workspace_size
755758
self.page_size = self.runner.block_size
756759

757760
def prepare(self):
@@ -920,7 +923,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
920923
context_chunk_seq_tot = None
921924
context_chunk_max_seq_lens = None
922925

923-
if self.chunked_prefill_enabled and self.num_prefills > 0 \
926+
if (self.chunked_prefill_enabled or self.enable_prefix_caching) \
927+
and self.num_prefills > 0 \
924928
and context_lens_tensor is not None \
925929
and context_lens_tensor[:self.num_prefills].max() > 0:
926930

@@ -936,7 +940,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
936940
# algorithm here and allocate more workspace to prefills with
937941
# longer context lengths
938942
max_context_chunk = \
939-
self.chunked_prefill_workspace_size // num_prefills_with_context
943+
self.context_chunk_workspace_size // num_prefills_with_context
940944

941945
# align max_context_chunk to page_size by rounding down,
942946
# currently the `gather_cache` kernel cannot handle
@@ -965,7 +969,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
965969
chunk_seq_lens.max(dim=1).values.tolist()
966970
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
967971
assert max(context_chunk_seq_tot) <= \
968-
self.chunked_prefill_workspace_size
972+
self.context_chunk_workspace_size
969973

970974
return self.runner.attn_backend.make_metadata(
971975
# Required by ModelRunner
@@ -1288,8 +1292,8 @@ def _compute_prefill_context(
12881292
# Fetch from attn_metadata directly, since it late bound by
12891293
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
12901294
# any weirdness around prefill_metadata caching
1291-
assert attn_metadata.chunked_prefill_workspace is not None
1292-
workspace = attn_metadata.chunked_prefill_workspace
1295+
assert attn_metadata.context_chunk_workspace is not None
1296+
workspace = attn_metadata.context_chunk_workspace
12931297

12941298
for i in range(iters):
12951299
toks = prefill_metadata.context_chunk_seq_tot[i]
@@ -1502,12 +1506,12 @@ def forward(
15021506
"output is not yet supported for MLAImplBase")
15031507

15041508
if attn_metadata.is_profile_run and \
1505-
attn_metadata.chunked_prefill_workspace is not None:
1509+
attn_metadata.context_chunk_workspace is not None:
15061510
# During the profile run try to simulate to worse case output size
15071511
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
15081512
# since this can be large
15091513
_ = torch.empty(
1510-
(attn_metadata.chunked_prefill_workspace.shape[0],
1514+
(attn_metadata.context_chunk_workspace.shape[0],
15111515
self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
15121516
device=k_c_normed.device,
15131517
dtype=k_c_normed.dtype,

0 commit comments

Comments
 (0)