Skip to content

fuse fp8 quant in kv copying and add flashinfer decode mla operator in the attention module #737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions lightllm/models/deepseek2/flashinfer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import torch
import numpy as np
import torch.distributed as dist
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.utils.envs_utils import enable_env_vars
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index


class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo):
def __init__(self):
super().__init__()
self.prefill_wrapper = None
self.decode_wrapper = None
self.flashinfer_extra_state = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
super().init_some_extra_state(model, input_ids)
self.flashinfer_extra_state = model.flashinfer_extra_state

import flashinfer

if not self.is_prefill:
if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA"):
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
self.kv_indices = torch.empty(
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
).to(input_ids.device)
repack_kv_index(
self.req_manager.req_to_token_indexs,
self.b_req_idx,
self.b_seq_len,
self.b_start_loc,
self.max_len_in_batch,
self.kv_indices,
)
if self.decode_wrapper is None:
self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.flashinfer_extra_state.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.q_indptr,
kv_indices=self.kv_indices,
kv_indptr=self.kv_starts,
kv_len_arr=self.b_seq_len,
)
self.decode_wrapper.plan(
self.q_indptr,
self.kv_starts,
self.kv_indices,
self.b_seq_len,
self.flashinfer_extra_state.tp_q_head_num,
self.flashinfer_extra_state.kv_lora_rank,
self.flashinfer_extra_state.qk_rope_head_dim,
1,
False, # causal
self.flashinfer_extra_state.softmax_scale,
self.flashinfer_extra_state.q_data_type,
self.flashinfer_extra_state.kv_data_type,
)
else:
if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"):
q_starts = torch.cat(
[self.b_start_loc, self.b_start_loc[-1:] + (self.b_seq_len - self.b_ready_cache_len)[-1:]], dim=0
).int()
kv_starts = torch.cat(
[self.b_kv_start_loc, self.b_kv_start_loc[-1:] + self.b_seq_len[-1:]], dim=0
).int()
if self.prefill_wrapper is None:
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_extra_state.workspace_buffer, "NHD"
)
self.prefill_wrapper.plan(
qo_indptr=q_starts,
kv_indptr=kv_starts,
num_qo_heads=self.flashinfer_extra_state.tp_q_head_num,
num_kv_heads=self.flashinfer_extra_state.tp_q_head_num,
head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim
+ self.flashinfer_extra_state.qk_rope_head_dim,
head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim,
q_data_type=self.flashinfer_extra_state.q_data_type,
causal=True,
sm_scale=self.flashinfer_extra_state.softmax_scale,
)
return

def copy_for_cuda_graph(self, new_infer_state):
super().copy_for_cuda_graph(new_infer_state)
if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA") and not self.is_prefill:
self.decode_wrapper.plan(
new_infer_state.q_indptr,
new_infer_state.kv_starts,
new_infer_state.kv_indices,
new_infer_state.b_seq_len,
new_infer_state.flashinfer_extra_state.tp_q_head_num,
new_infer_state.flashinfer_extra_state.kv_lora_rank,
new_infer_state.flashinfer_extra_state.qk_rope_head_dim,
1,
False, # causal
new_infer_state.flashinfer_extra_state.softmax_scale,
new_infer_state.flashinfer_extra_state.q_data_type,
new_infer_state.flashinfer_extra_state.kv_data_type,
)
return
108 changes: 86 additions & 22 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
Expand All @@ -20,10 +21,11 @@
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo
from functools import partial
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
import os
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
from lightllm.utils.envs_utils import enable_env_vars


class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
Expand Down Expand Up @@ -67,7 +69,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
self.tp_o_head_num_ = self.tp_q_head_num_
self.num_heads = network_config["num_attention_heads"]
self.num_kv_heads = network_config["num_key_value_heads"]
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
return

