@@ -4223,7 +4223,7 @@ struct llm_build_context {
4223
4223
cb(ffn_inp, "ffn_inp", il);
4224
4224
4225
4225
// feed-forward network
4226
- {
4226
+ if (model.layers[il].ffn_gate_inp == nullptr) {
4227
4227
cur = llm_build_norm(ctx0, ffn_inp, hparams,
4228
4228
model.layers[il].ffn_norm, NULL,
4229
4229
LLM_NORM_RMS, cb, il);
@@ -4235,6 +4235,51 @@ struct llm_build_context {
4235
4235
model.layers[il].ffn_down, NULL,
4236
4236
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
4237
4237
cb(cur, "ffn_out", il);
4238
+ } else {
4239
+ // MoE branch
4240
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
4241
+ model.layers[il].ffn_norm, NULL,
4242
+ LLM_NORM_RMS, cb, il);
4243
+ cb(cur, "ffn_norm", il);
4244
+
4245
+ const int n_experts_per_tok = 2; // TODO: param
4246
+
4247
+ ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
4248
+ ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
4249
+
4250
+ // select experts
4251
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4252
+ ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1]
4253
+ weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1]
4254
+
4255
+ // compute expert outputs
4256
+ ggml_tensor * moe_out;
4257
+
4258
+ for (int i = 0; i < n_experts_per_tok; ++i) {
4259
+ ggml_tensor * cur_expert;
4260
+
4261
+ // TODO: fix
4262
+ ggml_tensor ** ffn_up_exp = (ggml_tensor **) model.layers[il].ffn_up_exp;
4263
+ ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp;
4264
+ ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
4265
+
4266
+ cur_expert = ggml_mul(ctx0,
4267
+ ggml_mul_mat_id(ctx0, ffn_up_exp, selected_experts, i, cur),
4268
+ ggml_silu(ctx0,
4269
+ ggml_mul_mat_id(ctx0, ffn_gate_exp, selected_experts, i, cur))); // [n_tokens, n_embd]
4270
+
4271
+ cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, selected_experts, i, cur_expert); // [n_tokens, n_embd]
4272
+ cur_expert = ggml_mul(ctx0, cur,
4273
+ ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
4274
+
4275
+ if (i == 0) {
4276
+ moe_out = cur_expert;
4277
+ } else {
4278
+ moe_out = ggml_add(ctx0, moe_out, cur_expert);
4279
+ }
4280
+ }
4281
+
4282
+ cur = moe_out;
4238
4283
}
4239
4284
4240
4285
cur = ggml_add(ctx0, cur, ffn_inp);
0 commit comments