Skip to content

Commit c993711

Browse files
committed
[Bugfix] Remove hardcoded head_size=256 for Deepseek v2 and v3 (vllm-project#12067)
Signed-off-by: Isotr0py <[email protected]>
1 parent 409b228 commit c993711

File tree

4 files changed

+23
-40
lines changed

4 files changed

+23
-40
lines changed

tests/kernels/test_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
3232
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
3333

34-
# FlashAttention forward only supports head dimension at most 128
35-
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
36-
HEAD_SIZES = [64, 80, 120, 256]
34+
# This should be sync with get_supported_head_sizes() in
35+
# vllm.attention.ops.paged_attn.PagedAttention
36+
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
3737

3838
BLOCK_SIZES = [16, 32]
3939
USE_ALIBI = [False, True]

vllm/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,9 +733,12 @@ def get_head_size(self) -> int:
733733
if hasattr(self.hf_text_config,
734734
"model_type") and (self.hf_text_config.model_type
735735
in ('deepseek_v2', 'deepseek_v3')):
736-
# FlashAttention supports only head_size 32, 64, 128, 256,
737-
# we need to pad head_size 192 to 256
738-
return 256
736+
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
737+
0)
738+
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
739+
0)
740+
if qk_rope_head_dim and qk_nope_head_dim:
741+
return qk_rope_head_dim + qk_nope_head_dim
739742

740743
if self.is_attention_free:
741744
return 0

vllm/model_executor/models/deepseek_v2.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,8 @@ def __init__(
262262
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
263263
self.scaling = self.scaling * mscale * mscale
264264

265-
# self.attn = Attention(self.num_heads,
266-
# self.qk_head_dim,
267-
# self.scaling,
268-
# num_kv_heads=self.num_heads)
269-
270-
# TODO, support head_size 192
271265
self.attn = Attention(self.num_local_heads,
272-
256,
266+
self.qk_head_dim,
273267
self.scaling,
274268
num_kv_heads=self.num_local_heads,
275269
cache_config=cache_config,
@@ -319,18 +313,14 @@ def forward(
319313
k = torch.empty_like(q)
320314
k[..., :self.qk_nope_head_dim] = k_nope
321315
k[..., self.qk_nope_head_dim:] = k_pe
322-
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
323-
value=0).view(-1,
324-
self.num_local_heads * 256)
325-
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
326-
value=0).view(-1,
327-
self.num_local_heads * 256)
328-
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
329-
value=0).view(-1,
330-
self.num_local_heads * 256)
316+
# padding value to qk_head_dim for alignment
317+
v = torch.nn.functional.pad(
318+
v, [0, self.qk_head_dim - self.v_head_dim],
319+
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
331320
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
332321
attn_output = attn_output.view(
333-
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
322+
-1, self.num_local_heads,
323+
self.qk_head_dim)[..., :self.v_head_dim].reshape(
334324
-1, self.num_local_heads * self.v_head_dim)
335325
output, _ = self.o_proj(attn_output)
336326
return output

vllm/model_executor/models/deepseek_v3.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,8 @@ def __init__(
269269
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
270270
self.scaling = self.scaling * mscale * mscale
271271

272-
# self.attn = Attention(self.num_heads,
273-
# self.qk_head_dim,
274-
# self.scaling,
275-
# num_kv_heads=self.num_heads)
276-
277-
# TODO, support head_size 192
278272
self.attn = Attention(self.num_local_heads,
279-
256,
273+
self.qk_head_dim,
280274
self.scaling,
281275
num_kv_heads=self.num_local_heads,
282276
cache_config=cache_config,
@@ -326,18 +320,14 @@ def forward(
326320
k = torch.empty_like(q)
327321
k[..., :self.qk_nope_head_dim] = k_nope
328322
k[..., self.qk_nope_head_dim:] = k_pe
329-
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
330-
value=0).view(-1,
331-
self.num_local_heads * 256)
332-
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
333-
value=0).view(-1,
334-
self.num_local_heads * 256)
335-
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
336-
value=0).view(-1,
337-
self.num_local_heads * 256)
323+
# padding value to qk_head_dim for alignment
324+
v = torch.nn.functional.pad(
325+
v, [0, self.qk_head_dim - self.v_head_dim],
326+
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
338327
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
339328
attn_output = attn_output.view(
340-
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
329+
-1, self.num_local_heads,
330+
self.qk_head_dim)[..., :self.v_head_dim].reshape(
341331
-1, self.num_local_heads * self.v_head_dim)
342332
output, _ = self.o_proj(attn_output)
343333
return output

0 commit comments

Comments
 (0)