def _bind_func(self):
Expand Down Expand Up @@ -96,18 +97,33 @@ def _bind_attention(self):
)
else:
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
)
if self.enable_cc_method:
if "triton_fp8kv" in self.mode:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA"):
self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
)
if self.enable_cc_method:
if "triton_fp8kv" in self.mode:
if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"):
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
)
else:
if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"):
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
)
else:
if "triton_fp8kv" in self.mode:
self._context_attention_kernel = partial(
Expand Down Expand Up @@ -205,6 +221,38 @@ def _decompress_kv(
k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
return k_nope, k_rope, v

def _context_attention_flashinfer_kernel_with_CC(
self,
q: torch.Tensor,
kv,
infer_state: Deepseek2FlashInferStateInfo,
layer_weight: Deepseek2TransformerLayerWeight,
out=None,
) -> torch.Tensor:
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
o_tensor = (
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
)
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
return o_tensor

def _context_attention_flashinfer_kernel_with_CC_fp8(
self,
q: torch.Tensor,
kv,
infer_state: Deepseek2FlashInferStateInfo,
layer_weight: Deepseek2TransformerLayerWeight,
out=None,
) -> torch.Tensor:
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
o_tensor = (
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
)
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
return o_tensor

def _context_attention_kernel_with_CC(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -345,6 +393,25 @@ def _context_attention_kernel_origin_fp8(

return o_tensor

def _token_gqa_decode_attention_flashinfer(
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
):
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)

kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)

infer_state.decode_wrapper.run(
q_nope,
q_rope,
kv[:, :, : -self.qk_rope_head_dim],
kv[:, :, -self.qk_rope_head_dim :],
out=o_tensor,
return_lse=False,
)
return o_tensor

def _token_gqa_decode_attention_flashdecoding(
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
):
Expand All @@ -354,7 +421,7 @@ def _token_gqa_decode_attention_flashdecoding(
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)

if self.enable_opt_decoding_mha:
if enable_env_vars("ENABLE_OPT_DECODE_MHA"):
q = torch.cat([q_nope, q_rope], dim=-1)
q_nope, q_rope = None, None
import lightllm_ppl_mla
Expand All @@ -368,7 +435,7 @@ def _token_gqa_decode_attention_flashdecoding(
infer_state.b_req_idx,
self.softmax_scale,
q.shape[-1],
q_nope.shape[-1],
self.kv_lora_rank,
)
return o_tensor
else:
Expand Down Expand Up @@ -421,16 +488,13 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
return

def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager):
quant_method = vLLMFP8w8a8QuantizationMethod()
quant, scale = quant_method.quantize_scaled_mm_fp8(buffer.reshape(-1, buffer.shape[-1]))
destindex_copy_kv(
quant.T.unsqueeze(1)[:, :, : self.kv_lora_rank].view(torch.uint8),
quant.T.unsqueeze(1)[:, :, self.kv_lora_rank :].view(torch.uint8),
destindex_copy_kv_fp8(
buffer[:, :, : self.kv_lora_rank],
buffer[:, :, self.kv_lora_rank :],
mem_index,
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank],
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2],
mem_manager.kv_buffer[self.layer_num_][:, :, -2:],
scale.to(buffer.dtype).view(torch.uint8),
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn),
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn),
mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype),
)
return

Expand Down
32 changes: 31 additions & 1 deletion lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,54 @@
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights

from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
from lightllm.utils.log_utils import init_logger
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
from lightllm.utils.envs_utils import enable_env_vars


logger = init_logger(__name__)


class FlashInferStateExtraInfo:
def __init__(self, model):
num_heads = model.config["num_attention_heads"]
self.tp_q_head_num = num_heads if enable_env_vars("ENABLE_DP") else num_heads // model.world_size_
self.qk_nope_head_dim = model.qk_nope_head_dim
self.qk_rope_head_dim = model.qk_rope_head_dim
self.kv_lora_rank = model.kv_lora_rank
self.q_data_type = model.data_type
self.kv_data_type = model.data_type
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(model.tp_rank_)
self.max_seq_length = model.max_seq_length
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
if model.config["rope_scaling"] is not None:
rope_scaling = model.config["rope_scaling"]
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0)
scaling_factor = rope_scaling["factor"]
if mscale_all_dim:
mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale


class Deepseek2TpPartModel(LlamaTpPartModel):
# weight class
transformer_weight_class = Deepseek2TransformerLayerWeight

# infer class
transformer_layer_infer_class = Deepseek2TransformerLayerInfer

enable_flashinfer = enable_env_vars("ENABLE_FLASHINFER_PREFILLED") or enable_env_vars(
"ENABLE_FLASHINFER_DECODE_MLA"
)

# infer state class
infer_state_class = Deepseek2InferStateInfo
infer_state_class = Deepseek2FlashInferStateInfo if enable_flashinfer else Deepseek2InferStateInfo

def __init__(self, kvargs):
super().__init__(kvargs)
Expand All @@ -37,6 +65,8 @@ def _init_some_value(self):
self.q_lora_rank = self.config["q_lora_rank"]
self.kv_lora_rank = self.config["kv_lora_rank"]
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
if self.enable_flashinfer:
self.flashinfer_extra_state = FlashInferStateExtraInfo(self)

def _init_custom(self):
self._init_to_get_yarn_rotary()
Expand Down
Loading
Loading