Skip to content

edp with cc+acc is done #669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv
from typing import Tuple

import os


class TransformerLayerInferTpl(TransformerLayerInfer):
""" """
Expand All @@ -21,6 +23,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
self.tp_o_head_num_ = -1
self.head_dim_ = -1
self.embed_dim_ = -1
self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true"
return

def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
Expand Down Expand Up @@ -79,7 +82,7 @@ def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_w
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -88,7 +91,7 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand All @@ -102,7 +105,7 @@ def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_wei
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -111,7 +114,7 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand All @@ -125,7 +128,7 @@ def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateIn
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -134,7 +137,7 @@ def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, l
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
MultiCOLMMWeight,
ROWBMMWeight,
COLBMMWeight,
MultiCOLMMWeightNoTp,
ROWBMMWeightNoTp,
COLBMMWeightNoTp,
COLMMWeightNoTp,
)
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight import FusedMoeWeight
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _load_hf_weights_etp(self, weights):
self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep]

def load_hf_weights(self, weights):
if os.environ.get("ETP_MODE_ENABLED") == "true":
if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true":
self._load_hf_weights_etp(weights)
else:
for i_experts in range(self.n_routed_experts):
Expand Down Expand Up @@ -184,7 +184,7 @@ def _cuda(self, cpu_tensor):
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)

def verify_load(self):
if os.environ.get("ETP_MODE_ENABLED") == "true":
if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true":
return True
else:
return self.w1 is not None and self.w2 is not None
70 changes: 70 additions & 0 deletions lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,73 @@ def __init__(

def _post_load_weights(self):
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)


class COLMMWeightNoTp(MMWeight):
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

def load_hf_weights(self, weights):
weight = None
if self.weight_name in weights:
weight = weights[self.weight_name].to(self.data_type_)
self.weight = weight[:, self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name]
self.bias = bias.to(self.data_type_).cuda(self.tp_rank_)
if weight is None:
return
self._post_load_weights()
return


class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP):
def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]):
super().__init__(weight_names, data_type, split_n_embed, bias_names)

def load_hf_weights(self, weights):
weight = None
for i in range(len(self.weight_names)):
if self.weight_names[i] in weights:
weight = weights[self.weight_names[i]].to(self.data_type_)
self.weights[i] = weight[:, self.starts[i] : self.ends[i]]
if self.has_bias and self.bias_names[i] in weights:
bias = weights[self.bias_names[i]].to(self.data_type_)
self.biases[i] = bias[:, self.starts[i] : self.ends[i]]
self._fuse()
return


class ROWBMMWeightNoTp(BMMWeight):
load_hf_weights = ROWMMWeight.load_hf_weights

def __init__(
self,
weight_name,
data_type,
split_n_embed,
bias_name=None,
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed


class COLBMMWeightNoTp(BMMWeight):
load_hf_weights = COLMMWeightNoTp.load_hf_weights

def __init__(
self,
weight_name,
data_type,
split_n_embed,
bias_name=None,
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

def _post_load_weights(self):
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
16 changes: 15 additions & 1 deletion lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,23 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = torch.empty((layer_num, size, head_num, head_dim), dtype=dtype, device="cuda")
# todo, etp or edp use the same work buffer here
# also it can be used for any kernels for work buffer witout save info only
if os.environ.get("ETP_MODE_ENABLED") == "true":
if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true":
self.work_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.bfloat16, device="cuda")
self.work_buffer.share_memory_()
import lightllm_moe_etp_kernel
import torch.distributed as dist

rank_id = dist.get_rank()
world_size = dist.get_world_size()

# lightllm_moe_etp_kernel.enableP2P(world_size, rank_id)

handle = lightllm_moe_etp_kernel.get_handle(self.work_buffer.contiguous(), rank_id)
handles = [None] * world_size
dist.all_gather_object(handles, handle)
self.handles_work_buffer = handles

lightllm_moe_etp_kernel.init_system(world_size, rank_id, self.work_buffer.contiguous(), handles)

def alloc_kv_move_buffer(self, max_req_total_len):
self.kv_move_buffer = torch.empty(
Expand Down
42 changes: 29 additions & 13 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
self.mla_type = "ACCM"

self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true"

return

def _bind_attention(self):
Expand All @@ -78,8 +80,8 @@ def _bind_attention(self):
)
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
if self.is_moe:
if os.environ.get("ETP_MODE_ENABLED") == "true":
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp, self)
if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true":
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp_edp, self)
else:
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self)
else:
Expand Down Expand Up @@ -155,7 +157,7 @@ def _CC_method(
):
num_local_heads = self.num_heads
num_local_kv_heads = self.num_kv_heads
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
num_local_heads //= self.world_size_
num_local_kv_heads //= self.world_size_
if infer_state.use_dynamic_prompt_cache:
Expand Down Expand Up @@ -187,7 +189,7 @@ def _ACC_method(
q_nope, q_rope = q
num_local_heads = self.num_heads
num_local_kv_heads = self.num_kv_heads
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
num_local_heads //= self.world_size_
num_local_kv_heads //= self.world_size_
# ACC
Expand Down Expand Up @@ -275,6 +277,10 @@ def _context_attention_kernel_origin(
self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: Deepseek2InferStateInfo, layer_weight, out=None
) -> torch.Tensor:
q_nope, q_rope = q

# not support edp yet
# assert self.tp_split_ == True

o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out

if infer_state.use_dynamic_prompt_cache:
Expand Down Expand Up @@ -440,7 +446,7 @@ def _splitfuse_attention_kernel_with_CC(
torch.cuda.default_stream().wait_event(infer_state.end_event)
return o_tensor

def _moe_ffn_etp(
def _moe_ffn_etp_edp(
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
) -> torch.Tensor:
world_size_ = self.world_size_
Expand All @@ -460,17 +466,25 @@ def _moe_ffn_etp(
final_hidden_states = torch.empty(
num_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype
)

# router_logits_len = hidden_states.shape[0]*layer_weight.moe_gate.shape[1]
router_logits = layer_weight.moe_gate.mm(hidden_states)


# now some parameter is not supported yet
# assert gating_normalize_prob is False
# assert num_expert_groups<=1

import lightllm_moe_etp_kernel

lightllm_moe_etp_kernel.moe_fused_all(
is_etp = True
if os.environ.get("ETP_MODE_ENABLED") == "true":
router_logits = layer_weight.moe_gate.mm(hidden_states)
elif os.environ.get("EDP_MODE_ENABLED") == "true":
router_logits = infer_state.mem_manager.work_buffer[ -(num_tokens*num_experts_per_token+hidden_states.nelement()):-hidden_states.nelement()].view( num_tokens ,num_experts_per_token)
router_logits = layer_weight.moe_gate.mm(hidden_states,out=router_logits)
is_etp = False

#print(" hid state addr ", infer_state.mem_manager.work_buffer.data_ptr(),
# hidden_states.data_ptr(),
# hidden_states.shape()
# )

moe_fused_all(
router_logits.contiguous(),
hidden_states.contiguous(),
layer_weight.gate_up_proj.weight.contiguous(), # transpose
Expand All @@ -490,8 +504,10 @@ def _moe_ffn_etp(
layer_weight.gate_up_proj.weight.size(1) // 2,
layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2,
self.n_shared_experts is not None,
is_etp
)

router_logits = None
if os.environ.get("ETP_MODE_ENABLED") == "true":
router_logits = None

return final_hidden_states.view(num_tokens, hidden_dim)
Loading
Loading