Skip to content

Commit 0e4a029

Browse files
committed
[SYCL][E2E][Joint Matrix] Add k=32 for bfloat16 tests
1 parent 28e8416 commit 0e4a029

4 files changed

+14
-0
lines changed

sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,11 @@ int main() {
229229
test_ewops_c<float, 16, 16>();
230230
// This combination is not currently supported for sub group size = 32 in IGC
231231
#if (!defined(SG_SZ) || SG_SZ != 32)
232+
test_ewops_ab<bfloat16, 1, 32, use::a, layout::row_major, 1>();
232233
test_ewops_ab<bfloat16, 32, 16, use::a, layout::row_major, 1>();
234+
test_ewops_ab<bfloat16, 32, 32, use::a, layout::row_major, 1>();
233235
test_ewops_ab<bfloat16, 16, 64, use::b, layout::ext_intel_packed, 2>();
236+
test_ewops_ab<bfloat16, 32, 64, use::b, layout::ext_intel_packed, 2>();
234237
test_ewops_c<float, 1, 64>();
235238
test_ewops_c<float, 32, 64>();
236239
#endif

sycl/test-e2e/Matrix/element_wise_ops_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ int main() {
133133
// IGC
134134
passed &= test<bfloat16, float, 16, 16, 16, 2, class pvc_bf16_16x16x16>();
135135
passed &= test<bfloat16, float, 1, 64, 16, 2, class pvc_bf16_1x64x16>();
136+
passed &= test<bfloat16, float, 1, 64, 32, 2, class pvc_bf16_1x64x32>();
136137
passed &= test<bfloat16, float, 32, 64, 16, 2, class pvc_bf16_32x64x16>();
138+
passed &= test<bfloat16, float, 32, 64, 32, 2, class pvc_bf16_32x64x32>();
137139
#endif
138140
break;
139141
}

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,11 @@ size_t matrix_size = -1;
483483
MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
484484
test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16,
485485
MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
486+
test<bfloat16, float, VnniFactor, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32, MCache1,
487+
NCache1, /*KCache1*/ 32, MCache2, NCache2, KCache2>(matrix_size);
488+
test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32,
489+
MCache1, NCache1, /*KCache1*/ 32, MCache2, NCache2, KCache2>(
490+
matrix_size);
486491
#endif
487492
break;
488493
}

sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,12 @@ int main() {
127127
bfloat16, float>();
128128
res += gemm_row_major<1, 64, 16, class bf16_1x64x16, bfloat16, bfloat16,
129129
float>();
130+
res += gemm_row_major<1, 64, 32, class bf16_1x64x32, bfloat16, bfloat16,
131+
float>();
130132
res += gemm_row_major<32, 64, 16, class bf16_32x64x16, bfloat16,
131133
bfloat16, float>();
134+
res += gemm_row_major<32, 64, 32, class bf16_32x64x32, bfloat16,
135+
bfloat16, float>();
132136
}
133137
break;
134138
}

0 commit comments

Comments
 (0)