Skip to content

Commit 3df7b00

Browse files
LucasWilkinsonWoosukKwonsimon-momgoinzhuohan123
authored andcommitted
[Attention] MLA decode optimizations (vllm-project#12528)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
1 parent d2615b3 commit 3df7b00

File tree

31 files changed

+2266
-32
lines changed

31 files changed

+2266
-32
lines changed

csrc/cache.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
2828
const std::string& kv_cache_dtype,
2929
torch::Tensor& k_scale, torch::Tensor& v_scale);
3030

31+
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
32+
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
33+
const std::string& kv_cache_dtype,
34+
torch::Tensor& scale);
35+
3136
// Just for unittest
3237
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
3338
const double scale, const std::string& kv_cache_dtype);

csrc/cache_kernels.cu

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,51 @@ __global__ void reshape_and_cache_flash_kernel(
245245
}
246246
}
247247
}
248+
249+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
250+
__global__ void concat_and_cache_mla_kernel(
251+
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
252+
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
253+
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
254+
// + pe_dim)]
255+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
256+
const int block_stride, //
257+
const int kv_c_stride, //
258+
const int k_pe_stride, //
259+
const int kv_lora_rank, //
260+
const int pe_dim, //
261+
const int block_size, //
262+
const float* scale //
263+
) {
264+
const int64_t token_idx = blockIdx.x;
265+
const int64_t slot_idx = slot_mapping[token_idx];
266+
// NOTE: slot_idx can be -1 if the token is padded
267+
if (slot_idx < 0) {
268+
return;
269+
}
270+
const int64_t block_idx = slot_idx / block_size;
271+
const int64_t block_offset = slot_idx % block_size;
272+
273+
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
274+
int src_stride, int dst_stride, int size, int offset) {
275+
for (int i = threadIdx.x; i < size; i += blockDim.x) {
276+
const int64_t src_idx = token_idx * src_stride + i;
277+
const int64_t dst_idx = block_idx * block_stride +
278+
block_offset * (kv_lora_rank + pe_dim) + i +
279+
offset;
280+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
281+
dst[dst_idx] = src[src_idx];
282+
} else {
283+
dst[dst_idx] =
284+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
285+
}
286+
}
287+
};
288+
289+
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
290+
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
291+
}
292+
248293
} // namespace vllm
249294

250295
// KV_T is the stored data type of kv-cache.
@@ -343,6 +388,56 @@ void reshape_and_cache_flash(
343388
CALL_RESHAPE_AND_CACHE_FLASH);
344389
}
345390

391+
// KV_T is the stored data type of kv-cache.
392+
// CACHE_T is the data type of key and value tensors.
393+
// KV_DTYPE is the real data type of kv-cache.
394+
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
395+
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
396+
<<<grid, block, 0, stream>>>( \
397+
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
398+
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
399+
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
400+
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
401+
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
402+
reinterpret_cast<const float*>(scale.data_ptr()));
403+
404+
void concat_and_cache_mla(
405+
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
406+
torch::Tensor& k_pe, // [num_tokens, pe_dim]
407+
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
408+
// pe_dim)]
409+
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
410+
const std::string& kv_cache_dtype, torch::Tensor& scale) {
411+
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
412+
// slot_mapping.size(0) because of padding for CUDA graphs.
413+
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
414+
// both include padding.
415+
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
416+
// since key includes padding for CUDA graphs, while slot_mapping does not.
417+
// In this case, slot_mapping.size(0) represents the actual number of tokens
418+
// before padding.
419+
// For compatibility with both cases, we use slot_mapping.size(0) as the
420+
// number of tokens.
421+
int num_tokens = slot_mapping.size(0);
422+
int kv_lora_rank = kv_c.size(1);
423+
int pe_dim = k_pe.size(1);
424+
int block_size = kv_cache.size(1);
425+
426+
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
427+
428+
int kv_c_stride = kv_c.stride(0);
429+
int k_pe_stride = k_pe.stride(0);
430+
int block_stride = kv_cache.stride(0);
431+
432+
dim3 grid(num_tokens);
433+
dim3 block(std::min(kv_lora_rank, 512));
434+
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
435+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
436+
437+
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
438+
CALL_CONCAT_AND_CACHE_MLA);
439+
}
440+
346441
namespace vllm {
347442

348443
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>

csrc/torch_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
463463
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
464464
&reshape_and_cache_flash);
465465

