Skip to content

Commit cefebb3

Browse files
committed
test-backend-ops : add moe test
1 parent e640cbe commit cefebb3

File tree

1 file changed

+116
-12
lines changed

1 file changed

+116
-12
lines changed

Diff for: tests/test-backend-ops.cpp

+116-12
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
5151
t.join();
5252
}
5353

54-
if (tensor->type == GGML_TYPE_F32) {
54+
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
5555
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
5656
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
5757
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
@@ -233,14 +233,18 @@ static bool ggml_is_view_op(enum ggml_op op) {
233233
struct test_case {
234234
virtual ~test_case() {}
235235

236+
virtual std::string op_desc(ggml_tensor * t) {
237+
return ggml_op_desc(t);
238+
}
239+
236240
virtual std::string vars() {
237241
return "";
238242
}
239243

240244
virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
241245

242246
virtual double max_nmse_err() {
243-
return 1e-6;
247+
return 1e-7;
244248
}
245249

246250
virtual void initialize_tensors(ggml_context * ctx) {
@@ -270,13 +274,13 @@ struct test_case {
270274

271275
ggml_tensor * out = build_graph(ctx);
272276

273-
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
274-
//printf(" %s: skipping\n", ggml_op_desc(out));
277+
if (op_name != nullptr && op_desc(out) != op_name) {
278+
//printf(" %s: skipping\n", op_desc(out).c_str());
275279
ggml_free(ctx);
276280
return true;
277281
}
278282

279-
printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
283+
printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
280284
fflush(stdout);
281285

282286
// check if backends support op
@@ -317,29 +321,40 @@ struct test_case {
317321
for (size_t i = 0; i < f1.size(); i++) {
318322
// check for nans
319323
if (std::isnan(f1[i]) || std::isnan(f2[i])) {
320-
printf("NaN at index %zu ", i);
324+
printf("[%s] NaN at index %zu ", ggml_op_desc(t1), i);
321325
ud->ok = false;
322326
return true;
323327
}
324328
// check for infs: both must be inf of the same sign, or both must be finite
325329
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
326330
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
327331
if (std::signbit(f1[i]) != std::signbit(f2[i])) {
328-
printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
332+
printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
329333
ud->ok = false;
330334
return true;
331335
}
332336
} else {
333-
printf("inf mismatch: %f %f ", f1[i], f2[i]);
337+
printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
334338
ud->ok = false;
335339
return true;
336340
}
337341
}
338342
}
339343

344+
//if (t1->op == GGML_OP_SOFT_MAX) {
345+
// printf("[%s] ", ggml_op_desc(t1));
346+
// for (int i = 0; i < f1.size(); i++) {
347+
// printf("(%x, %x) ", *(uint32_t*)&f1[i], *(uint32_t*)&f2[i]);
348+
// }
349+
// printf("\n");
350+
//}
340351
double err = nmse(f1.data(), f2.data(), f1.size());
341352
if (err > ud->max_err) {
342-
printf("NMSE = %f ", err);
353+
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
354+
//for (int i = 0; i < f1.size(); i++) {
355+
// printf("(%f, %f) ", f1[i], f2[i]);
356+
//}
357+
//printf("\n");
343358
ud->ok = false;
344359
}
345360
return true;
@@ -374,13 +389,13 @@ struct test_case {
374389

375390
ggml_tensor * out = build_graph(ctx);
376391

377-
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
378-
//printf(" %s: skipping\n", ggml_op_desc(out));
392+
if (op_name != nullptr && op_desc(out) != op_name) {
393+
//printf(" %s: skipping\n", op_desc(out).c_str());
379394
ggml_free(ctx);
380395
return true;
381396
}
382397

383-
int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
398+
int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
384399
fflush(stdout);
385400

386401
// check if backends support op
@@ -1122,6 +1137,91 @@ struct test_sum_rows : public test_case {
11221137
}
11231138
};
11241139

1140+
struct test_moe : public test_case {
1141+
const int n_experts = 8;
1142+
const int n_experts_per_tok = 2;
1143+
const int n_tokens = 1;
1144+
const int n_embd = 4096;
1145+
const int n_ff = 14336;
1146+
1147+
std::string op_desc(ggml_tensor * t) override {
1148+
return "MOE";
1149+
GGML_UNUSED(t);
1150+
}
1151+
1152+
std::string vars() override {
1153+
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
1154+
}
1155+
1156+
test_moe() {
1157+
}
1158+
1159+
ggml_tensor * build_graph(ggml_context * ctx) override {
1160+
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
1161+
1162+
std::vector<ggml_tensor *> ffn_up_exp(n_experts);
1163+
std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
1164+
std::vector<ggml_tensor *> ffn_down_exp(n_experts);
1165+
1166+
for (int i = 0; i < n_experts; ++i) {
1167+
ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
1168+
ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
1169+
ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
1170+
}
1171+
1172+
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
1173+
1174+
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_tokens, num_experts]
1175+
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_tokens, num_experts]
1176+
1177+
// select experts
1178+
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
1179+
1180+
ggml_tensor * weights = ggml_get_rows(ctx,
1181+
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
1182+
printf("get rows args %ld %ld %ld %ld, %ld %ld %ld %ld\n",
1183+
weights->src[0]->ne[0], weights->src[0]->ne[1], weights->src[0]->ne[2], weights->src[0]->ne[3],
1184+
weights->src[1]->ne[0], weights->src[1]->ne[1], weights->src[1]->ne[2], weights->src[1]->ne[3]);
1185+
1186+
1187+
weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
1188+
1189+
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
1190+
1191+
weights = ggml_div(ctx, weights, weights_sum); // [n_tokens, num_experts_per_tok]
1192+
1193+
// compute expert outputs
1194+
ggml_tensor * moe_out = nullptr;
1195+
1196+
for (int i = 0; i < n_experts_per_tok; ++i) {
1197+
ggml_tensor * cur_expert;
1198+
1199+
ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
1200+
1201+
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
1202+
1203+
cur_gate = ggml_silu(ctx, cur_gate);
1204+
1205+
cur_expert = ggml_mul(ctx, cur_up, cur_gate); // [n_tokens, n_embd]
1206+
1207+
cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
1208+
1209+
cur_expert = ggml_mul(ctx, cur_expert,
1210+
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
1211+
1212+
if (i == 0) {
1213+
moe_out = cur_expert;
1214+
} else {
1215+
moe_out = ggml_add(ctx, moe_out, cur_expert);
1216+
}
1217+
}
1218+
1219+
cur = moe_out;
1220+
1221+
return cur;
1222+
}
1223+
};
1224+
11251225
enum test_mode {
11261226
MODE_TEST,
11271227
MODE_PERF,
@@ -1140,11 +1240,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
11401240
GGML_TYPE_Q6_K
11411241
};
11421242

1243+
test_cases.emplace_back(new test_moe());
1244+
11431245
// unary ops
11441246
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
11451247
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
11461248
}
11471249

1250+
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
11481251
for (ggml_type type : all_types) {
11491252
for (int b : {1, 7}) {
11501253
for (bool v : {false, true}) {
@@ -1265,6 +1368,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
12651368
test_cases.emplace_back(new test_concat());
12661369

12671370
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
1371+
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
12681372
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
12691373
}
12701374

0 commit comments

Comments
 (0)