Skip to content

Commit aedfad1

Browse files
committed
llama : update graph to support MoE
1 parent 861cd67 commit aedfad1

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

llama.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4223,7 +4223,7 @@ struct llm_build_context {
42234223
cb(ffn_inp, "ffn_inp", il);
42244224

42254225
// feed-forward network
4226-
{
4226+
if (model.layers[il].ffn_gate_inp == nullptr) {
42274227
cur = llm_build_norm(ctx0, ffn_inp, hparams,
42284228
model.layers[il].ffn_norm, NULL,
42294229
LLM_NORM_RMS, cb, il);
@@ -4235,6 +4235,51 @@ struct llm_build_context {
42354235
model.layers[il].ffn_down, NULL,
42364236
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
42374237
cb(cur, "ffn_out", il);
4238+
} else {
4239+
// MoE branch
4240+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
4241+
model.layers[il].ffn_norm, NULL,
4242+
LLM_NORM_RMS, cb, il);
4243+
cb(cur, "ffn_norm", il);
4244+
4245+
const int n_experts_per_tok = 2; // TODO: param
4246+
4247+
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
4248+
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
4249+
4250+
// select experts
4251+
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4252+
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1]
4253+
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1]
4254+
4255+
// compute expert outputs
4256+
ggml_tensor * moe_out;
4257+
4258+
for (int i = 0; i < n_experts_per_tok; ++i) {
4259+
ggml_tensor * cur_expert;
4260+
4261+
// TODO: fix
4262+
ggml_tensor ** ffn_up_exp = (ggml_tensor **) model.layers[il].ffn_up_exp;
4263+
ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp;
4264+
ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
4265+
4266+
cur_expert = ggml_mul(ctx0,
4267+
ggml_mul_mat_id(ctx0, ffn_up_exp, selected_experts, i, cur),
4268+
ggml_silu(ctx0,
4269+
ggml_mul_mat_id(ctx0, ffn_gate_exp, selected_experts, i, cur))); // [n_tokens, n_embd]
4270+
4271+
cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, selected_experts, i, cur_expert); // [n_tokens, n_embd]
4272+
cur_expert = ggml_mul(ctx0, cur,
4273+
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
4274+
4275+
if (i == 0) {
4276+
moe_out = cur_expert;
4277+
} else {
4278+
moe_out = ggml_add(ctx0, moe_out, cur_expert);
4279+
}
4280+
}
4281+
4282+
cur = moe_out;
42384283
}
42394284

42404285
cur = ggml_add(ctx0, cur, ffn_inp);

0 commit comments

Comments
 (0)