@@ -313,9 +313,10 @@ def __init__(self, runner):
313
313
cache_config = runner .cache_config
314
314
315
315
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
316
+ self .enable_prefix_caching = cache_config .enable_prefix_caching
316
317
317
- if self .chunked_prefill_enabled :
318
- self .chunked_prefill_workspace_size = min (
318
+ if self .chunked_prefill_enabled or self . enable_prefix_caching :
319
+ self .context_chunk_workspace_size = min (
319
320
# Max sure there is enough for 8 full length request or at least
320
321
# 4 pages of cache per request
321
322
max (
@@ -330,7 +331,7 @@ def __init__(self, runner):
330
331
# 2*(192*128)*(64*1024) = 3gb
331
332
# (assuming 192 QK head dim, 128 heads, and fp16)
332
333
128 * 1024 )
333
- assert self .chunked_prefill_workspace_size >= \
334
+ assert self .context_chunk_workspace_size >= \
334
335
scheduler_config .max_num_seqs * cache_config .block_size
335
336
336
337
@contextmanager
@@ -430,23 +431,23 @@ def prepare_graph_input_buffers(self,
430
431
"TritonMLAState does not support encoder/decoder yet" )
431
432
432
433
def begin_forward (self , model_input ):
433
- if self .chunked_prefill_enabled :
434
- if not hasattr (self , "chunked_prefill_workspace " ):
434
+ if self .chunked_prefill_enabled or self . enable_prefix_caching :
435
+ if not hasattr (self , "context_chunk_workspace " ):
435
436
# not self.runner.device does not return the correct device
436
437
# for this process, (init_device sets the correct device but
437
438
# only on the Worker). The only way Ive figured out to get the
438
439
# correct device is to allocate the workspace on the first call
439
440
# to begin_forward and use the device of the input tokens
440
441
assert model_input .input_tokens is not None
441
- self .chunked_prefill_workspace = torch .empty (
442
- (self .chunked_prefill_workspace_size ,
442
+ self .context_chunk_workspace = torch .empty (
443
+ (self .context_chunk_workspace_size ,
443
444
self .model_config .get_head_size ()),
444
445
dtype = self .model_config .dtype ,
445
446
device = model_input .input_tokens .device ,
446
447
)
447
448
448
- model_input .attn_metadata .chunked_prefill_workspace = \
449
- self .chunked_prefill_workspace
449
+ model_input .attn_metadata .context_chunk_workspace = \
450
+ self .context_chunk_workspace
450
451
451
452
452
453
@dataclass
@@ -537,7 +538,7 @@ class MLACommonMetadata(AttentionMetadata):
537
538
context_chunk_seq_tot : Optional [List [int ]] = None
538
539
context_chunk_max_seq_lens : Optional [List [int ]] = None
539
540
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
540
- chunked_prefill_workspace : Optional [torch .Tensor ] = None
541
+ context_chunk_workspace : Optional [torch .Tensor ] = None
541
542
542
543
def __post_init__ (self ):
543
544
supported_head_sizes = MLACommonBackend .get_supported_head_sizes ()
@@ -747,11 +748,13 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
747
748
self .block_size = input_builder .block_size
748
749
self .chunked_prefill_enabled = \
749
750
self .runner .scheduler_config .chunked_prefill_enabled
751
+ self .enable_prefix_caching = \
752
+ self .runner .cache_config .enable_prefix_caching
750
753
751
- if self .chunked_prefill_enabled :
754
+ if self .chunked_prefill_enabled or self . enable_prefix_caching :
752
755
attn_state = self .input_builder .runner .attn_state
753
- self .chunked_prefill_workspace_size = \
754
- attn_state .chunked_prefill_workspace_size
756
+ self .context_chunk_workspace_size = \
757
+ attn_state .context_chunk_workspace_size
755
758
self .page_size = self .runner .block_size
756
759
757
760
def prepare (self ):
@@ -920,7 +923,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
920
923
context_chunk_seq_tot = None
921
924
context_chunk_max_seq_lens = None
922
925
923
- if self .chunked_prefill_enabled and self .num_prefills > 0 \
926
+ if (self .chunked_prefill_enabled or self .enable_prefix_caching ) \
927
+ and self .num_prefills > 0 \
924
928
and context_lens_tensor is not None \
925
929
and context_lens_tensor [:self .num_prefills ].max () > 0 :
926
930
@@ -936,7 +940,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
936
940
# algorithm here and allocate more workspace to prefills with
937
941
# longer context lengths
938
942
max_context_chunk = \
939
- self .chunked_prefill_workspace_size // num_prefills_with_context
943
+ self .context_chunk_workspace_size // num_prefills_with_context
940
944
941
945
# align max_context_chunk to page_size by rounding down,
942
946
# currently the `gather_cache` kernel cannot handle
@@ -965,7 +969,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
965
969
chunk_seq_lens .max (dim = 1 ).values .tolist ()
966
970
context_chunk_seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist ()
967
971
assert max (context_chunk_seq_tot ) <= \
968
- self .chunked_prefill_workspace_size
972
+ self .context_chunk_workspace_size
969
973
970
974
return self .runner .attn_backend .make_metadata (
971
975
# Required by ModelRunner
@@ -1288,8 +1292,8 @@ def _compute_prefill_context(
1288
1292
# Fetch from attn_metadata directly, since it late bound by
1289
1293
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
1290
1294
# any weirdness around prefill_metadata caching
1291
- assert attn_metadata .chunked_prefill_workspace is not None
1292
- workspace = attn_metadata .chunked_prefill_workspace
1295
+ assert attn_metadata .context_chunk_workspace is not None
1296
+ workspace = attn_metadata .context_chunk_workspace
1293
1297
1294
1298
for i in range (iters ):
1295
1299
toks = prefill_metadata .context_chunk_seq_tot [i ]
@@ -1502,12 +1506,12 @@ def forward(
1502
1506
"output is not yet supported for MLAImplBase" )
1503
1507
1504
1508
if attn_metadata .is_profile_run and \
1505
- attn_metadata .chunked_prefill_workspace is not None :
1509
+ attn_metadata .context_chunk_workspace is not None :
1506
1510
# During the profile run try to simulate to worse case output size
1507
1511
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
1508
1512
# since this can be large
1509
1513
_ = torch .empty (
1510
- (attn_metadata .chunked_prefill_workspace .shape [0 ],
1514
+ (attn_metadata .context_chunk_workspace .shape [0 ],
1511
1515
self .num_heads , self .qk_nope_head_dim + self .v_head_dim ),
1512
1516
device = k_c_normed .device ,
1513
1517
dtype = k_c_normed .dtype ,
0 commit comments