Skip to content

Commit ae79743

Browse files
xuechendikwisniewski98yanguleiyiliu30Yi4Liu
authored
[Deepseek R1][v0] Porting deepseek r1 to habana_main (vllm-project#1161)
JIRA: https://jira.habana-labs.com/browse/SW-227174 cherry-pick vllm-project#1030 and fixed conflicts after rebase Dependency: HabanaAI/vllm-hpu-extension#161 Verified with below 3 methods: 1. test with deepseek-v2 BF16 weight. => Passed 2. evaluate acc on deepseek-r1 with out of box block fp8 weight => Passed 3. evaluate acc on deepseek-r1 with out of box block fp8 weight + INC calibrated per-channel scale => Passed acc check, performance reach goal(number is in jira ticket) == Details == 1. test with deepseek-v2 BF16 weight: ``` PT_HPU_LAZY_MODE=1 python run_example_tp.py --model DeepSeek-V2-Lite --tokenizer DeepSeek-V2-Lite --osl 32 ``` ``` (VllmWorkerProcess pid=1039) WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up! (VllmWorkerProcess pid=1038) WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up! (VllmWorkerProcess pid=1041) WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up! WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up! Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.57it/s, est. speed input: 12.59 toks/s, output: 50.37 toks/s] e2e took 2.5509743690199684 seconds ==================================== Prompt: 'Hello, my name is' Generated text: '\nI am a 20 year old student from the UK. I am currently studying for a degree in English Literature and Creative Writing at the University of East' Ground truth: None ==================================== ==================================== Prompt: '0.999 compares to 0.9 is ' Generated text: '100%\n0.9999999999999999999999999' Ground truth: None ==================================== ==================================== Prompt: 'The capital of France is' Generated text: ' Paris, which is also the largest city in the country. The city is located on the Seine River and is known for its beautiful architecture, museums, and art' Ground truth: None ==================================== ==================================== Prompt: 'The future of AI is' Generated text: ' in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe' Ground truth: None ==================================== ``` 2. evaluate acc on deepseek-r1 with out of box block fp8 weight - limit 256 |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9648|± |0.0115| | | |strict-match | 5|exact_match|↑ |0.9648|± |0.0115| 3. evaluate acc on deepseek-r1 with out of box block fp8 weight + INC calibrated per-channel scale |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9688|± |0.0109| | | |strict-match | 5|exact_match|↑ |0.9688|± |0.0109| --------- Signed-off-by: Chendi.Xue <[email protected]> Signed-off-by: kwisniewski98 <[email protected]> Signed-off-by: Chendi Xue <[email protected]> Signed-off-by: Yi Liu <[email protected]> Co-authored-by: kwisniewski98 <[email protected]> Co-authored-by: Youlei Yang <[email protected]> Co-authored-by: Yi Liu <[email protected]> Co-authored-by: Yi Liu <[email protected]>
1 parent 670a544 commit ae79743

File tree

18 files changed

+457
-83
lines changed

18 files changed

+457
-83
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1
2+
model_name: "/mnt/weka/llm/DeepSeek-V2-Lite"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.375
8+
- name: "exact_match,flexible-extract"
9+
value: 0.375
10+
limit: 256
11+
num_fewshot: 5
12+
dtype: "bfloat16"
13+
trust_remote_code: True
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DeepSeek-V2-Lite.yaml

.jenkins/test_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ stages:
2020
- name: v0_gsm8k_small_g2_tp2
2121
flavor: g2.s
2222
command: export PT_HPU_LAZY_MODE=1 && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2
23+
- name: v0_gsm8k_g2_deepseek-v2-lite_tp1
24+
flavor: g3
25+
command: export PT_HPU_LAZY_MODE=1 && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-deepseek.txt -t 1
2326
#- name: v1_gsm8k_small_g3_tp1
2427
# flavor: g3
2528
# command: export PT_HPU_LAZY_MODE=1 && export VLLM_USE_V1=1 && export VLLM_CONTIGUOUS_PA=false && cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1

README_GAUDI.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,9 @@ measurements for a given model. The quantization configuration is used during in
408408
> If you are prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which is time-consuming.
409409
However, disabling this feature in production environments is not recommended, as it can lead to a significant performance decrease.
410410

411+
> [!TIP]
412+
> If you are benchmarking an FP8 model with `scale_format=const`, setting `VLLM_DISABLE_MARK_SCALES_AS_CONST=true` can help speed up the warmup stage.
413+
411414
> [!TIP]
412415
> When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this, set the following environment variables:
413416
> - `VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.

requirements/hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ numpy==1.26.4
99
tabulate
1010
setuptools>=77.0.3,<80.0.0
1111
setuptools-scm>=8
12-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@50a112a
12+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@7df7dd0

vllm/attention/backends/hpu_attn.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1919
AttentionLayer,
2020
AttentionMetadata, AttentionType)
21+
from vllm.attention.backends.mla.common import MLACommonImpl
2122
from vllm.attention.backends.utils import CommonAttentionState
2223
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
2324
HPUPagedAttentionMetadata)
@@ -70,6 +71,49 @@ def copy_blocks(
7071
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)
7172

7273

74+
class HPUMLAAttentionBackend(AttentionBackend):
75+
76+
@staticmethod
77+
def get_name() -> str:
78+
return "HPU_MLA"
79+
80+
@staticmethod
81+
def get_impl_cls() -> Type["HPUMLAImpl"]:
82+
return HPUMLAImpl
83+
84+
@staticmethod
85+
def get_metadata_cls() -> Type["AttentionMetadata"]:
86+
return HPUMLAMetadata
87+
88+
@staticmethod
89+
def get_state_cls() -> Type["CommonAttentionState"]:
90+
return CommonAttentionState
91+
92+
@staticmethod
93+
def get_kv_cache_shape(
94+
num_blocks: int,
95+
block_size: int,
96+
num_kv_heads: int,
97+
head_size: int,
98+
) -> Tuple[int, ...]:
99+
return (num_blocks, block_size, head_size)
100+
101+
@staticmethod
102+
def swap_blocks(
103+
src_kv_cache: torch.Tensor,
104+
dst_kv_cache: torch.Tensor,
105+
src_to_dst: torch.Tensor,
106+
) -> None:
107+
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
108+
109+
@staticmethod
110+
def copy_blocks(
111+
kv_caches: List[torch.Tensor],
112+
src_to_dists: torch.Tensor,
113+
) -> None:
114+
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
115+
116+
73117
@dataclass
74118
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
75119
"""Metadata for HPUAttentionbackend."""
@@ -79,6 +123,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
79123
attn_bias: Optional[torch.Tensor]
80124
seq_lens_tensor: Optional[torch.Tensor]
81125
context_lens_tensor: Optional[torch.Tensor]
126+
input_positions: torch.Tensor
82127
seq_lens: Optional[List[int]] = None
83128
encoder_seq_lens: Optional[List[int]] = None
84129
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
@@ -92,6 +137,207 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
92137
cross_attn_bias: Optional[torch.Tensor] = None
93138

94139

140+
@dataclass
141+
class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata):
142+
pass
143+
144+
145+
class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata], torch.nn.Module):
146+
147+
def __init__(
148+
self,
149+
num_heads: int,
150+
head_size: int,
151+
scale: float,
152+
num_kv_heads: int,
153+
alibi_slopes: Optional[List[float]],
154+
sliding_window: Optional[int],
155+
kv_cache_dtype: str,
156+
blocksparse_params: Optional[Dict[str, Any]],
157+
logits_soft_cap: Optional[float],
158+
attn_type: str,
159+
# MLA Specific Arguments
160+
**kwargs) -> None:
161+
torch.nn.Module.__init__(self)
162+
MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads,
163+
alibi_slopes, sliding_window, kv_cache_dtype,
164+
blocksparse_params, logits_soft_cap, attn_type,
165+
**kwargs)
166+
167+
self.matmul_qk = Matmul()
168+
self.softmax = Softmax()
169+
self.matmul_av = Matmul()
170+
self.batch2block_matmul = Matmul()
171+
self.block2batch_matmul = Matmul()
172+
self.latent_cache_k = VLLMKVCache()
173+
self.fused_scaled_dot_product_attention = kernels.fsdpa()
174+
175+
if "fsdpa" in enabled_flags():
176+
assert alibi_slopes is None, \
177+
'Prefill with FusedSDPA not supported with alibi slopes!'
178+
self.prefill_impl = 'fsdpa'
179+
else:
180+
self.prefill_impl = 'naive'
181+
182+
unsupported_features = [
183+
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
184+
]
185+
if any(unsupported_features):
186+
raise NotImplementedError(
187+
"HPUMLAImpl does not support one of the following: "
188+
"alibi_slopes, sliding_window, blocksparse_params, "
189+
"logits_soft_cap")
190+
191+
if attn_type != AttentionType.DECODER:
192+
raise NotImplementedError("Encoder self-attention and "
193+
"encoder/decoder cross-attention "
194+
"are not implemented for "
195+
"TritonMLAImpl")
196+
197+
def forward(
198+
self,
199+
layer: AttentionLayer,
200+
q: torch.Tensor,
201+
k_c_normed: torch.Tensor, # key in unified attn
202+
k_pe: torch.Tensor, # value in unified attn
203+
kv_cache: torch.Tensor,
204+
attn_metadata: HPUAttentionMetadata,
205+
output: Optional[torch.Tensor] = None,
206+
) -> torch.Tensor:
207+
if output is not None:
208+
raise NotImplementedError(
209+
"output is not yet supported for MLAImplBase")
210+
211+
batch_size = q.shape[0]
212+
is_prefill = attn_metadata.is_prompt
213+
214+
# Restore head dim (for rotary embedding)
215+
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
216+
q = q.view(-1, self.num_heads, self.qk_head_dim)
217+
assert hasattr(attn_metadata,
218+
"input_positions"), f"attn meta: {attn_metadata}"
219+
220+
input_positions = attn_metadata.input_positions.view(-1)
221+
if not is_prefill:
222+
# decode
223+
q_nope, q_pe = q.split(
224+
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
225+
# Convert from (B, N, P) to (N, B, P)
226+
q_nope = q_nope.transpose(0, 1)
227+
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
228+
decode_ql_nope = torch.bmm(q_nope, self.W_UK_T)
229+
# Convert from (N, B, L) to (B, N, L)
230+
decode_ql_nope = decode_ql_nope.transpose(0, 1)
231+
q_pe, k_pe = \
232+
self.rotary_emb(input_positions, q_pe, k_pe)
233+
else:
234+
# prefill
235+
q_pe = q[..., self.qk_nope_head_dim:]
236+
q[..., self.qk_nope_head_dim:], k_pe = \
237+
self.rotary_emb(input_positions, q_pe, k_pe)
238+
239+
block_indices = attn_metadata.block_indices
240+
block_offsets = attn_metadata.block_offsets
241+
242+
latent_vec_k = torch.concat(
243+
(k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)),
244+
dim=-1)
245+
latent_vec_k = latent_vec_k.view(
246+
-1, self.qk_rope_head_dim + self.kv_lora_rank)
247+
if is_prefill:
248+
latent_vec_k = latent_vec_k.unflatten(0,
249+
(block_indices.size(0), -1))
250+
251+
# write the latent and rope to kv cache
252+
if kv_cache is not None and len(kv_cache) == 2:
253+
self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices,
254+
block_offsets)
255+
k_cache = kv_cache[0]
256+
v_cache = None
257+
258+
if is_prefill:
259+
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata,
260+
batch_size)
261+
else:
262+
return self._forward_decode(decode_ql_nope, q_pe,
263+
(k_cache, v_cache), attn_metadata,
264+
batch_size)
265+
266+
def _forward_prefill( # type: ignore
267+
self, q: torch.Tensor, k_c_normed: torch.Tensor,
268+
k_pe: torch.Tensor, attn_metadata: HPUAttentionMetadata,
269+
batch_size: int) -> torch.Tensor:
270+
kv_nope = self.kv_b_proj(k_c_normed)[0]\
271+
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
272+
k_nope, v = kv_nope\
273+
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
274+
275+
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
276+
277+
q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
278+
k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
279+
v = v.view(batch_size, -1, self.num_heads, self.v_head_dim)
280+
281+
to_pad = self.qk_head_dim - self.v_head_dim
282+
if to_pad > 0:
283+
v_padding = torch.zeros(*v.shape[:-1],
284+
q.shape[-1] - v.shape[-1],
285+
device=v.device,
286+
dtype=v.dtype)
287+
v_padded = torch.cat((v, v_padding), dim=-1)
288+
else:
289+
v_padded = v
290+
291+
out = ops.prompt_attention(
292+
impl=self.prefill_impl,
293+
query=q,
294+
key=k,
295+
value=v_padded,
296+
is_causal=True,
297+
attn_bias=attn_metadata.attn_bias,
298+
valid_seq_lengths=attn_metadata.seq_lens_tensor,
299+
scale=self.scale,
300+
matmul_qk_op=self.matmul_qk,
301+
softmax_op=self.softmax,
302+
matmul_av_op=self.matmul_av,
303+
fsdpa_op=self.fused_scaled_dot_product_attention.apply \
304+
if self.fused_scaled_dot_product_attention is not None else None)
305+
attn_output = out.view(batch_size, -1, self.num_heads, q.shape[-1])
306+
attn_output = attn_output[..., :v.shape[-1]]\
307+
.reshape(batch_size, -1, self.num_heads * v.shape[-1])
308+
309+
return attn_output
310+
311+
def _forward_decode( # type: ignore
312+
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
313+
kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata,
314+
batch_size: int) -> torch.Tensor:
315+
query = torch.cat([q_nope, q_pe], dim=-1)
316+
317+
key_cache = kv_cache[0].unsqueeze(2)
318+
value_cache = kv_cache[1] # value_cache is None
319+
output = HPUPagedAttention.forward_decode(
320+
query=query,
321+
key_cache=key_cache,
322+
value_cache=value_cache,
323+
block_list=attn_metadata.block_list,
324+
block_mapping=attn_metadata.block_mapping,
325+
block_bias=attn_metadata.attn_bias,
326+
block_groups=attn_metadata.block_groups,
327+
scale=self.scale,
328+
matmul_qk_op=self.matmul_qk,
329+
matmul_av_op=self.matmul_av,
330+
batch2block_matmul_op=self.batch2block_matmul,
331+
block2batch_matmul_op=self.block2batch_matmul,
332+
keys_fetch_func=self.latent_cache_k.fetch_from_cache,
333+
values_fetch_func=None,
334+
kv_lora_rank=self.kv_lora_rank)
335+
output = output.view(batch_size, 1, -1)
336+
result = self._v_up_proj(output)
337+
result = result.view(batch_size, 1, -1)
338+
return result
339+
340+
95341
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
96342
"""
97343
If the input tensors contain prompt tokens, the layout is as follows:

vllm/attention/ops/hpu_paged_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
6161

6262
@staticmethod
6363
def forward_decode(**kwargs) -> torch.Tensor:
64+
if kwargs.get("kv_lora_rank"):
65+
return ops.flat_pa_mla(**kwargs)
6466
return ops.flat_pa(**kwargs)
6567

6668
@staticmethod

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,10 @@ def grouped_topk(
902902
assert hidden_states.shape[0] == gating_output.shape[0], (
903903
"Number of tokens mismatch")
904904

905+
gating_output = gating_output.float()
906+
if e_score_correction_bias is not None:
907+
e_score_correction_bias = e_score_correction_bias.float()
908+
905909
if scoring_func == "softmax":
906910
scores = torch.softmax(gating_output, dim=-1)
907911
elif scoring_func == "sigmoid":

0 commit comments

Comments
 (0)