Skip to content

Commit 88fa03f

Browse files
authored
doc: improve mla related documentation (#818)
1 parent 5094eb7 commit 88fa03f

File tree

7 files changed

+48
-23
lines changed

7 files changed

+48
-23
lines changed

docs/api/decode.rst

-5
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,3 @@ Batch Decoding
2626
:members:
2727

2828
.. automethod:: __init__
29-
30-
.. autoclass:: BatchDecodeMlaWithPagedKVCacheWrapper
31-
:members:
32-
33-
.. automethod:: __init__

docs/tutorials/kv_layout.rst

+35-10
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,14 @@ by default).
2424
Ragged Tensor
2525
-------------
2626

27-
In batched inference/serving, the input sequence length may vary across different samples.
28-
When there is no need to change the sequence length (e.g. in prefilling stage), we can use ``RaggedTensor``
29-
with a single ragged (variable length) dimension to store the key/value tensors in KV-Cache:
27+
We use Ragged Tensor to store the variable length Q/K/V tensors in FlashInfer for batch prefill self-attention:
3028

3129
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/ragged.png
3230
:width: 400
3331
:align: center
3432
:alt: Data structure of Ragged KV-Cache.
3533

36-
The keys (or values) of all requests are packed into a single ``data`` tensor without padding,
34+
In Ragged Tensor, all requests's Q/K/V are packed into a single ``data`` tensor without padding,
3735
we use a ``indptr`` array (``num_requests+1`` elements, the first element is always zero)
3836
to store the information of variable sequence lengths of each request
3937
(``indptr[i+1]-indptr[i]`` is the sequence length of request ``i``), the ``data`` tensor has
@@ -42,7 +40,7 @@ shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``.
4240
We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``.
4341

4442
.. note::
45-
``indptr`` arrays across the flashinfer library should be of type ``int32``. Arrays of type ``int64`` can cause indexing errors.
43+
``indptr`` arrays across the flashinfer library should be of type ``int32``. Arrays of type ``int64`` can cause indexing errors.
4644

4745
FlashInfer APIs
4846
~~~~~~~~~~~~~~~
@@ -127,21 +125,48 @@ when stored in a single tensor, ``kv_data`` has shape:
127125

128126
.. code:: python
129127
130-
(max_num_pages, 2, page_size, num_heads, head_dim) # NHD layout
131-
(max_num_pages, 2, num_heads, page_size, head_dim) # HND layout
128+
kv_cache_nhd = torch.empty(max_num_pages, 2, page_size, num_heads, head_dim, dtype=torch.bfloat16) # NHD layout
129+
kv_cache_hnd = torch.empty(max_num_pages, 2, num_heads, page_size, head_dim, dtype=torch.bfloat16) # HND layout
132130
133131
when stored in a tuple of tensors, ``kv_data = (k_data, v_data)``, and each one of them has shape:
134132

135133
.. code:: python
136134
137-
(max_num_pages, page_size, num_heads, head_dim) # NHD layout
138-
(max_num_pages, num_heads, page_size, head_dim) # HND layout
135+
k_cache_nhd = torch.empty(max_num_pages, page_size, num_heads, head_dim, dtype=torch.bfloat16) # NHD layout
136+
k_cache_nhd = torch.empty(max_num_pages, num_heads, page_size, head_dim, dtype=torch.bfloat16) # HND layout
137+
v_cache_nhd = torch.empty(max_num_pages, page_size, num_heads, head_dim, dtype=torch.bfloat16) # NHD layout
138+
v_cache_nhd = torch.empty(max_num_pages, num_heads, page_size, head_dim, dtype=torch.bfloat16) # HND layout
139+
139140
140141
where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens
141142
we fit into each page. ``2`` in single tensor storage means K/V (first one for keys, the second one for values).
142143

