@@ -51,7 +51,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
51
51
t.join ();
52
52
}
53
53
54
- if (tensor->type == GGML_TYPE_F32) {
54
+ if (tensor->type == GGML_TYPE_F32 || tensor-> type == GGML_TYPE_I32 ) {
55
55
ggml_backend_tensor_set (tensor, data.data (), 0 , size * sizeof (float ));
56
56
} else if (ggml_is_quantized (tensor->type ) || tensor->type == GGML_TYPE_F16) {
57
57
GGML_ASSERT (size % ggml_blck_size (tensor->type ) == 0 );
@@ -233,14 +233,18 @@ static bool ggml_is_view_op(enum ggml_op op) {
233
233
struct test_case {
234
234
virtual ~test_case () {}
235
235
236
+ virtual std::string op_desc (ggml_tensor * t) {
237
+ return ggml_op_desc (t);
238
+ }
239
+
236
240
virtual std::string vars () {
237
241
return " " ;
238
242
}
239
243
240
244
virtual ggml_tensor * build_graph (ggml_context * ctx) = 0;
241
245
242
246
virtual double max_nmse_err () {
243
- return 1e-6 ;
247
+ return 1e-7 ;
244
248
}
245
249
246
250
virtual void initialize_tensors (ggml_context * ctx) {
@@ -270,13 +274,13 @@ struct test_case {
270
274
271
275
ggml_tensor * out = build_graph (ctx);
272
276
273
- if (op_name != nullptr && strcmp ( ggml_op_desc ( out), op_name) != 0 ) {
274
- // printf(" %s: skipping\n", ggml_op_desc (out));
277
+ if (op_name != nullptr && op_desc ( out) != op_name ) {
278
+ // printf(" %s: skipping\n", op_desc (out).c_str( ));
275
279
ggml_free (ctx);
276
280
return true ;
277
281
}
278
282
279
- printf (" %s(%s): " , ggml_op_desc (out), vars ().c_str ());
283
+ printf (" %s(%s): " , op_desc (out). c_str ( ), vars ().c_str ());
280
284
fflush (stdout);
281
285
282
286
// check if backends support op
@@ -317,29 +321,40 @@ struct test_case {
317
321
for (size_t i = 0 ; i < f1.size (); i++) {
318
322
// check for nans
319
323
if (std::isnan (f1[i]) || std::isnan (f2[i])) {
320
- printf (" NaN at index %zu " , i);
324
+ printf (" [%s] NaN at index %zu " , ggml_op_desc (t1) , i);
321
325
ud->ok = false ;
322
326
return true ;
323
327
}
324
328
// check for infs: both must be inf of the same sign, or both must be finite
325
329
if (isinf_or_max (f1[i]) || isinf_or_max (f2[i])) {
326
330
if (isinf_or_max (f1[i]) && isinf_or_max (f2[i])) {
327
331
if (std::signbit (f1[i]) != std::signbit (f2[i])) {
328
- printf (" inf sign mismatch: %f %f " , f1[i], f2[i]);
332
+ printf (" [%s] inf sign mismatch: %f %f " , ggml_op_desc (t1) , f1[i], f2[i]);
329
333
ud->ok = false ;
330
334
return true ;
331
335
}
332
336
} else {
333
- printf (" inf mismatch: %f %f " , f1[i], f2[i]);
337
+ printf (" [%s] inf mismatch: %f %f " , ggml_op_desc (t1) , f1[i], f2[i]);
334
338
ud->ok = false ;
335
339
return true ;
336
340
}
337
341
}
338
342
}
339
343
344
+ // if (t1->op == GGML_OP_SOFT_MAX) {
345
+ // printf("[%s] ", ggml_op_desc(t1));
346
+ // for (int i = 0; i < f1.size(); i++) {
347
+ // printf("(%x, %x) ", *(uint32_t*)&f1[i], *(uint32_t*)&f2[i]);
348
+ // }
349
+ // printf("\n");
350
+ // }
340
351
double err = nmse (f1.data (), f2.data (), f1.size ());
341
352
if (err > ud->max_err ) {
342
- printf (" NMSE = %f " , err);
353
+ printf (" [%s] NMSE = %f " , ggml_op_desc (t1), err);
354
+ // for (int i = 0; i < f1.size(); i++) {
355
+ // printf("(%f, %f) ", f1[i], f2[i]);
356
+ // }
357
+ // printf("\n");
343
358
ud->ok = false ;
344
359
}
345
360
return true ;
@@ -374,13 +389,13 @@ struct test_case {
374
389
375
390
ggml_tensor * out = build_graph (ctx);
376
391
377
- if (op_name != nullptr && strcmp ( ggml_op_desc ( out), op_name) != 0 ) {
378
- // printf(" %s: skipping\n", ggml_op_desc (out));
392
+ if (op_name != nullptr && op_desc ( out) != op_name ) {
393
+ // printf(" %s: skipping\n", op_desc (out).c_str( ));
379
394
ggml_free (ctx);
380
395
return true ;
381
396
}
382
397
383
- int len = printf (" %s(%s): " , ggml_op_desc (out), vars ().c_str ());
398
+ int len = printf (" %s(%s): " , op_desc (out). c_str ( ), vars ().c_str ());
384
399
fflush (stdout);
385
400
386
401
// check if backends support op
@@ -1122,6 +1137,91 @@ struct test_sum_rows : public test_case {
1122
1137
}
1123
1138
};
1124
1139
1140
+ struct test_moe : public test_case {
1141
+ const int n_experts = 8 ;
1142
+ const int n_experts_per_tok = 2 ;
1143
+ const int n_tokens = 1 ;
1144
+ const int n_embd = 4096 ;
1145
+ const int n_ff = 14336 ;
1146
+
1147
+ std::string op_desc (ggml_tensor * t) override {
1148
+ return " MOE" ;
1149
+ GGML_UNUSED (t);
1150
+ }
1151
+
1152
+ std::string vars () override {
1153
+ return VARS_TO_STR5 (n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
1154
+ }
1155
+
1156
+ test_moe () {
1157
+ }
1158
+
1159
+ ggml_tensor * build_graph (ggml_context * ctx) override {
1160
+ ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_experts);
1161
+
1162
+ std::vector<ggml_tensor *> ffn_up_exp (n_experts);
1163
+ std::vector<ggml_tensor *> ffn_gate_exp (n_experts);
1164
+ std::vector<ggml_tensor *> ffn_down_exp (n_experts);
1165
+
1166
+ for (int i = 0 ; i < n_experts; ++i) {
1167
+ ffn_up_exp[i] = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_ff);
1168
+ ffn_gate_exp[i] = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_ff);
1169
+ ffn_down_exp[i] = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_ff, n_embd);
1170
+ }
1171
+
1172
+ ggml_tensor * cur = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_tokens);
1173
+
1174
+ ggml_tensor * logits = ggml_mul_mat (ctx, ffn_gate_inp, cur); // [n_tokens, num_experts]
1175
+ ggml_tensor * probs = ggml_soft_max (ctx, logits); // [n_tokens, num_experts]
1176
+
1177
+ // select experts
1178
+ ggml_tensor * selected_experts = ggml_top_k (ctx, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
1179
+
1180
+ ggml_tensor * weights = ggml_get_rows (ctx,
1181
+ ggml_reshape_3d (ctx, probs, 1 , n_experts, n_tokens), selected_experts);
1182
+ printf (" get rows args %ld %ld %ld %ld, %ld %ld %ld %ld\n " ,
1183
+ weights->src [0 ]->ne [0 ], weights->src [0 ]->ne [1 ], weights->src [0 ]->ne [2 ], weights->src [0 ]->ne [3 ],
1184
+ weights->src [1 ]->ne [0 ], weights->src [1 ]->ne [1 ], weights->src [1 ]->ne [2 ], weights->src [1 ]->ne [3 ]);
1185
+
1186
+
1187
+ weights = ggml_reshape_2d (ctx, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
1188
+
1189
+ ggml_tensor * weights_sum = ggml_sum_rows (ctx, weights);
1190
+
1191
+ weights = ggml_div (ctx, weights, weights_sum); // [n_tokens, num_experts_per_tok]
1192
+
1193
+ // compute expert outputs
1194
+ ggml_tensor * moe_out = nullptr ;
1195
+
1196
+ for (int i = 0 ; i < n_experts_per_tok; ++i) {
1197
+ ggml_tensor * cur_expert;
1198
+
1199
+ ggml_tensor * cur_up = ggml_mul_mat_id (ctx, ffn_up_exp.data (), n_experts, selected_experts, i, cur);
1200
+
1201
+ ggml_tensor * cur_gate = ggml_mul_mat_id (ctx, ffn_gate_exp.data (), n_experts, selected_experts, i, cur);
1202
+
1203
+ cur_gate = ggml_silu (ctx, cur_gate);
1204
+
1205
+ cur_expert = ggml_mul (ctx, cur_up, cur_gate); // [n_tokens, n_embd]
1206
+
1207
+ cur_expert = ggml_mul_mat_id (ctx, ffn_down_exp.data (), n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
1208
+
1209
+ cur_expert = ggml_mul (ctx, cur_expert,
1210
+ ggml_view_2d (ctx, weights, 1 , n_tokens, weights->nb [1 ], i*weights->nb [0 ]));
1211
+
1212
+ if (i == 0 ) {
1213
+ moe_out = cur_expert;
1214
+ } else {
1215
+ moe_out = ggml_add (ctx, moe_out, cur_expert);
1216
+ }
1217
+ }
1218
+
1219
+ cur = moe_out;
1220
+
1221
+ return cur;
1222
+ }
1223
+ };
1224
+
1125
1225
enum test_mode {
1126
1226
MODE_TEST,
1127
1227
MODE_PERF,
@@ -1140,11 +1240,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
1140
1240
GGML_TYPE_Q6_K
1141
1241
};
1142
1242
1243
+ test_cases.emplace_back (new test_moe ());
1244
+
1143
1245
// unary ops
1144
1246
for (int op = 0 ; op < GGML_UNARY_OP_COUNT; op++) {
1145
1247
test_cases.emplace_back (new test_unary ((ggml_unary_op) op));
1146
1248
}
1147
1249
1250
+ test_cases.emplace_back (new test_get_rows (GGML_TYPE_F32, 1 , 8 , 2 , 1 , false ));
1148
1251
for (ggml_type type : all_types) {
1149
1252
for (int b : {1 , 7 }) {
1150
1253
for (bool v : {false , true }) {
@@ -1265,6 +1368,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
1265
1368
test_cases.emplace_back (new test_concat ());
1266
1369
1267
1370
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
1371
+ test_cases.emplace_back (new test_argsort (GGML_TYPE_F32, {8 , 1 , 1 , 1 }, order));
1268
1372
test_cases.emplace_back (new test_argsort (GGML_TYPE_F32, {16 , 10 , 10 , 10 }, order));
1269
1373
}
1270
1374
0 commit comments