Skip to content

[SYCL][CUDA] tf32 matrix MAD impl using uint32_t #5709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8)

// m16n16k8 fp19
__SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, a, 16, 8, int32_t, 4)
__SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, b, 8, 16, int32_t, 4)

#undef __SYCL_JOINT_MATRIX_OVERLOAD
} // namespace experimental::matrix

Expand Down Expand Up @@ -271,6 +275,15 @@ struct joint_matrix_load_impl<
__dmma_m8n8k4_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, uint32_t>::value) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
if constexpr (NumRows == 16 && NumCols == 8) {
__mma_tf32_m16n16k8_ld_a(res.data, tileptr, stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 8 && NumCols == 16) {
__mma_tf32_m16n16k8_ld_b(res.data, tileptr, stride,
get_layout_id<Layout>());
}
}
}
};
Expand Down Expand Up @@ -495,6 +508,9 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
} else if constexpr (M == 16 && N == 16 && K == 8) {
__mma_tf32_m16n16k8_mma_f32(D.data, A.data, B.data, C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T1, double>::value) {
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
Expand Down
118 changes: 118 additions & 0 deletions sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// REQUIRES: cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s

// IMPORTANT: before updating sm version support beyond sm_86 read the following
// NOTE!

// NOTE: Technically the 'wrong' ptx instruction is called by
// joint_matrix_load/joint_matrix_store in this case: notice that the load and
// store instructions use shape m16n16k16, rather than the correct shape
// m16n16k8. The 'wrong' ptx instruction is used because it returns the correct
// SASS instructions for all existing supported sm versions: sm_80 and sm_86.
// The reason for this ptx instruction redundancy is due to the ptx naming
// convention for the mnk shape triple; however we cannot in principle a priori
// know that future sm versions will behave in the same way and that this
// redundancy will continue as future architecture is released. This should be
// validated before supporting any sm versions beyond sm_86. The reason that we
// choose to use the m16n16k16 instruction is that it allows the significant
// advantage of being able to use a portable interface across Intel and Nvidia
// backends.

#include <CL/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

// M, N, K define the sizes of dimensions of the three matrix types (a, b,
// accumulator) used per subgroup operation.
constexpr int M = 16; // number of rows of accumulator,
// number of cols of b.
constexpr int N = 16; // number of cols of accumulator,
// number of rows of a.
constexpr int K = 8; // number of cols of a/number of rows of b.

uint32_t A[M * K];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment that uint32 is used here as a storage for fp19

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments, I've updated both tests now.

uint32_t B[K * N];
float C[M * N];
float D[M * N];

int main() {

buffer<uint32_t, 1> bufA(A, range<1>(M * K));
buffer<uint32_t, 1> bufB(B, range<1>(K * N));
buffer<float, 1> bufC(C, range<1>(M * N));
buffer<float, 1> bufD(D, range<1>(M * N));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a complete example in test/matrix where you show the necessary "manual" conversion function from float to fp19(uint32) during initialization and then from fp19 to float during accumulation and verification?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's here: intel/llvm-test-suite#881
for the float to fp19

uint32_t make_tf32(float const &x);

For the fp19 to float:

float tf32_to_fp32(uint32_t x);

(I'll rename both to e.g. make_fp19)

queue q;

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class row_row>(
nd_range<2>({1, 32}, {1, 32}),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<float, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major>
sub_a;

joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::row_major>
sub_b;

//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 %10, i32 %11, i32 %12, i32 %13, i32 %15, i32 %16, i32 %17, i32 %18, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %20, float %21, float %22, float %23, float %24, float %25, float %26, float %27, i32 16) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class col_col>(
nd_range<2>({1, 32}, {1, 32}),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<float, matrix_use::accumulator, M, N,
matrix_layout::col_major>
sub_c;

joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::col_major>
sub_a;

joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major>
sub_b;

//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 %10, i32 %11, i32 %12, i32 %13, i32 %15, i32 %16, i32 %17, i32 %18, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %20, float %21, float %22, float %23, float %24, float %25, float %26, float %27, i32 16) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

return 0;
};