Skip to content

[feature] add deepseekv2 edp support, based on pr628 #662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, kvargs):
self.quant_type = kvargs.get("quant_type", None)
self.quant_cfg_path = kvargs.get("quant_cfg", None)
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
self.expert_parallel_mode = kvargs.get("expert_parallel_mode", "etp")

self._init_datatype()
self._init_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class TransformerLayerInferTpl(TransformerLayerInfer):
""" """

def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
def __init__(self, layer_num, tp_rank, world_size, network_config, mode, tp_split=True):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
# need to set by subclass
self.eps_ = 1e-5
Expand All @@ -21,6 +21,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
self.tp_o_head_num_ = -1
self.head_dim_ = -1
self.embed_dim_ = -1
self.tp_split_ = tp_split
return

def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
Expand Down Expand Up @@ -79,7 +80,7 @@ def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_w
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -88,7 +89,7 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand All @@ -102,7 +103,7 @@ def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_wei
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -111,7 +112,7 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand All @@ -125,7 +126,7 @@ def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateIn
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -134,7 +135,7 @@ def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, l
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
MultiCOLMMWeight,
ROWBMMWeight,
COLBMMWeight,
MultiCOLMMWeightNoTp,
ROWBMMWeightNoTp,
COLBMMWeightNoTp,
COLMMWeightNoTp
)
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight import FusedMoeWeight
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.utils.dist_utils import get_world_size, get_rank
import threading
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
import os

try:
HAS_VLLM = True
Expand All @@ -14,7 +15,7 @@

class FusedMoeWeight(BaseWeight):
def __init__(
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type, expert_parallel_mode="etp"
):
super().__init__()
assert HAS_VLLM, "vllm is not installed, you can't use FusedMoeWeight"
Expand All @@ -28,8 +29,11 @@ def __init__(
self.tp_rank_ = get_rank()
self.experts_up_projs = [None] * self.n_routed_experts
self.experts_gate_projs = [None] * self.n_routed_experts
self.expert_gate_up_proj_etp = None
self.expert_down_proj_etp = None
self.w2_list = [None] * self.n_routed_experts
self.quant_method = None
self.expert_parallel_mode = expert_parallel_mode
self.lock = threading.Lock()

def set_quant_method(self, quant_method):
Expand All @@ -39,6 +43,7 @@ def set_quant_method(self, quant_method):
self.quant_method.is_moe = True

def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=input_tensor,
router_logits=router_logits,
Expand Down Expand Up @@ -95,27 +100,89 @@ def _fuse(self):
delattr(self, "experts_up_projs")
delattr(self, "experts_gate_projs")

def _load_hf_weights_etp(self, weights):
world_size_ = get_world_size()
assert self.n_routed_experts % world_size_ == 0
n_expert_ep = self.n_routed_experts // world_size_

# tp to ep here
expert_gate_up_proj_last = None
expert_down_proj_last = None

for i_experts_ep in range(n_expert_ep):
expert_up_proj = None
expert_gate_proj = None
expert_gate_up_proj = None
expert_down_proj = None
i_experts = i_experts_ep + n_expert_ep * self.tp_rank_

if f"{self.weight_prefix}.{i_experts}.up_proj.weight" in weights:
expert_up_proj = weights[f"{self.weight_prefix}.{i_experts}.up_proj.weight"]

# self.experts_up_proj[i_experts] = expert_up_proj

if f"{self.weight_prefix}.{i_experts}.gate_proj.weight" in weights:
expert_gate_proj = weights[f"{self.weight_prefix}.{i_experts}.gate_proj.weight"]
# self.experts_gate_proj[i_experts] = expert_gate_proj

if expert_gate_proj is not None and expert_up_proj is not None:
expert_gate_up_proj = torch.cat([expert_gate_proj, expert_up_proj], dim=0)
self.experts_gate_projs[i_experts_ep] = expert_gate_up_proj # self._cuda(expert_gate_up_proj)
expert_gate_up_proj_last = expert_gate_up_proj

if f"{self.weight_prefix}.{i_experts}.down_proj.weight" in weights:
expert_down_proj = weights[f"{self.weight_prefix}.{i_experts}.down_proj.weight"]
self.experts_up_projs[i_experts_ep] = expert_down_proj # self._cuda(expert_down_proj)
expert_down_proj_last = expert_down_proj

with self.lock:
if expert_gate_up_proj_last is not None:
# package, if there is broken experts

if self.expert_gate_up_proj_etp is None:
self.expert_gate_up_proj_etp = torch.zeros(
(n_expert_ep,) + expert_gate_up_proj_last.shape, dtype=expert_gate_up_proj_last.dtype
).cuda(self.tp_rank_)

for i_experts_ep in range(n_expert_ep):
if self.experts_gate_projs[i_experts_ep] is not None:
self.expert_gate_up_proj_etp[i_experts_ep, :] = self.experts_gate_projs[i_experts_ep]

if expert_down_proj_last is not None:
# package, if there is broken experts
if self.expert_down_proj_etp is None:
self.expert_down_proj_etp = torch.zeros(
(n_expert_ep,) + expert_down_proj_last.shape, dtype=expert_down_proj_last.dtype
).cuda(self.tp_rank_)

for i_experts_ep in range(n_expert_ep):
if self.experts_up_projs[i_experts_ep] is not None:
self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep]

