Skip to content

Commit 086f6b2

Browse files
authored
[SYCL][CUDA][Matrix] Adding test case for tf32 (intel/llvm-test-suite#963)
Test for intel#5870
1 parent cc44c09 commit 086f6b2

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

SYCL/Matrix/joint_matrix_tensorcore.cpp

+21-10
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
7474
}
7575

7676
template <typename T1, typename T2, size_t Sub_Tiles_M, size_t Sub_Tiles_K,
77-
size_t Sub_Tiles_N, size_t M, size_t K, size_t N>
77+
size_t Sub_Tiles_N, size_t M, size_t K, size_t N, typename T3 = T1>
7878
void test() {
7979

8080
constexpr auto Big_M =
@@ -131,19 +131,19 @@ void test() {
131131
range<2> GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP};
132132

133133
cgh.parallel_for<KernelName<T1, T2, M, K, N>>(
134-
nd_range<2>(GlobalRange, LocalRange), [=
135-
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
134+
nd_range<2>(GlobalRange, LocalRange),
135+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
136136
sycl::sub_group sg = item.get_sub_group();
137137
const auto m =
138-
item.get_group()
139-
.get_id()[0]; // row id of current submatrix of BIG C matrix
138+
item.get_group().get_group_id()[0]; // row id of current submatrix
139+
// of BIG C matrix
140140
const auto n =
141-
item.get_group().get_id()[1]; // column id of current
142-
// submatrix of BIG C matrix
141+
item.get_group().get_group_id()[1]; // column id of current
142+
// submatrix of BIG C matrix
143143

144-
joint_matrix<T1, matrix_use::a, M, K, matrix_layout::row_major> sub_a;
144+
joint_matrix<T3, matrix_use::a, M, K, matrix_layout::row_major> sub_a;
145145

146-
joint_matrix<T1, matrix_use::b, K, N, matrix_layout::row_major> sub_b;
146+
joint_matrix<T3, matrix_use::b, K, N, matrix_layout::row_major> sub_b;
147147

148148
joint_matrix<T2, matrix_use::accumulator, M, N,
149149
matrix_layout::row_major>
@@ -163,6 +163,14 @@ void test() {
163163
accB.get_pointer() + (k * K * Big_N) + (n * N),
164164
Big_N);
165165

166+
// Convert values if using tf32
167+
if constexpr (std::is_same<T3, precision::tf32>::value) {
168+
for (auto i = 0; i < 4; ++i) {
169+
sub_a.data[i] = round_to_tf32(sub_a.data[i]);
170+
sub_b.data[i] = round_to_tf32(sub_b.data[i]);
171+
}
172+
}
173+
166174
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
167175
}
168176
joint_matrix_store(
@@ -182,7 +190,6 @@ void test() {
182190
};
183191

184192
int main() {
185-
186193
// A/B half, Accumulator float
187194
test<half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>();
188195
test<half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>();
@@ -208,5 +215,9 @@ int main() {
208215
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>();
209216
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>();
210217

218+
// A/B tf32
219+
test<float, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8, 16,
220+
precision::tf32>();
221+
211222
return 0;
212223
};

0 commit comments

Comments
 (0)