Skip to content

Commit 54d254b

Browse files
committed
test-backend-ops : cleanup, add moe test for batches
1 parent 54ba263 commit 54d254b

File tree

1 file changed

+32
-35
lines changed

1 file changed

+32
-35
lines changed

Diff for: tests/test-backend-ops.cpp

+32-35
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
2020
size_t size = ggml_nelements(tensor);
2121
std::vector<float> data(size);
2222

23-
std::random_device rd;
24-
2523
#if 0
2624
std::default_random_engine generator(rd());
2725
std::uniform_real_distribution<float> distribution(min, max);
@@ -31,6 +29,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
3129
}
3230
#endif
3331
auto init_thread = [&](size_t start, size_t end) {
32+
std::random_device rd;
3433
std::default_random_engine generator(rd());
3534
std::uniform_real_distribution<float> distribution(min, max);
3635

@@ -341,13 +340,6 @@ struct test_case {
341340
}
342341
}
343342

344-
//if (t1->op == GGML_OP_SOFT_MAX) {
345-
// printf("[%s] ", ggml_op_desc(t1));
346-
// for (int i = 0; i < f1.size(); i++) {
347-
// printf("(%x, %x) ", *(uint32_t*)&f1[i], *(uint32_t*)&f2[i]);
348-
// }
349-
// printf("\n");
350-
//}
351343
double err = nmse(f1.data(), f2.data(), f1.size());
352344
if (err > ud->max_err) {
353345
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
@@ -447,8 +439,9 @@ struct test_case {
447439
return size;
448440
};
449441
for (int i = 0; i < gf->n_nodes; i++) {
450-
if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out)
442+
if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
451443
continue;
444+
}
452445
mem += tensor_op_size(gf->nodes[i]);
453446
}
454447

@@ -1137,23 +1130,26 @@ struct test_sum_rows : public test_case {
11371130
}
11381131
};
11391132

1133+
// Mixtral MOE
11401134
struct test_moe : public test_case {
1141-
const int n_experts = 8;
1142-
const int n_experts_per_tok = 2;
1143-
const int n_tokens = 1;
1144-
const int n_embd = 4096;
1145-
const int n_ff = 14336;
1135+
const int n_experts;
1136+
const int n_experts_per_tok;
1137+
const int n_tokens;
1138+
const int n_embd;
1139+
const int n_ff;
11461140

11471141
std::string op_desc(ggml_tensor * t) override {
11481142
return "MOE";
1143+
11491144
GGML_UNUSED(t);
11501145
}
11511146

11521147
std::string vars() override {
11531148
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
11541149
}
11551150

1156-
test_moe() {
1151+
test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
1152+
: n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
11571153
}
11581154

11591155
ggml_tensor * build_graph(ggml_context * ctx) override {
@@ -1171,24 +1167,20 @@ struct test_moe : public test_case {
11711167

11721168
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
11731169

1174-
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_tokens, num_experts]
1175-
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd)); // [n_tokens, num_experts]
1170+
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
1171+
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
11761172

11771173
// select experts
1178-
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
1174+
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
11791175

11801176
ggml_tensor * weights = ggml_get_rows(ctx,
11811177
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
1182-
printf("get rows args %ld %ld %ld %ld, %ld %ld %ld %ld\n",
1183-
weights->src[0]->ne[0], weights->src[0]->ne[1], weights->src[0]->ne[2], weights->src[0]->ne[3],
1184-
weights->src[1]->ne[0], weights->src[1]->ne[1], weights->src[1]->ne[2], weights->src[1]->ne[3]);
11851178

1186-
1187-
weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
1179+
weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
11881180

11891181
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
11901182

1191-
weights = ggml_div(ctx, weights, weights_sum); // [n_tokens, num_experts_per_tok]
1183+
weights = ggml_div(ctx, weights, weights_sum);
11921184

11931185
// compute expert outputs
11941186
ggml_tensor * moe_out = nullptr;
@@ -1202,9 +1194,9 @@ struct test_moe : public test_case {
12021194

12031195
cur_gate = ggml_silu(ctx, cur_gate);
12041196

1205-
cur_expert = ggml_mul(ctx, cur_up, cur_gate); // [n_tokens, n_embd]
1197+
cur_expert = ggml_mul(ctx, cur_up, cur_gate);
12061198

1207-
cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
1199+
cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
12081200

12091201
cur_expert = ggml_mul(ctx, cur_expert,
12101202
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
@@ -1240,8 +1232,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
12401232
GGML_TYPE_Q6_K
12411233
};
12421234

1243-
test_cases.emplace_back(new test_moe());
1244-
12451235
// unary ops
12461236
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
12471237
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
@@ -1374,6 +1364,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
13741364

13751365
test_cases.emplace_back(new test_sum_rows());
13761366

1367+
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336));
1368+
test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
1369+
13771370
// run tests
13781371
if (mode == MODE_TEST) {
13791372
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
@@ -1389,14 +1382,17 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
13891382
ggml_backend_free(backend_cpu);
13901383

13911384
return n_ok == test_cases.size();
1392-
} else if (mode == MODE_PERF) {
1385+
}
1386+
1387+
if (mode == MODE_PERF) {
13931388
for (auto & test : test_cases) {
13941389
test->eval_perf(backend, op_name);
13951390
}
13961391
return true;
1397-
} else {
1398-
GGML_ASSERT(false);
13991392
}
1393+
1394+
GGML_ASSERT(false);
1395+
return false;
14001396
}
14011397

14021398
static void usage(char ** argv) {
@@ -1469,11 +1465,12 @@ int main(int argc, char ** argv) {
14691465
}
14701466

14711467
printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
1468+
14721469
if (n_ok != ggml_backend_reg_get_count()) {
14731470
printf("\033[1;31mFAIL\033[0m\n");
14741471
return 1;
1475-
} else {
1476-
printf("\033[1;32mOK\033[0m\n");
1477-
return 0;
14781472
}
1473+
1474+
printf("\033[1;32mOK\033[0m\n");
1475+
return 0;
14791476
}

0 commit comments

Comments
 (0)