Skip to content

Commit 84426d1

Browse files
authored
[SYCL][Joint Matrix Tests] Add fill/store/apply tests for 16x16x16, 32x64x16 (#12629)
1 parent 75f6cd2 commit 84426d1

8 files changed

+177
-179
lines changed

sycl/test-e2e/Matrix/SG32/element_wise_ops.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,16 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
// REQUIRES: matrix
8+
// REQUIRES: aspect-ext_intel_matrix
99
// REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943
10+
// SG size = 32 is not currently supported for SYCL Joint Matrix by IGC on DG2
11+
// UNSUPPORTED: gpu-intel-dg2
1012

1113
// RUN: %{build} -o %t.out
1214
// RUN: %{run} %t.out
1315

14-
#include <iostream>
15-
#include <sycl/sycl.hpp>
16+
#include "../common.hpp"
1617

17-
using namespace sycl;
18-
using namespace sycl::ext::oneapi::experimental::matrix;
19-
20-
constexpr size_t SG_SZ = 32;
21-
constexpr size_t TN = 16;
18+
#define SG_SZ 32
2219

2320
#include "../element_wise_ops_impl.hpp"

sycl/test-e2e/Matrix/XMX8/element_wise_ops.cpp

Lines changed: 0 additions & 22 deletions
This file was deleted.

sycl/test-e2e/Matrix/common.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,21 @@ float make_fp32(bfloat16 x) {
4242
return *res;
4343
}
4444

45-
template <typename Ta, typename Tb, typename Tc, unsigned int VF = 1>
45+
template <typename Ta, typename Tb, typename Tc, unsigned int VF = 1,
46+
typename F = std::nullptr_t>
4647
void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
4748
bool transpose_c = false, bool colmajor_a = false,
48-
bool colmajor_b = false) {
49+
bool colmajor_b = false, F &&lambda = {}) {
4950
for (unsigned int m = 0; m < M; m++) {
5051
for (unsigned int n = 0; n < N; n++) {
51-
for (unsigned int k = 0; k < K; k++) {
52+
int c_ind = transpose_c ? (n * M + m) : m * N + n;
53+
Tc acc = *(C + c_ind);
5254

55+
for (unsigned int k = 0; k < K; k++) {
5356
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
5457
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
55-
int c_ind = transpose_c ? (n * M + m) : m * N + n;
56-
5758
Ta *va = (Ta *)(A + a_ind * VF);
5859
Tb *vb = (Tb *)(B + b_ind * VF);
59-
Tc acc = *(C + c_ind);
6060

6161
for (unsigned int i = 0; i < VF; i++) {
6262
if constexpr (std::is_same_v<Ta, bfloat16> &&
@@ -74,9 +74,12 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
7474
else
7575
assert(false && "Unsupported type in matrix_multiply_ref.");
7676
}
77+
}
7778

78-
*(C + c_ind) = acc;
79+
if constexpr (!std::is_same_v<F, std::nullptr_t>) {
80+
lambda(acc);
7981
}
82+
*(C + c_ind) = acc;
8083
}
8184
}
8285
}
@@ -132,8 +135,7 @@ void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
132135
if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float> ||
133136
std::is_same_v<T, double>) {
134137
src[i * cols + j] = T(fdistr(dev));
135-
} else if constexpr (std::is_same_v<T, int8_t> ||
136-
std::is_same_v<T, int32_t>) {
138+
} else if constexpr (std::is_integral_v<T>) {
137139
src[i * cols + j] = T(idistr(dev));
138140
} else {
139141
assert(false && "Unsupported type in matrix_rand.");
@@ -170,8 +172,9 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
170172
}
171173
} else if constexpr (exact || std::is_same_v<T1, int32_t>) {
172174
if (src[i * cols + j] != ref[i * cols + j]) {
173-
std::cout << "Incorrect result in matrix." << "i: " << i
174-
<< ", j: " << j << ", Ref: " << ref[i * cols + j]
175+
std::cout << "Incorrect result in matrix."
176+
<< "i: " << i << ", j: " << j
177+
<< ", Ref: " << ref[i * cols + j]
175178
<< ", Val: " << src[i * cols + j] << "\n";
176179
return false;
177180
}

sycl/test-e2e/Matrix/element_wise_all_ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,4 @@
1414
// RUN: %{run} %t.out
1515

1616
#include "common.hpp"
17-
1817
#include "element_wise_all_ops_impl.hpp"

sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
89
template <typename T, size_t NUM_ROWS, size_t NUM_COLS>
910
void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
1011
const float ref) {
@@ -105,8 +106,11 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
105106

106107
// Avoid same kernel name for different types
107108
template <typename T, class name> class ewops_a {};
108-
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
109-
void test_ewops_a() {
109+
template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_a() {
110+
std::cout << "Test A " << SROWS << "x" << SCOLS << "\n";
111+
112+
static constexpr size_t NROWS = SROWS * 2;
113+
static constexpr size_t NCOLS = SCOLS * 2;
110114

111115
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add>>(
112116
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
@@ -135,64 +139,87 @@ void test_ewops_a() {
135139
T(5.0), T(2.0), 2.0,
136140
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
137141
}
142+
138143
// Avoid same kernel name for different types and numbers of columns
139-
template <typename T, size_t COLS, class name> class ewops_c {};
140-
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
141-
void test_ewops_c() {
144+
template <typename T, size_t ROWS, size_t COLS, class name> class ewops_c {};
145+
template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_c() {
146+
std::cout << "Test C " << SROWS << "x" << SCOLS << "\n";
142147

143-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_add>>(
148+
static constexpr size_t NROWS = SROWS * 2;
149+
static constexpr size_t NCOLS = SCOLS * 2;
150+
151+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
152+
ewops_c<T, SROWS, SCOLS, class c_add>>(
144153
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
145-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_sub>>(
154+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
155+
ewops_c<T, SROWS, SCOLS, class c_sub>>(
146156
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
147-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_mul>>(
157+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
158+
ewops_c<T, SROWS, SCOLS, class c_mul>>(
148159
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
149-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_div>>(
160+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
161+
ewops_c<T, SROWS, SCOLS, class c_div>>(
150162
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
151163
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
152-
ewops_c<T, SCOLS, class c_logical>>(
164+
ewops_c<T, SROWS, SCOLS, class c_logical>>(
153165
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
154-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_eq>>(
166+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
167+
ewops_c<T, SROWS, SCOLS, class c_eq>>(
155168
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
156-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ne>>(
169+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
170+
ewops_c<T, SROWS, SCOLS, class c_ne>>(
157171
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
158-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_gt>>(
172+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
173+
ewops_c<T, SROWS, SCOLS, class c_gt>>(
159174
T(5.0), T(2.0), 3.0,
160175
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
161-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_lt>>(
176+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
177+
ewops_c<T, SROWS, SCOLS, class c_lt>>(
162178
T(5.0), T(2.0), 2.0,
163179
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
164-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ge>>(
180+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
181+
ewops_c<T, SROWS, SCOLS, class c_ge>>(
165182
T(5.0), T(2.0), 3.0,
166183
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
167-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_le>>(
184+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
185+
ewops_c<T, SROWS, SCOLS, class c_le>>(
168186
T(5.0), T(2.0), 2.0,
169187
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
170188
}
171189

172190
int main() {
173-
static constexpr size_t TM = 8;
174-
175-
static constexpr size_t MATRIX_M = TM * 2;
176-
static constexpr size_t MATRIX_N = 32;
177-
static constexpr size_t MATRIX_K = 32;
178191
queue q;
179192
std::vector<combination> combinations =
180193
q.get_device()
181194
.get_info<sycl::ext::oneapi::experimental::info::device::
182195
matrix_combinations>();
196+
183197
for (unsigned int i = 0; i < combinations.size(); i++) {
184198
if (combinations[i].atype == matrix_type::bf16) {
185-
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
186-
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
187-
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 16>();
188-
break;
199+
200+
if (combinations[i].nsize == 0 ||
201+
(combinations[i].msize == 0 && combinations[i].nsize == 16)) {
202+
test_ewops_a<bfloat16, 8, 16>();
203+
test_ewops_c<float, 8, 16>();
204+
}
205+
206+
if (combinations[i].msize == 16 && combinations[i].nsize == 16) {
207+
test_ewops_c<float, 16, 16>();
208+
}
209+
210+
// This combination is not currently supported for sub group size = 32 in IGC
211+
#if (!defined(SG_SZ) || SG_SZ != 32)
212+
if (combinations[i].msize == 32 && combinations[i].nsize == 64) {
213+
test_ewops_c<float, 32, 64>();
189214
}
215+
#endif
216+
190217
if (combinations[i].nsize == 8) {
191-
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
192-
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 8>();
193-
break;
218+
test_ewops_a<bfloat16, 8, 16>();
219+
test_ewops_c<float, 8, 8>();
194220
}
195221
}
196222
}
223+
197224
return 0;
198225
}

sycl/test-e2e/Matrix/element_wise_ops.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,10 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
// REQUIRES: matrix
8+
// REQUIRES: aspect-ext_intel_matrix
99

1010
// RUN: %{build} -o %t.out
1111
// RUN: %{run} %t.out
1212

13-
#include <iostream>
14-
#include <sycl/sycl.hpp>
15-
16-
using namespace sycl;
17-
using namespace sycl::ext::oneapi::experimental::matrix;
18-
19-
#define SG_SZ 16
20-
constexpr size_t TN = 16;
21-
13+
#include "common.hpp"
2214
#include "element_wise_ops_impl.hpp"

0 commit comments

Comments
 (0)