@@ -74,7 +74,7 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
74
74
}
75
75
76
76
template <typename T1, typename T2, size_t Sub_Tiles_M, size_t Sub_Tiles_K,
77
- size_t Sub_Tiles_N, size_t M, size_t K, size_t N>
77
+ size_t Sub_Tiles_N, size_t M, size_t K, size_t N, typename T3 = T1 >
78
78
void test () {
79
79
80
80
constexpr auto Big_M =
@@ -131,19 +131,19 @@ void test() {
131
131
range<2 > GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP};
132
132
133
133
cgh.parallel_for <KernelName<T1, T2, M, K, N>>(
134
- nd_range<2 >(GlobalRange, LocalRange), [=
135
- ](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
134
+ nd_range<2 >(GlobalRange, LocalRange),
135
+ [= ](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
136
136
sycl::sub_group sg = item.get_sub_group ();
137
137
const auto m =
138
- item.get_group ()
139
- . get_id ()[ 0 ]; // row id of current submatrix of BIG C matrix
138
+ item.get_group (). get_group_id ()[ 0 ]; // row id of current submatrix
139
+ // of BIG C matrix
140
140
const auto n =
141
- item.get_group ().get_id ()[1 ]; // column id of current
142
- // submatrix of BIG C matrix
141
+ item.get_group ().get_group_id ()[1 ]; // column id of current
142
+ // submatrix of BIG C matrix
143
143
144
- joint_matrix<T1 , matrix_use::a, M, K, matrix_layout::row_major> sub_a;
144
+ joint_matrix<T3 , matrix_use::a, M, K, matrix_layout::row_major> sub_a;
145
145
146
- joint_matrix<T1 , matrix_use::b, K, N, matrix_layout::row_major> sub_b;
146
+ joint_matrix<T3 , matrix_use::b, K, N, matrix_layout::row_major> sub_b;
147
147
148
148
joint_matrix<T2, matrix_use::accumulator, M, N,
149
149
matrix_layout::row_major>
@@ -163,6 +163,14 @@ void test() {
163
163
accB.get_pointer () + (k * K * Big_N) + (n * N),
164
164
Big_N);
165
165
166
+ // Convert values if using tf32
167
+ if constexpr (std::is_same<T3, precision::tf32>::value) {
168
+ for (auto i = 0 ; i < 4 ; ++i) {
169
+ sub_a.data [i] = round_to_tf32 (sub_a.data [i]);
170
+ sub_b.data [i] = round_to_tf32 (sub_b.data [i]);
171
+ }
172
+ }
173
+
166
174
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
167
175
}
168
176
joint_matrix_store (
@@ -182,7 +190,6 @@ void test() {
182
190
};
183
191
184
192
int main () {
185
-
186
193
// A/B half, Accumulator float
187
194
test<half, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 , 16 >();
188
195
test<half, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >();
@@ -208,5 +215,9 @@ int main() {
208
215
test<uint16_t , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >();
209
216
test<uint16_t , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >();
210
217
218
+ // A/B tf32
219
+ test<float , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 8 , 16 ,
220
+ precision::tf32>();
221
+
211
222
return 0 ;
212
223
};
0 commit comments