@@ -1540,16 +1540,18 @@ struct test_flash_attn_ext : public test_case {
1540
1540
1541
1541
const float max_bias; // ALiBi
1542
1542
1543
+ const ggml_type type_KV;
1544
+
1543
1545
std::string vars () override {
1544
- return VARS_TO_STR6 (hs, nh, kv, nb, mask, max_bias);
1546
+ return VARS_TO_STR7 (hs, nh, kv, nb, mask, max_bias, type_KV );
1545
1547
}
1546
1548
1547
1549
double max_nmse_err () override {
1548
1550
return 5e-4 ;
1549
1551
}
1550
1552
1551
- test_flash_attn_ext (int64_t hs = 128 , int64_t nh = 32 , int64_t kv = 96 , int64_t nb = 8 , bool mask = true , float max_bias = 0 .0f )
1552
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias) {}
1553
+ test_flash_attn_ext (int64_t hs = 128 , int64_t nh = 32 , int64_t kv = 96 , int64_t nb = 8 , bool mask = true , float max_bias = 0 .0f , ggml_type type_KV = GGML_TYPE_F16 )
1554
+ : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
1553
1555
1554
1556
ggml_tensor * build_graph (ggml_context * ctx) override {
1555
1557
ggml_tensor * q = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, hs, nb, nh, 1 );
@@ -2238,7 +2240,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2238
2240
for (int nh : { 32 , }) {
2239
2241
for (int kv : { 512 , 1024 , }) {
2240
2242
for (int nb : { 1 , 2 , 4 , 8 , }) {
2241
- test_cases.emplace_back (new test_flash_attn_ext (hs, nh, kv, nb, mask, max_bias));
2243
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
2244
+ test_cases.emplace_back (new test_flash_attn_ext (hs, nh, kv, nb, mask, max_bias, type_KV));
2245
+ }
2242
2246
}
2243
2247
}
2244
2248
}
0 commit comments