143144
.. note::
144-
``indptr`` arrays across the flashinfer library should be of type ``int32``. Arrays of type ``int64`` can cause indexing errors. This is also true of the ``kv_page_indices`` and ``kv_last_page_lens`` arrays.
145+
``indptr`` arrays across the flashinfer library should be of type ``int32``. Arrays of type ``int64`` can cause indexing errors. This is also true of the ``kv_page_indices`` and ``kv_last_page_lens`` arrays.
146+
147+
.. _mla-page-layout:
148+
149+
Multi-head Latent Attention Page Layout
150+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
151+
152+
Multi-head Latent Attention (MLA) is a new attention mechanism proposed in `DeepSeek v2 <https://arxiv.org/abs/2405.04434>`_ and was
153+
used in later DeepSeek models. MLA unifies key cache and value cache into a single tensor, so there is no need to store them seperately.
154+
Compared to multi-head atteniton or grouped query attention, the KV-Cache of MLA do not have the ``num_heads`` dimension,
155+
so there is no distinction like ``NHD`` and ``HND`` layout.
156+
157+
MLA separates RoPE (Rotary Positional Encoding) dimensions and other head dimensions. We use ``kpe`` (key w/ positional encoding) and ``ckv`` (compressed key/value)
158+
to name these two components. User can store them in a single Paged KV-Cache:
159+
160+
.. code:: python
161+
162+
head_dim_ckv = 512
163+
head_dim_kpe = 64
164+
mla_paged_kv_cache = torch.empty(max_num_pages, page_size, head_dim_ckv + head_dim_kpe, dtype=torch.bfloat16)
165+
ckv = mla_paged_kv_cache[:, :, :head_dim_ckv] # Slicing here does not copy or move data
166+
kpe = mla_paged_kv_cache[:, :, head_dim_ckv:] # Slicing here does not copy or move data
167+
168+
169+
and ``ckv`` and ``kpe`` can then be fed into the MLA attention kernel :class:`flashinfer.mla.BatchMLAPagedAttentionWrapper`.
145170

146171
FlashInfer APIs
147172
~~~~~~~~~~~~~~~

flashinfer/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@
3030
from .cascade import merge_state as merge_state
3131
from .cascade import merge_state_in_place as merge_state_in_place
3232
from .cascade import merge_states as merge_states
33-
from .decode import (
34-
BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper,
35-
)
3633
from .decode import (
3734
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
3835
)

flashinfer/decode.py

+5
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,11 @@ def __init__(
12431243

12441244

12451245
class BatchDecodeMlaWithPagedKVCacheWrapper:
1246+
r"""Warning: this class is deprecated and will be removed in a future release.
1247+
Please use :class:`flashinfer.mla.BatchMLAPagedAttentionWrapper` instead, which provides
1248+
a more efficient and general MLA implementation that supports decode and incremental prefill.
1249+
"""
1250+
12461251
def __init__(
12471252
self,
12481253
float_workspace_buffer: torch.Tensor,

flashinfer/mla.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ class BatchMLAPagedAttentionWrapper:
4242
absorbed with :math:`W_{O}`.
4343
For MLA attention without Matrix Absorption (``head_dim_qk=192`` and ``head_dim_vo=128``, which is
4444
used in prefilling self-attention stage), please use
45-
:class:`flashinfer.prefill.BatchPrefillWithRaggedAttentionWrapper`.
45+
:class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`.
46+
47+
More information about The Paged KV-Cache layout in MLA is explained in our tutorial
48+
:ref:`MLA Page Layout <mla-page-layout>`.
4649
4750
For more details about the MLA computation, Matrix Absorption and FlashInfer's MLA implementation,
4851
please refer to our `blog post <http://flashinfer.ai/2025/02/10/flashinfer-deepseek-mla.html>`_.
@@ -76,7 +79,7 @@ class BatchMLAPagedAttentionWrapper:
7679
>>> kpe = torch.zeros(
7780
... batch_size * 999, 1, head_dim_kpe, dtype=torch.bfloat16, device="cuda"
7881
... )
79-
>>> sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5)
82+
>>> sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption
8083
>>> mla_wrapper.plan(
8184
... q_indptr,
8285
... kv_indptr,

flashinfer/prefill.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
897897
r"""Wrapper class for prefill/append attention with paged kv-cache for batch of
898898
requests.
899899
900-
Check :ref:`our tutorial<page-layout>` for page table layout.
900+
Check :ref:`our tutorial <page-layout>` for page table layout.
901901
902902
Example
903903
-------
@@ -1711,7 +1711,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
17111711
r"""Wrapper class for prefill/append attention with ragged (tensor) kv-cache for
17121712
batch of requests.
17131713
1714-
Check :ref:`our tutorial<ragged-layout>` for ragged kv-cache layout.
1714+
Check :ref:`our tutorial <ragged-layout>` for ragged kv-cache layout.
17151715
17161716
Example
17171717
-------

tests/test_deepseek_mla.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_batch_mla_page_attention(
188188
dtype=torch.half,
189189
device="cuda",
190190
)
191-
sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5)
191+
sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption
192192
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
193193
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
194194
workspace_buffer, backend=backend

0 commit comments

Comments
 (0)