|
5 | 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
| 8 | + |
8 | 9 | template <typename T, size_t NUM_ROWS, size_t NUM_COLS>
|
9 | 10 | void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
|
10 | 11 | const float ref) {
|
@@ -105,8 +106,11 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
|
105 | 106 |
|
106 | 107 | // Avoid same kernel name for different types
|
107 | 108 | template <typename T, class name> class ewops_a {};
|
108 |
| -template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS> |
109 |
| -void test_ewops_a() { |
| 109 | +template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_a() { |
| 110 | + std::cout << "Test A " << SROWS << "x" << SCOLS << "\n"; |
| 111 | + |
| 112 | + static constexpr size_t NROWS = SROWS * 2; |
| 113 | + static constexpr size_t NCOLS = SCOLS * 2; |
110 | 114 |
|
111 | 115 | verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add>>(
|
112 | 116 | T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
|
@@ -135,64 +139,87 @@ void test_ewops_a() {
|
135 | 139 | T(5.0), T(2.0), 2.0,
|
136 | 140 | [](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
|
137 | 141 | }
|
| 142 | + |
138 | 143 | // Avoid same kernel name for different types and numbers of columns
|
139 |
| -template <typename T, size_t COLS, class name> class ewops_c {}; |
140 |
| -template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS> |
141 |
| -void test_ewops_c() { |
| 144 | +template <typename T, size_t ROWS, size_t COLS, class name> class ewops_c {}; |
| 145 | +template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_c() { |
| 146 | + std::cout << "Test C " << SROWS << "x" << SCOLS << "\n"; |
142 | 147 |
|
143 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_add>>( |
| 148 | + static constexpr size_t NROWS = SROWS * 2; |
| 149 | + static constexpr size_t NCOLS = SCOLS * 2; |
| 150 | + |
| 151 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 152 | + ewops_c<T, SROWS, SCOLS, class c_add>>( |
144 | 153 | T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
|
145 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_sub>>( |
| 154 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 155 | + ewops_c<T, SROWS, SCOLS, class c_sub>>( |
146 | 156 | T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
|
147 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_mul>>( |
| 157 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 158 | + ewops_c<T, SROWS, SCOLS, class c_mul>>( |
148 | 159 | T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
|
149 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_div>>( |
| 160 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 161 | + ewops_c<T, SROWS, SCOLS, class c_div>>( |
150 | 162 | T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
|
151 | 163 | verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
|
152 |
| - ewops_c<T, SCOLS, class c_logical>>( |
| 164 | + ewops_c<T, SROWS, SCOLS, class c_logical>>( |
153 | 165 | T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
|
154 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_eq>>( |
| 166 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 167 | + ewops_c<T, SROWS, SCOLS, class c_eq>>( |
155 | 168 | T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
|
156 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ne>>( |
| 169 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 170 | + ewops_c<T, SROWS, SCOLS, class c_ne>>( |
157 | 171 | T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
|
158 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_gt>>( |
| 172 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 173 | + ewops_c<T, SROWS, SCOLS, class c_gt>>( |
159 | 174 | T(5.0), T(2.0), 3.0,
|
160 | 175 | [](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
|
161 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_lt>>( |
| 176 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 177 | + ewops_c<T, SROWS, SCOLS, class c_lt>>( |
162 | 178 | T(5.0), T(2.0), 2.0,
|
163 | 179 | [](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
|
164 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ge>>( |
| 180 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 181 | + ewops_c<T, SROWS, SCOLS, class c_ge>>( |
165 | 182 | T(5.0), T(2.0), 3.0,
|
166 | 183 | [](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
|
167 |
| - verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_le>>( |
| 184 | + verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, |
| 185 | + ewops_c<T, SROWS, SCOLS, class c_le>>( |
168 | 186 | T(5.0), T(2.0), 2.0,
|
169 | 187 | [](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
|
170 | 188 | }
|
171 | 189 |
|
172 | 190 | int main() {
|
173 |
| - static constexpr size_t TM = 8; |
174 |
| - |
175 |
| - static constexpr size_t MATRIX_M = TM * 2; |
176 |
| - static constexpr size_t MATRIX_N = 32; |
177 |
| - static constexpr size_t MATRIX_K = 32; |
178 | 191 | queue q;
|
179 | 192 | std::vector<combination> combinations =
|
180 | 193 | q.get_device()
|
181 | 194 | .get_info<sycl::ext::oneapi::experimental::info::device::
|
182 | 195 | matrix_combinations>();
|
| 196 | + |
183 | 197 | for (unsigned int i = 0; i < combinations.size(); i++) {
|
184 | 198 | if (combinations[i].atype == matrix_type::bf16) {
|
185 |
| - if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { |
186 |
| - test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>(); |
187 |
| - test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 16>(); |
188 |
| - break; |
| 199 | + |
| 200 | + if (combinations[i].nsize == 0 || |
| 201 | + (combinations[i].msize == 0 && combinations[i].nsize == 16)) { |
| 202 | + test_ewops_a<bfloat16, 8, 16>(); |
| 203 | + test_ewops_c<float, 8, 16>(); |
| 204 | + } |
| 205 | + |
| 206 | + if (combinations[i].msize == 16 && combinations[i].nsize == 16) { |
| 207 | + test_ewops_c<float, 16, 16>(); |
| 208 | + } |
| 209 | + |
| 210 | +// This combination is not currently supported for sub group size = 32 in IGC |
| 211 | +#if (!defined(SG_SZ) || SG_SZ != 32) |
| 212 | + if (combinations[i].msize == 32 && combinations[i].nsize == 64) { |
| 213 | + test_ewops_c<float, 32, 64>(); |
189 | 214 | }
|
| 215 | +#endif |
| 216 | + |
190 | 217 | if (combinations[i].nsize == 8) {
|
191 |
| - test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>(); |
192 |
| - test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 8>(); |
193 |
| - break; |
| 218 | + test_ewops_a<bfloat16, 8, 16>(); |
| 219 | + test_ewops_c<float, 8, 8>(); |
194 | 220 | }
|
195 | 221 | }
|
196 | 222 | }
|
| 223 | + |
197 | 224 | return 0;
|
198 | 225 | }
|
0 commit comments