466+
// Concat kv_c and k_pe and cache them.
467+
cache_ops.def(
468+
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
469+
" Tensor! kv_cache,"
470+
" Tensor slot_mapping,"
471+
" str kv_cache_dtype,"
472+
" Tensor scale) -> ()");
473+
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
474+
466475
// Convert the key and value cache to fp8 data type.
467476
cache_ops.def(
468477
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import pytest
2+
import torch
3+
4+
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
5+
6+
7+
def cdiv(a, b):
8+
return (a + b - 1) // b
9+
10+
11+
@pytest.mark.parametrize("B", [3, 5])
12+
@pytest.mark.parametrize("L", [1027, 1025])
13+
@pytest.mark.parametrize("H_Q", [32])
14+
@pytest.mark.parametrize("H_KV", [32, 8])
15+
@pytest.mark.parametrize("D_QK", [128, 192, 576])
16+
@pytest.mark.parametrize("D_V", [128, 512])
17+
@pytest.mark.parametrize("CACHE_SIZE", [16384])
18+
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
19+
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
20+
assert CACHE_SIZE % PAGE_SIZE == 0
21+
dtype = torch.bfloat16
22+
seq_len = L # This represents the number of tokens already in the sequence
23+
sm_scale = 1.0 / (D_QK**0.5)
24+
num_kv_splits = 8
25+
26+
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
27+
req_to_page = torch.randint(0,
28+
CACHE_SIZE // PAGE_SIZE,
29+
(B, num_pages_per_batch, 1),
30+
device="cuda")
31+
req_to_token = req_to_page * PAGE_SIZE
32+
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
33+
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
34+
1, 1, -1)
35+
req_to_token = req_to_token.view(B, -1)
36+
req_to_token = req_to_token[:, :seq_len].contiguous()
37+
38+
# q represents the new token being generated, one per batch
39+
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
40+
41+
# k_buffer and v_buffer represent all previous tokens
42+
# Page size is 1.
43+
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
44+
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
45+
46+
# o will have the same shape as q
47+
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
48+
49+
b_seq_len = torch.full((B, ), seq_len, device="cuda")
50+
51+
attn_logits = torch.empty(
52+
(B, H_Q, num_kv_splits, D_V + 1),
53+
dtype=torch.float32,
54+
device="cuda",
55+
)
56+
57+
# Call the original implementation.
58+
decode_attention_fwd(
59+
q,
60+
k_buffer,
61+
v_buffer,
62+
o,
63+
req_to_token,
64+
b_seq_len,
65+
attn_logits,
66+
num_kv_splits,
67+
sm_scale,
68+
)
69+
70+
# Page size can be larger than 1.
71+
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
72+
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
73+
74+
o1 = torch.zeros_like(o)
75+
76+
decode_attention_fwd(
77+
q,
78+
k_buffer,
79+
v_buffer,
80+
o1,
81+
req_to_page,
82+
b_seq_len,
83+
attn_logits,
84+
num_kv_splits,
85+
sm_scale,
86+
PAGE_SIZE,
87+
)
88+
89+
assert torch.allclose(o, o1)

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ class DummyPlatform(CudaPlatform):
55
device_name = "DummyDevice"
66

77
def get_attn_backend_cls(self, backend_name, head_size, dtype,
8-
kv_cache_dtype, block_size, use_v1):
8+
kv_cache_dtype, block_size, use_v1, use_mla):
99
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

tests/weight_loading/models.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
2020
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
2121
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
2222
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
23-
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
23+
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
2424
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
2525
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
2626
awq, casperhansen/mixtral-instruct-awq, main

tests/weight_loading/run_model_weight_loading_test.sh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ SUCCESS=0
33

44
while getopts "c:" OPT; do
55
case ${OPT} in
6-
c )
6+
c )
77
CONFIG="$OPTARG"
88
;;
99
\? )
@@ -18,9 +18,14 @@ IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG"
1818

1919
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
2020
do
21+
if [[ $MODEL_CONFIG == \#* ]]; then
22+
echo "=== SKIPPING MODEL: $MODEL_CONFIG ==="
23+
continue
24+
fi
25+
2126
LOCAL_SUCCESS=0
2227
IFS=', ' read -r -a array <<< "$MODEL_CONFIG"
23-
28+
2429
echo "=== RUNNING MODEL: $MODEL_CONFIG ==="
2530

2631
export QUANTIZATION=${array[0]}

vllm/_custom_ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,19 @@ def reshape_and_cache_flash(
10021002
v_scale)
10031003

10041004

1005+
def concat_and_cache_mla(
1006+
kv_c: torch.Tensor,
1007+
k_pe: torch.Tensor,
1008+
kv_cache: torch.Tensor,
1009+
slot_mapping: torch.Tensor,
1010+
kv_cache_dtype: str,
1011+
scale: torch.Tensor,
1012+
) -> None:
1013+
torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
1014+
slot_mapping, kv_cache_dtype,
1015+
scale)
1016+
1017+
10051018
def copy_blocks(key_caches: List[torch.Tensor],
10061019
value_caches: List[torch.Tensor],
10071020
block_mapping: torch.Tensor) -> None:

vllm/attention/backends/abstract.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,19 @@ def forward(
276276
output: Optional[torch.Tensor] = None,
277277
) -> torch.Tensor:
278278
raise NotImplementedError
279+
280+
281+
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
282+
283+
@abstractmethod
284+
def forward(
285+
self,
286+
layer: AttentionLayer,
287+
hidden_states_or_cq: torch.Tensor,
288+
kv_c_normed: torch.Tensor,
289+
k_pe: torch.Tensor,
290+
kv_cache: torch.Tensor,
291+
attn_metadata: T,
292+
output: Optional[torch.Tensor] = None,
293+
) -> torch.Tensor:
294+
raise NotImplementedError

vllm/attention/backends/mla/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)