def load_hf_weights(self, weights):
for i_experts in range(self.n_routed_experts):
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"

if w1_weight in weights:
self.experts_gate_projs[i_experts] = weights[w1_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]
if w3_weight in weights:
self.experts_up_projs[i_experts] = weights[w3_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]

if w2_weight in weights:
self.w2_list[i_experts] = weights[w2_weight][
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
]

self._fuse()
if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp":
self._load_hf_weights_etp(weights)
else:
for i_experts in range(self.n_routed_experts):
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"

if w1_weight in weights:
self.experts_gate_projs[i_experts] = weights[w1_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]
if w3_weight in weights:
self.experts_up_projs[i_experts] = weights[w3_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]

if w2_weight in weights:
self.w2_list[i_experts] = weights[w2_weight][
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
]

self._fuse()

def _cuda(self, cpu_tensor):
if self.tp_rank_ is None:
Expand All @@ -124,4 +191,7 @@ def _cuda(self, cpu_tensor):
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)

def verify_load(self):
return self.w1 is not None and self.w2 is not None
if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp":
return True
else:
return self.w1 is not None and self.w2 is not None
64 changes: 64 additions & 0 deletions lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def load_hf_weights(self, weights):
self._post_load_weights()
return

class COLMMWeightNoTp(MMWeight):
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

def load_hf_weights(self, weights):
weight = None
if self.weight_name in weights:
weight = weights[self.weight_name].to(self.data_type_)
self.weight = weight[:, self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name]
self.bias = bias.to(self.data_type_).cuda(self.tp_rank_)
if weight is None:
return
self._post_load_weights()
return

class MultiMMWeight(MMWeightTpl):
def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]):
Expand Down Expand Up @@ -172,6 +190,21 @@ def load_hf_weights(self, weights):
self._fuse()
return

class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP):
def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]):
super().__init__(weight_names, data_type, split_n_embed, bias_names)

def load_hf_weights(self, weights):
weight = None
for i in range(len(self.weight_names)):
if self.weight_names[i] in weights:
weight = weights[self.weight_names[i]].to(self.data_type_)
self.weights[i] = weight[:, self.starts[i] : self.ends[i]]
if self.has_bias and self.bias_names[i] in weights:
bias = weights[self.bias_names[i]].to(self.data_type_)
self.biases[i] = bias[:, self.starts[i] : self.ends[i]]
self._fuse()
return

class BMMWeightTpl(BaseWeightTpl):
def __init__(self, data_type):
Expand Down Expand Up @@ -233,6 +266,19 @@ def __init__(
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)

class ROWBMMWeightNoTp(BMMWeight):
load_hf_weights = ROWMMWeight.load_hf_weights

def __init__(
self,
weight_name,
data_type,
split_n_embed,
bias_name=None,
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

class COLBMMWeight(BMMWeight):
load_hf_weights = COLMMWeight.load_hf_weights
Expand All @@ -248,3 +294,21 @@ def __init__(

def _post_load_weights(self):
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)

class COLBMMWeightNoTp(BMMWeight):
load_hf_weights = COLMMWeightNoTp.load_hf_weights

def __init__(
self,
weight_name,
data_type,
split_n_embed,
bias_name=None,
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

def _post_load_weights(self):
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)

6 changes: 6 additions & 0 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import os

from .mem_manager import MemoryManager
from typing import List
Expand All @@ -10,6 +11,11 @@ def get_cell_size(self):

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = torch.empty((layer_num, size, head_num, head_dim), dtype=dtype, device="cuda")
# todo, etp or edp use the same work buffer here
# also it can be used for any kernels for work buffer witout save info only
if os.environ.get("ETP_MODE_ENABLED") == "true":
self.work_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.bfloat16, device="cuda")
self.work_buffer.share_memory_()

def alloc_kv_move_buffer(self, max_req_total_len):
self.kv_move_buffer = torch.empty(
Expand Down
Loading