diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 019d1dd6f..e8e15622b 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -65,6 +65,7 @@ def __init__(self, kvargs): self.quant_type = kvargs.get("quant_type", None) self.quant_cfg_path = kvargs.get("quant_cfg", None) self.mem_fraction = kvargs.get("mem_fraction", 0.9) + self.expert_parallel_mode = kvargs.get("expert_parallel_mode", "etp") self._init_datatype() self._init_config() diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index bf3b210e8..bedd3f306 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -11,7 +11,7 @@ class TransformerLayerInferTpl(TransformerLayerInfer): """ """ - def __init__(self, layer_num, tp_rank, world_size, network_config, mode): + def __init__(self, layer_num, tp_rank, world_size, network_config, mode, tp_split=True): super().__init__(layer_num, tp_rank, world_size, network_config, mode) # need to set by subclass self.eps_ = 1e-5 @@ -21,6 +21,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode): self.tp_o_head_num_ = -1 self.head_dim_ = -1 self.embed_dim_ = -1 + self.tp_split_ = tp_split return def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: @@ -79,7 +80,7 @@ def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_w o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -88,7 +89,7 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return @@ -102,7 +103,7 @@ def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_wei o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -111,7 +112,7 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return @@ -125,7 +126,7 @@ def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateIn o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -134,7 +135,7 @@ def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, l input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index c6b1ab500..c7db48817 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -11,6 +11,10 @@ MultiCOLMMWeight, ROWBMMWeight, COLBMMWeight, + MultiCOLMMWeightNoTp, + ROWBMMWeightNoTp, + COLBMMWeightNoTp, + COLMMWeightNoTp ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight from .fused_moe_weight import FusedMoeWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index ddec7198f..4984d6036 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -3,6 +3,7 @@ from lightllm.utils.dist_utils import get_world_size, get_rank import threading from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod +import os try: HAS_VLLM = True @@ -14,7 +15,7 @@ class FusedMoeWeight(BaseWeight): def __init__( - self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type + self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type, expert_parallel_mode="etp" ): super().__init__() assert HAS_VLLM, "vllm is not installed, you can't use FusedMoeWeight" @@ -28,8 +29,11 @@ def __init__( self.tp_rank_ = get_rank() self.experts_up_projs = [None] * self.n_routed_experts self.experts_gate_projs = [None] * self.n_routed_experts + self.expert_gate_up_proj_etp = None + self.expert_down_proj_etp = None self.w2_list = [None] * self.n_routed_experts self.quant_method = None + self.expert_parallel_mode = expert_parallel_mode self.lock = threading.Lock() def set_quant_method(self, quant_method): @@ -39,6 +43,7 @@ def set_quant_method(self, quant_method): self.quant_method.is_moe = True def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=input_tensor, router_logits=router_logits, @@ -95,27 +100,89 @@ def _fuse(self): delattr(self, "experts_up_projs") delattr(self, "experts_gate_projs") + def _load_hf_weights_etp(self, weights): + world_size_ = get_world_size() + assert self.n_routed_experts % world_size_ == 0 + n_expert_ep = self.n_routed_experts // world_size_ + + # tp to ep here + expert_gate_up_proj_last = None + expert_down_proj_last = None + + for i_experts_ep in range(n_expert_ep): + expert_up_proj = None + expert_gate_proj = None + expert_gate_up_proj = None + expert_down_proj = None + i_experts = i_experts_ep + n_expert_ep * self.tp_rank_ + + if f"{self.weight_prefix}.{i_experts}.up_proj.weight" in weights: + expert_up_proj = weights[f"{self.weight_prefix}.{i_experts}.up_proj.weight"] + + # self.experts_up_proj[i_experts] = expert_up_proj + + if f"{self.weight_prefix}.{i_experts}.gate_proj.weight" in weights: + expert_gate_proj = weights[f"{self.weight_prefix}.{i_experts}.gate_proj.weight"] + # self.experts_gate_proj[i_experts] = expert_gate_proj + + if expert_gate_proj is not None and expert_up_proj is not None: + expert_gate_up_proj = torch.cat([expert_gate_proj, expert_up_proj], dim=0) + self.experts_gate_projs[i_experts_ep] = expert_gate_up_proj # self._cuda(expert_gate_up_proj) + expert_gate_up_proj_last = expert_gate_up_proj + + if f"{self.weight_prefix}.{i_experts}.down_proj.weight" in weights: + expert_down_proj = weights[f"{self.weight_prefix}.{i_experts}.down_proj.weight"] + self.experts_up_projs[i_experts_ep] = expert_down_proj # self._cuda(expert_down_proj) + expert_down_proj_last = expert_down_proj + + with self.lock: + if expert_gate_up_proj_last is not None: + # package, if there is broken experts + + if self.expert_gate_up_proj_etp is None: + self.expert_gate_up_proj_etp = torch.zeros( + (n_expert_ep,) + expert_gate_up_proj_last.shape, dtype=expert_gate_up_proj_last.dtype + ).cuda(self.tp_rank_) + + for i_experts_ep in range(n_expert_ep): + if self.experts_gate_projs[i_experts_ep] is not None: + self.expert_gate_up_proj_etp[i_experts_ep, :] = self.experts_gate_projs[i_experts_ep] + + if expert_down_proj_last is not None: + # package, if there is broken experts + if self.expert_down_proj_etp is None: + self.expert_down_proj_etp = torch.zeros( + (n_expert_ep,) + expert_down_proj_last.shape, dtype=expert_down_proj_last.dtype + ).cuda(self.tp_rank_) + + for i_experts_ep in range(n_expert_ep): + if self.experts_up_projs[i_experts_ep] is not None: + self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep] + def load_hf_weights(self, weights): - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - - self._fuse() + if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp": + self._load_hf_weights_etp(weights) + else: + for i_experts in range(self.n_routed_experts): + w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" + w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" + w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" + + if w1_weight in weights: + self.experts_gate_projs[i_experts] = weights[w1_weight][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : + ] + if w3_weight in weights: + self.experts_up_projs[i_experts] = weights[w3_weight][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : + ] + + if w2_weight in weights: + self.w2_list[i_experts] = weights[w2_weight][ + :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) + ] + + self._fuse() def _cuda(self, cpu_tensor): if self.tp_rank_ is None: @@ -124,4 +191,7 @@ def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_) def verify_load(self): - return self.w1 is not None and self.w2 is not None + if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp": + return True + else: + return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index d188d5511..5920e3f26 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -96,6 +96,24 @@ def load_hf_weights(self, weights): self._post_load_weights() return +class COLMMWeightNoTp(MMWeight): + def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + + def load_hf_weights(self, weights): + weight = None + if self.weight_name in weights: + weight = weights[self.weight_name].to(self.data_type_) + self.weight = weight[:, self.start : self.end] + if self.bias_name in weights: + bias = weights[self.bias_name] + self.bias = bias.to(self.data_type_).cuda(self.tp_rank_) + if weight is None: + return + self._post_load_weights() + return class MultiMMWeight(MMWeightTpl): def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]): @@ -172,6 +190,21 @@ def load_hf_weights(self, weights): self._fuse() return +class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP): + def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): + super().__init__(weight_names, data_type, split_n_embed, bias_names) + + def load_hf_weights(self, weights): + weight = None + for i in range(len(self.weight_names)): + if self.weight_names[i] in weights: + weight = weights[self.weight_names[i]].to(self.data_type_) + self.weights[i] = weight[:, self.starts[i] : self.ends[i]] + if self.has_bias and self.bias_names[i] in weights: + bias = weights[self.bias_names[i]].to(self.data_type_) + self.biases[i] = bias[:, self.starts[i] : self.ends[i]] + self._fuse() + return class BMMWeightTpl(BaseWeightTpl): def __init__(self, data_type): @@ -233,6 +266,19 @@ def __init__( ): super().__init__(weight_name, data_type, split_n_embed, bias_name) +class ROWBMMWeightNoTp(BMMWeight): + load_hf_weights = ROWMMWeight.load_hf_weights + + def __init__( + self, + weight_name, + data_type, + split_n_embed, + bias_name=None, + ): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed class COLBMMWeight(BMMWeight): load_hf_weights = COLMMWeight.load_hf_weights @@ -248,3 +294,21 @@ def __init__( def _post_load_weights(self): self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + +class COLBMMWeightNoTp(BMMWeight): + load_hf_weights = COLMMWeightNoTp.load_hf_weights + + def __init__( + self, + weight_name, + data_type, + split_n_embed, + bias_name=None, + ): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + + def _post_load_weights(self): + self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 11f01fdcb..3684e8227 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -1,4 +1,5 @@ import torch +import os from .mem_manager import MemoryManager from typing import List @@ -10,6 +11,11 @@ def get_cell_size(self): def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty((layer_num, size, head_num, head_dim), dtype=dtype, device="cuda") + # todo, etp or edp use the same work buffer here + # also it can be used for any kernels for work buffer witout save info only + if os.environ.get("ETP_MODE_ENABLED") == "true": + self.work_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.bfloat16, device="cuda") + self.work_buffer.share_memory_() def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 6da45a7c5..6f76312af 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -1,6 +1,6 @@ from typing import Tuple import torch -import torch.functional as F +import torch.nn.functional as F import torch.distributed as dist import numpy as np from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight @@ -18,6 +18,7 @@ from lightllm.models.deepseek2.layer_infer.fused_moe import fused_experts, grouped_topk from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from functools import partial @@ -27,7 +28,7 @@ class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer): def __init__( - self, layer_num, tp_rank, world_size, network_config, mode=[], disable_qk_absorb=False, disable_vo_absorb=False + self, layer_num, tp_rank, world_size, network_config, mode=[], disable_qk_absorb=False, disable_vo_absorb=False, expert_parallel_mode="etp" ): self.tp_k_head_num_ = 1 self.tp_v_head_num_ = 1 @@ -35,6 +36,10 @@ def __init__( self.qk_rope_head_dim = network_config["qk_rope_head_dim"] self.q_lora_rank = network_config["q_lora_rank"] self.kv_lora_rank = network_config["kv_lora_rank"] + self.expert_parallel_mode = expert_parallel_mode + + self.n_routed_experts = network_config["n_routed_experts"] + self.is_moe = ( network_config["n_routed_experts"] is not None and layer_num >= network_config["first_k_dense_replace"] @@ -58,7 +63,8 @@ def __init__( if mscale_all_dim: mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale - super().__init__(layer_num, tp_rank, world_size, network_config, mode) + tp_split = True if expert_parallel_mode == "etp" else False + super().__init__(layer_num, tp_rank, world_size, network_config, mode, tp_split) self.tp_o_head_num_ = self.tp_q_head_num_ self.num_heads = network_config["num_attention_heads"] @@ -75,7 +81,12 @@ def _bind_attention(self): ) self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) if self.is_moe: - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self) + if self.expert_parallel_mode == "etp": + self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp, self) + elif self.expert_parallel_mode == "edp": + self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self) + else: + self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self) else: self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) @@ -87,7 +98,7 @@ def _get_qkv( layer_weight: Deepseek2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - if not self.disable_qk_absorb: + if not self.disable_qk_absorb: # ACC if self.q_lora_rank is None: q_nope = layer_weight.fuse_qk_weight_.mm(input) q_rope = layer_weight.q_rope_proj_.mm(input) @@ -98,7 +109,7 @@ def _get_qkv( q_rope = layer_weight.q_rope_proj_.mm(q) q_nope = q_nope.view(-1, self.tp_q_head_num_, self.kv_lora_rank) q_rope = q_rope.view(-1, self.tp_q_head_num_, self.qk_rope_head_dim) - else: + else: # CC if self.q_lora_rank is None: q = layer_weight.q_weight_.mm(input) else: @@ -149,7 +160,7 @@ def _CC_method( ): num_local_heads = self.num_heads num_local_kv_heads = self.num_kv_heads - if self.world_size_ > 1: + if self.world_size_ > 1 and self.expert_parallel_mode == "etp": num_local_heads //= self.world_size_ num_local_kv_heads //= self.world_size_ if infer_state.use_dynamic_prompt_cache: @@ -181,7 +192,7 @@ def _ACC_method( q_ne, q_pe = q num_local_heads = self.num_heads num_local_kv_heads = self.num_kv_heads - if self.world_size_ > 1: + if self.world_size_ > 1 and self.expert_parallel_mode == "etp": num_local_heads //= self.world_size_ num_local_kv_heads //= self.world_size_ # ACC @@ -368,6 +379,7 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): def _moe_ffn( self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: + hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape @@ -391,3 +403,229 @@ def _moe_ffn( hidden_states.add_(shared_output) return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_etp( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight + ) -> torch.Tensor: + world_size_ = self.world_size_ + # num_local_experts = self.n_shared_experts // world_size_ + # local_expert_offset = self.tp_rank_ * num_local_experts + num_experts_per_token = self.num_experts_per_tok + num_experts = self.n_routed_experts + # num_expert_groups = self.n_group + # num_groups_per_token = self.topk_group + gating_scaling_factor = self.routed_scaling_factor + # gating_normalize_prob = self.norm_topk_prob + rank_self = self.tp_rank_ + + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + + final_hidden_states = torch.empty( + num_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + + # router_logits_len = hidden_states.shape[0]*layer_weight.moe_gate.shape[1] + router_logits = layer_weight.moe_gate.mm(hidden_states) + + # now some parameter is not supported yet + # assert gating_normalize_prob is False + # assert num_expert_groups<=1 + + import lightllm_moe_etp_kernel + + lightllm_moe_etp_kernel.moe_fused_all( + router_logits.contiguous(), + hidden_states.contiguous(), + layer_weight.gate_up_proj.weight.contiguous(), # transpose + layer_weight.down_proj.weight.contiguous(), # transpose + layer_weight.experts.expert_gate_up_proj_etp.contiguous(), + layer_weight.experts.expert_down_proj_etp.contiguous(), + infer_state.mem_manager.work_buffer.contiguous(), + infer_state.mem_manager.work_buffer.nelement(), + final_hidden_states.contiguous(), + rank_self, + gating_scaling_factor, + num_experts, + num_experts_per_token, + num_tokens, + world_size_, + True, + hidden_dim, + layer_weight.gate_up_proj.weight.size(1) // 2, + layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2, + self.n_shared_experts is not None, + ) + + router_logits = None + + return final_hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight + ) -> torch.Tensor: + world_size = self.world_size_ + + num_experts_per_token = self.num_experts_per_tok + num_experts = self.n_routed_experts + num_local_experts = num_experts // world_size + # num_expert_groups = self.n_group + # num_groups_per_token = self.topk_group + gating_scaling_factor = self.routed_scaling_factor + # gating_normalize_prob = self.norm_topk_prob + rank_self = self.tp_rank_ + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + + if self.n_shared_experts is not None: + shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) + # final_hidden_states = torch.empty( + # num_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype + # ) + + router_logits = layer_weight.moe_gate.mm(hidden_states) + + states_expand_permute, expert_weights, invert_permutation, expert_offset = moe_select( + hidden_states, router_logits, num_experts, num_experts_per_token, + 1, 1, + gating_scaling_factor, False, + gating_method="greedy" + ) + device = hidden_states.device + flat_stats = states_expand_permute.view(-1, hidden_dim) + global_exp_local_token = torch.tensor([expert_offset[i+1] - expert_offset[i] for i in range(num_experts)], dtype=torch.int64, device=device) # [num_experts] + + local_exp_global_token = torch.zeros(num_experts, dtype=torch.int64, device=device) # [rano0_local_exp, rank1_local_exp, ...] + dist.all_to_all_single(local_exp_global_token, global_exp_local_token) + + input_splits = global_exp_local_token.reshape(world_size, num_local_experts).sum(dim=1) # [world_size] + output_splists = local_exp_global_token.reshape(world_size, num_local_experts).sum(dim=1) # [world_size] + + local_exp_global_input = torch.zeros((output_splists.sum().item(), hidden_dim), dtype=hidden_states.dtype, device=device) + dist.all_to_all_single(local_exp_global_input, flat_stats, output_split_sizes=output_splists.tolist(), input_split_sizes=input_splits.tolist()) + + input_chunk_idxs = torch.arange(num_experts) + sorted_local_exp_index = input_chunk_idxs.reshape(world_size, num_local_experts).T.ravel() + restore_local_exp_index = input_chunk_idxs.reshape(num_local_experts, world_size).T.ravel() + + # sort chunk by idx + expert_sorted_token = local_exp_global_token.reshape(world_size, -1).sum(dim=0) # [num_local_experts] + sorted_local_exp_global_token = local_exp_global_token.reshape(world_size, -1).transpose(0, 1).contiguous().view(-1) + + def permute_chunks_by_idxs(input: torch.Tensor, split_size: torch.Tensor, sorted_idxs: torch.Tensor): + """ + sort chunks by idx, + """ + splited_input = input.split(split_size.tolist()) + output = torch.cat([splited_input[i] for i in sorted_idxs.tolist()], dim=0) + return output + + expert_sorted_input = permute_chunks_by_idxs(local_exp_global_input, local_exp_global_token, sorted_local_exp_index) + expert_sorted_token_offset = [0] + # new offset + for i in range(num_local_experts): + expert_sorted_token_offset.append(expert_sorted_token_offset[i] + expert_sorted_token[i]) + + down_proj_output = torch.zeros_like(local_exp_global_input) + for i in range(num_local_experts): + token_beg_idx = expert_sorted_token_offset[i] + token_end_idx = expert_sorted_token_offset[i+1] + if token_beg_idx == token_end_idx: + continue + local_expert_idx = i + up_proj_output = F.linear( + expert_sorted_input[token_beg_idx:token_end_idx], + layer_weight.experts.expert_gate_up_proj_etp[local_expert_idx], + ) + # up_proj_output = act_fn(up_proj_output) + act_output = torch.empty(up_proj_output.shape[0], up_proj_output.shape[1] // 2, device=up_proj_output.device, dtype=up_proj_output.dtype) + + silu_and_mul_fwd(up_proj_output, act_output) + + down_proj_output[token_beg_idx : token_end_idx] = F.linear( + act_output, + layer_weight.experts.expert_down_proj_etp[local_expert_idx], + ) + + # restore chunks + restore_down_proj_output = permute_chunks_by_idxs(down_proj_output, sorted_local_exp_global_token, restore_local_exp_index) + input_splits2 = output_splists + output_splits2 = input_splits + + global_exp_local_output = torch.zeros_like(flat_stats, dtype=hidden_states.dtype, device=device) + dist.all_to_all_single(global_exp_local_output, restore_down_proj_output, output_split_sizes=output_splits2.tolist(), input_split_sizes=input_splits2.tolist()) + + final_hidden_states = moe_reduce(global_exp_local_output, expert_weights, invert_permutation, num_experts_per_token) + if shared_output is not None: + final_hidden_states.add_(shared_output) + # now some parameter is not supported yet + # assert gating_normalize_prob is False + # assert num_expert_groups<=1 + return final_hidden_states + +def moe_select(X: torch.Tensor, scores: torch.Tensor, + num_experts: int, num_experts_per_token: int, + num_expert_groups: int = 1, num_groups_per_token: int = 1, + gating_scaling_factor: float = 1.0, + gating_normalize_prob: bool = False, + gating_method: str='greedy'): + origin_shape = X.shape + X_expand_permute = X.view(-1, X.shape[-1]) + _scores = scores.softmax(dim=-1, dtype=torch.float32).view(-1, num_experts).type_as(scores) + if 'greedy' in gating_method and len('greedy') == len(gating_method): + expert_weights, expert_indices = torch.topk(_scores, num_experts_per_token, dim=-1) + elif 'grouped_limited_greedy' in gating_method and len('grouped_limited_greedy') == len(gating_method): + group_scores = ( + _scores.view(_scores.shape[0], num_expert_groups, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=num_groups_per_token, dim=-1, sorted=False + )[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(_scores.shape[0], num_expert_groups, num_experts_per_token // num_expert_groups + ) + .reshape(_scores.shape[0], -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + expert_weights, expert_indices = torch.topk(tmp_scores, num_experts_per_token, dim=-1) + + if num_experts_per_token > 1 and gating_normalize_prob: + denominator = expert_weights.sum(dim=-1, keepdim=True) + 1e-20 + expert_weights = expert_weights / denominator + else: + expert_weights *= gating_scaling_factor + + flat_expert_indices = expert_indices.view(-1) # (seqlen * num_experts_per_token) + + sorted_expert_indices, permute_token_idx = flat_expert_indices.sort(stable=True) + X_expand_permute = X_expand_permute.repeat_interleave(num_experts_per_token, dim=0) # (seqlen * num_experts_per_token, hidden_dim) + + X_expand_permute = X_expand_permute[permute_token_idx] + + invert_permutation = torch.full_like(permute_token_idx, -1, device=scores.device) + for i in range(len(permute_token_idx)): + reidx = permute_token_idx[i] + invert_permutation[reidx] = i + + expert_offset = torch.full((num_experts + 1,), -1, device=scores.device) + ptr = 0 + for i in range(num_experts): + while(ptr < len(sorted_expert_indices) and sorted_expert_indices[ptr] < i): + ptr += 1 + expert_offset[i] = ptr + expert_offset[num_experts] = X_expand_permute.size(0) + X_expand_permute = X_expand_permute.view(*origin_shape[:-1], num_experts_per_token, -1) + invert_permutation = invert_permutation.view(*origin_shape[:-1], num_experts_per_token) + + return X_expand_permute, expert_weights, invert_permutation, expert_offset + + +def moe_reduce(Y: torch.Tensor, expert_weights: torch.Tensor, + invert_permutation: torch.Tensor, num_experts_per_token: int): + Y = Y.view(-1, Y.shape[-1]) + Y_out = Y[invert_permutation].view(*expert_weights.shape, -1) + Y_out = (Y_out * expert_weights.unsqueeze(-1)).sum(dim=-2) # [*, hidden_dim] + return Y_out \ No newline at end of file diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 995ad1f11..de72f27f0 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -6,12 +6,17 @@ ROWMMWeight, ROWMMWeightNoTP, MultiROWMMWeight, + MultiROWMMWeightNoTP, COLMMWeight, + COLMMWeightNoTp, MultiCOLMMWeight, + MultiCOLMMWeightNoTp, NormWeight, FusedMoeWeight, ROWBMMWeight, + ROWBMMWeightNoTp, COLBMMWeight, + COLBMMWeightNoTp, ) from functools import partial @@ -71,9 +76,11 @@ def __init__( quant_cfg=None, disable_qk_absorb=False, disable_vo_absorb=False, + expert_parallel_mode="etp" ): self.disable_qk_absorb = disable_qk_absorb self.disable_vo_absorb = disable_vo_absorb + self.expert_parallel_mode = expert_parallel_mode super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) # mla_type = "ACCM", "MIX" # MIX是prefilled CC,decoding ACC @@ -89,7 +96,9 @@ def _parse_config(self): and self.layer_num_ >= self.network_config_["first_k_dense_replace"] and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 ) - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.world_size_ + self.tp_q_head_num_ = self.network_config_["num_attention_heads"] + if self.expert_parallel_mode == "etp": + self.tp_q_head_num_ //= self.world_size_ self.n_routed_experts = self.network_config_["n_routed_experts"] self.q_lora_rank = self.network_config_["q_lora_rank"] self.qk_nope_head_dim = self.network_config_["qk_nope_head_dim"] @@ -104,7 +113,10 @@ def _init_weight_names(self): self.rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight" def _init_weight(self): - self._init_qkvo() + if self.expert_parallel_mode == "etp": + self._init_qkvo() + else: + self._init_qkvo_dp() if self.is_moe: self._init_moe() else: @@ -112,12 +124,13 @@ def _init_weight(self): self._init_norm() def _load_q_rope(self, q_weight_): - q_split_n_embed_with_rope = ( - (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads // self.world_size_ - ) - q_weight_ = q_weight_[ - q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), : - ] + if self.expert_parallel_mode == "etp": + q_split_n_embed_with_rope = ( + (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads // self.world_size_ + ) + q_weight_ = q_weight_[ + q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), : + ] q_weight_ = q_weight_.transpose(0, 1).contiguous() q_nope_proj_, q_rope_proj_ = torch.split( q_weight_.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), @@ -239,11 +252,105 @@ def _init_qkvo(self): q_split_n_embed, ) - def _load_mlp(self, mlp_prefix, split_inter_size): - self.gate_up_proj = MultiROWMMWeight( - [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + def _init_qkvo_dp(self): + q_split_n_embed = self.qk_nope_head_dim * self.tp_q_head_num_ + q_split_n_embed_with_rope = ( + (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads ) - self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + if self.q_lora_rank is None: + if not self.disable_qk_absorb: # acc + self.fuse_qk_weight_ = MultiROWMMWeightNoTP( + [ + f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + ], + self.data_type_, + [q_split_n_embed_with_rope, self.tp_q_head_num_], + ) + self.fuse_qk_weight_._fuse = partial(fuse_q_kb, self.fuse_qk_weight_, self) + else: # cc + self.q_weight_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", + self.data_type_, + q_split_n_embed_with_rope, + ) + else: + self.q_a_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", + self.data_type_, + self.q_lora_rank, + ) + if not self.disable_qk_absorb: + self.fuse_qk_weight_ = MultiROWMMWeightNoTP( + [ + f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + ], + self.data_type_, + [q_split_n_embed_with_rope, self.tp_q_head_num_], + ) + self.fuse_qk_weight_._fuse = partial(fuse_q_kb, self.fuse_qk_weight_, self) + else: + self.q_b_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", + self.data_type_, + q_split_n_embed_with_rope, + ) + + self.q_rope_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_rope_proj.weight", + self.data_type_, + self.qk_rope_head_dim * self.tp_q_head_num_, + ) + + self.kv_a_proj_with_mqa_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", + self.data_type_, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + if self.disable_qk_absorb: + self.k_b_proj_ = ROWBMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + self.data_type_, + split_n_embed=self.tp_q_head_num_, + ) + if not self.disable_vo_absorb: + self.fuse_vo_weight_ = MultiCOLMMWeightNoTp( + [ + f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", + ], + self.data_type_, + [self.tp_q_head_num_, q_split_n_embed], + ) + self.fuse_vo_weight_._fuse = partial(fuse_vb_o, self.fuse_vo_weight_, self) + else: + self.v_b_proj_ = COLBMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + self.data_type_, + split_n_embed=self.tp_q_head_num_, + ) + if self.disable_vo_absorb: + self.o_weight_ = COLMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", + self.data_type_, + q_split_n_embed, + ) + + + def _load_mlp(self, mlp_prefix, split_inter_size): + if self.expert_parallel_mode == "etp": + self.gate_up_proj = MultiROWMMWeight( + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + ) + self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + elif self.expert_parallel_mode == "edp": + self.gate_up_proj = MultiROWMMWeightNoTP( + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + ) + self.down_proj = COLMMWeightNoTp(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + else: + raise ValueError(f"Invalid expert_parallel_mode: {self.expert_parallel_mode}") def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] @@ -251,8 +358,9 @@ def _init_moe(self): f"model.layers.{self.layer_num_}.mlp.gate.weight", self.data_type_, moe_intermediate_size ) shared_intermediate_size = moe_intermediate_size * self.network_config_["n_shared_experts"] - shared_split_inter_size = shared_intermediate_size // self.world_size_ - self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_split_inter_size) + + num_shards = self.world_size_ if self.expert_parallel_mode == "etp" else 1 + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_intermediate_size // num_shards) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", @@ -262,12 +370,13 @@ def _init_moe(self): n_routed_experts=self.n_routed_experts, split_inter_size=moe_intermediate_size // self.world_size_, data_type=self.data_type_, + expert_parallel_mode=self.expert_parallel_mode, ) def _init_ffn(self): inter_size = self.network_config_["intermediate_size"] - split_inter_size = inter_size // self.world_size_ - self._load_mlp(f"model.layers.{self.layer_num_}.mlp", split_inter_size) + num_shards = self.world_size_ if self.expert_parallel_mode == "etp" else 1 + self._load_mlp(f"model.layers.{self.layer_num_}.mlp", inter_size // num_shards) def _init_norm(self): self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 5cc8dfdd8..58ff271df 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -55,8 +55,9 @@ def _init_mem_manager(self): return def _init_weights(self): + tp_split = True if self.expert_parallel_mode == "etp" else False self.pre_post_weight = self.pre_and_post_weight_class( - self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode + self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode, tp_split=tp_split ) self.trans_layers_weight = [ self.transformer_weight_class( @@ -69,6 +70,7 @@ def _init_weights(self): quant_cfg=self.quant_cfg, disable_qk_absorb=self.disable_qk_absorb, disable_vo_absorb=self.disable_vo_absorb, + expert_parallel_mode=self.expert_parallel_mode, ) for i in range(self.config["n_layer"]) ] @@ -84,11 +86,12 @@ def _init_weights(self): return def _init_infer_layer(self): + tp_split = True if self.expert_parallel_mode == "etp" else False self.pre_infer = self.pre_layer_infer_class( - tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode, tp_split=tp_split ) self.post_infer = self.post_layer_infer_class( - tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode, tp_split=tp_split ) self.layers_infer = [ self.transformer_layer_infer_class( @@ -99,6 +102,7 @@ def _init_infer_layer(self): mode=self.mode, disable_qk_absorb=self.disable_qk_absorb, disable_vo_absorb=self.disable_vo_absorb, + expert_parallel_mode=self.expert_parallel_mode, ) for i in range(self.config["n_layer"]) ] diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index a642a0fe0..8efa3e940 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -16,11 +16,12 @@ class LlamaPostLayerInfer(PostLayerInferTpl): """ """ - def __init__(self, tp_rank, world_size, network_config, mode): + def __init__(self, tp_rank, world_size, network_config, mode, tp_split=True): super().__init__(tp_rank, world_size, network_config, mode) self.eps_ = network_config["rms_norm_eps"] self.vocab_size_ = network_config["vocab_size"] self.embed_dim_ = network_config["n_embed"] + self.tp_split_ = tp_split return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: @@ -89,7 +90,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) last_input = None - if self.world_size_ == 1: + if self.world_size_ == 1 or self.tp_split_ == False: gather_data = logic_batch else: gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index f60fa6127..fa91e6385 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -13,10 +13,14 @@ class LlamaPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, tp_rank, world_size, network_config, mode): + def __init__(self, tp_rank, world_size, network_config, mode, tp_split=True): super().__init__(tp_rank, world_size, network_config, mode) - tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64) - self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) + self.tp_split_ = tp_split + if self.tp_split_: + tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64) + self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) + else: + self.vob_start_id_, self.vob_end_id_ = 0, network_config["vocab_size"] return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -24,7 +28,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) return input_embdings @@ -33,7 +37,7 @@ def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weigh (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 010bde881..90cb3f8ac 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -33,12 +33,17 @@ class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[], tp_split=True): + super().__init__(layer_num, tp_rank, world_size, network_config, mode, tp_split) self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ + self.tp_q_head_num_ = network_config["num_attention_heads"] + self.tp_k_head_num_ = network_config["num_key_value_heads"] + self.tp_v_head_num_ = network_config["num_key_value_heads"] + if tp_split: + self.tp_q_head_num_ //= world_size + self.tp_k_head_num_ //= world_size + self.tp_v_head_num_ //= world_size + self.tp_o_head_num_ = self.tp_q_head_num_ self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] self.embed_dim_ = network_config["hidden_size"] diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 25e9bd10c..0fedca12b 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -4,15 +4,20 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, tp_rank, world_size, data_type, network_config, mode): + def __init__(self, tp_rank, world_size, data_type, network_config, mode, tp_split=True): super().__init__(tp_rank, world_size, data_type, network_config, mode) + self.tp_split_ = tp_split return def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] + if self.tp_split_: + split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + else: + split_start = 0 + split_end = vob_size if "model.embed_tokens.weight" in weights: self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) diff --git a/test/model/test_model_deepseek_v2.py b/test/model/test_model_deepseek_v2.py new file mode 100644 index 000000000..a98ca811e --- /dev/null +++ b/test/model/test_model_deepseek_v2.py @@ -0,0 +1,71 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +import unittest +from model_infer import test_model_inference +from lightllm.models.bloom.model import BloomTpPartModel +from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.models.starcoder.model import StarcoderTpPartModel +from lightllm.models.starcoder2.model import Starcoder2TpPartModel +from lightllm.models.qwen.model import QWenTpPartModel +from lightllm.models.chatglm2.model import ChatGlm2TpPartModel +from lightllm.models.internlm.model import InternlmTpPartModel +from lightllm.models.stablelm.model import StablelmTpPartModel +from lightllm.models.internlm2.model import Internlm2TpPartModel +from lightllm.models.mistral.model import MistralTpPartModel +from lightllm.models.minicpm.model import MiniCPMTpPartModel +from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel +from lightllm.models.phi3.model import Phi3TpPartModel +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.cohere.model import CohereTpPartModel +from lightllm.models.mixtral.model import MixtralTpPartModel +from lightllm.models.qwen2.model import Qwen2TpPartModel +from lightllm.utils.config_utils import get_dtype +from lightllm.utils.config_utils import get_config_json +from test_model import get_model + + +class TestModelInfer(unittest.TestCase): + def test_model_infer(self): + model_dir = "/mnt/llm/DeepSeekV2/DeepSeek-V2-Lite-Chat" + model_class = get_model(model_dir) + data_type = get_dtype(model_dir) + mode = "triton_gqa_flashdecoding" + world_size = 2 + batch_size = 1 + input_len = 32 + output_len = 8 + disable_cudagraph = True + graph_max_batch_size = batch_size + graph_max_len_in_batch = 2048 + quant_type = None + quant_cfg = None + extra_model_kvargs = { + "weight_dir": model_dir, + "mode": mode, + "data_type": data_type, + "disable_cudagraph": disable_cudagraph, + "graph_max_batch_size": graph_max_batch_size, + "graph_max_len_in_batch": graph_max_len_in_batch, + "quant_type": quant_type, + "quant_cfg": quant_cfg, # path for mixed quantization config. + "expert_parallel_mode": "edp", + } + + test_model_inference( + world_size=world_size, + model_class=model_class, + batch_size=batch_size, + input_len=input_len, + output_len=output_len, + extra_model_kvargs=extra_model_kvargs, + ) + return + +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn") + unittest.main()