Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[SYCL][CUDA] fp19 matrix mad test update #881

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions SYCL/Matrix/joint_matrix_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ using namespace sycl::ext::oneapi::experimental::matrix;
// Example usage of Nvidia matrix multiply.
// Optimizations such as memory paddings for avoiding bank conflicts are not
// included in this test which aids clarity for what is going on. This example
// forms a "Big matrix" corresponding to a single "TILE" using cuda example
// terminology. Multiple TILES can be used to construct yet larger matrices.
// This example uses row_major a, b, and accumulator matrices.
// forms a "Big matrix" corresponding to a single "TILE". Multiple TILES can be
// used to construct yet larger matrices. This example uses row_major a, b, and
// accumulator matrices.

// M, N, K define the unit sizes of dimensions of the three types (a, b,
// accumulator) of matrices per subgroup operation:
Expand All @@ -43,17 +43,28 @@ class TypeHelper;
template <typename T1, typename T2, size_t M, size_t K, size_t N>
using KernelName = class TypeHelper<T1, T2, M, K, N>;

float make_fp32(short x) {
unsigned int y = x;
float make_fp32(uint16_t x) {
uint32_t y = x;
y = y << 16;
float *res = reinterpret_cast<float *>(&y);
return *res;
}

unsigned short make_bf16(float x) {
int *res = reinterpret_cast<int *>(&x);
uint16_t make_bf16(float x) {
uint32_t *res = reinterpret_cast<uint32_t *>(&x);
*res = *res >> 16;
return (unsigned short)*res;
return (uint16_t)*res;
}

uint32_t make_tf32(float const &x) {
uint32_t res = reinterpret_cast<uint32_t const &>(x);
res += 0x1000u;
return res;
}

float tf32_to_fp32(uint32_t x) {
uint32_t bits = (x & ~0x1fffu);
return reinterpret_cast<float const &>(bits);
}

template <typename T1, typename T2, size_t Big_N, size_t Big_K>
Expand All @@ -63,9 +74,11 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
if constexpr (std::is_same<T1, uint16_t>::value) {
for (int k = 0; k < Big_K; k++)
res += make_fp32(A[m * Big_K + k]) * make_fp32(B[k * Big_N + n]);
} else if constexpr (std::is_same<T1, uint32_t>::value) {
for (int k = 0; k < Big_K; k++)
res += tf32_to_fp32(A[m * Big_K + k]) * tf32_to_fp32(B[k * Big_N + n]);
} else {
for (int k = 0; k < Big_K; k++)

res +=
static_cast<T2>(A[m * Big_K + k]) * static_cast<T2>(B[k * Big_N + n]);
}
Expand Down Expand Up @@ -105,6 +118,14 @@ void test() {
for (int i = 0; i < Big_K * Big_N; i++) {
B[i] = make_bf16(0.1f * (i % 10));
}
} else if constexpr (std::is_same<T1, uint32_t>::value) {
for (int i = 0; i < Big_M * Big_K; i++) {
A[i] = make_tf32(1.0f * (i % 10));
}

for (int i = 0; i < Big_K * Big_N; i++) {
B[i] = make_tf32(1.0f * (i % 10));
}
} else {
for (int i = 0; i < Big_M * Big_K; i++) {
A[i] = i % 100;
Expand All @@ -131,15 +152,15 @@ void test() {
range<2> GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP};

cgh.parallel_for<KernelName<T1, T2, M, K, N>>(
nd_range<2>(GlobalRange, LocalRange), [=
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();
const auto m =
item.get_group()
.get_id()[0]; // row id of current submatrix of BIG C matrix
const auto n =
item.get_group().get_id()[1]; // column id of current
// submatrix of BIG C matrix
nd_range<2>(GlobalRange, LocalRange),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
auto sg = item.get_sub_group();
auto Group = item.get_group();
const auto m = Group.get_group_id()[0];
// row id of current submatrix of BIG C matrix
const auto n = Group.get_group_id()[1];
// column id of current
// submatrix of BIG C matrix

joint_matrix<T1, matrix_use::a, M, K, matrix_layout::row_major> sub_a;

Expand Down Expand Up @@ -203,10 +224,13 @@ int main() {

test<double, double, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 4, 8>();

// A/B bf16
// A/B bfloat16 (using the bfloat16 storage type uint16_t directly)
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 16, 16>();
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8, 16, 32>();
test<uint16_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32, 16, 8>();

// A/B tf32 (using the tf32 storage type uint32_t directly)
test<uint32_t, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8, 16>();

return 0;
};