Skip to content

[SYCL][ext][CUDA] Use float as storage type for tf32 joint matrix #5870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
025cf7e
Added bfloat16 support for cuda backend.
JackAKirk Jan 25, 2022
66b4e33
deleted intel namespace bfloat16.
JackAKirk Jan 25, 2022
2d04406
Format.
JackAKirk Jan 25, 2022
9418f74
Changed extension macro name.
JackAKirk Jan 25, 2022
65fddfa
Merge branch 'sycl' into bf16-cvt-ext
JackAKirk Feb 17, 2022
4d99f3f
fixed test.
JackAKirk Feb 17, 2022
f6cf7b8
Implemented fp19 mma using the natural storage type uint32_t.
JackAKirk Mar 2, 2022
35302b5
format
JackAKirk Mar 2, 2022
712af98
format
JackAKirk Mar 2, 2022
3530643
format
JackAKirk Mar 2, 2022
fa67ff9
added comment relating uint32_t to fp19
JackAKirk Mar 3, 2022
3982001
Used neg ptx7.0 builtin for unary minus
JackAKirk Mar 4, 2022
8d2d11f
Replaced SYCL_EXT_INTEL_BF16_CONVERSION.asciidoc with SYCL_EXT_ONEAPI…
JackAKirk Mar 7, 2022
d8bc53f
Merge branch 'sycl' into bf16-cvt-ext
JackAKirk Mar 8, 2022
bfc68d2
fp19 comments ->tf32
JackAKirk Mar 10, 2022
2f9b7d7
Merge branch 'sycl' into bf16-cvt-ext
JackAKirk Mar 15, 2022
8a29c44
Renamed extension to cover all bfloat16 funct.
JackAKirk Mar 15, 2022
ca1d735
Merge remote-tracking branch 'Jack/fp19-matrix-uint32_t' into tf32-jo…
hdelan Mar 16, 2022
52c8e20
Changing impl to accept float with boolean switch to tell whether tf3…
hdelan Mar 17, 2022
813aa4b
Final impl
hdelan Mar 23, 2022
0630667
Adding sycl test
hdelan Mar 23, 2022
61b3d8f
Device code check passing
hdelan Mar 23, 2022
23cb7da
Changing to precision enum
hdelan Mar 24, 2022
be60cdd
Responding to comments. Using precision::tf32 as empty class and floa…
hdelan Mar 31, 2022
618c807
Add device code check for conversion builtin
hdelan Mar 31, 2022
1499836
Merge remote-tracking branch 'Jack/bf16-cvt-ext' into tf32-joint-matrix
hdelan Apr 5, 2022
b03e661
Zeroing out bottom bits in sware impl
hdelan Apr 14, 2022
077e0f4
Changing names
hdelan Apr 14, 2022
1b5503c
Merge branch 'sycl' into tf32-joint-matrix
hdelan Apr 14, 2022
5b5bbcc
Removing newline
hdelan Apr 19, 2022
560b02d
Removing truncate function
hdelan Apr 21, 2022
13b1efb
Updating test
hdelan Apr 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ enum class matrix_use { a, b, accumulator };

enum class matrix_layout { row_major, col_major, packed_a, packed_b };

namespace precision {
class tf32 {};
} // namespace precision

template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
size_t Cols = sycl::dynamic_extent,
matrix_layout Layout = matrix_layout::row_major,
Expand Down Expand Up @@ -81,18 +85,23 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8)

// m16n16k8 tf32
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, a, 16, 8, float, 4)
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, b, 8, 16, float, 4)

#undef __SYCL_JOINT_MATRIX_OVERLOAD
} // namespace experimental::matrix

namespace detail {

template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
template <typename S, typename T,
sycl::ext::oneapi::experimental::matrix::matrix_use Use,
size_t NumRows, size_t NumCols,
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space, typename Cond = void>
struct joint_matrix_load_impl {
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
multi_ptr<T, Space> src, size_t stride);
};

Expand All @@ -111,18 +120,19 @@ constexpr int get_layout_id<
return 1;
}

template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
template <typename S, typename T,
sycl::ext::oneapi::experimental::matrix::matrix_use Use,
size_t NumRows, size_t NumCols,
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
access::address_space Space>
struct joint_matrix_load_impl<
T, Use, NumRows, NumCols, Layout, Space,
S, T, Use, NumRows, NumCols, Layout, Space,
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::row_major ||
Layout == sycl::ext::oneapi::experimental::
matrix::matrix_layout::col_major>> {
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am tagging @yubingex007-a11y here as changing the type of the load will be necessary to handle tf32 case: type of memory can be difference from type of joint matrix.
However, @JackAKirk, we should restrict this flexibility to only tf32.
Can this work in the case of bfloat16? load from float to bfloat16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the final bfloat16 cuda impl is ready now using the old API (#5964).

Sounds fine to restrict the flexibility: I think the way this is implemented it already does restrict it to the tf32 type. If we add subbyte/single-bit cases then I think this would also encounter type of memory can be difference from type of joint matrix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi_ptr<T, Space> src, size_t stride) {
if constexpr (std::is_same<T, uint16_t>::value) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
Expand Down Expand Up @@ -247,15 +257,27 @@ struct joint_matrix_load_impl<
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, float>::value) {
if constexpr (NumRows == 16 && NumCols == 16) {
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 8 && NumCols == 32) {
__hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 32 && NumCols == 8) {
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
if (std::is_same<S, float>::value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this not to be if constexpr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if constexpr is c++17 so it may have been removed for that reason. Although I noticed that for some reason c++17 is allowed in the extension namespace, although I don't understand this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe extensions that need to be explicitly included are allowed to use C++17 features.

Copy link
Contributor

@JackAKirk JackAKirk Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowed in the sense that tests don't appear to fail due to c++17 in extension namespace that fail due to c++17 stuff in other namespaces!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks.

Copy link
Contributor

@JackAKirk JackAKirk Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I am happy to ensure c++17 is fully employed where appropriate in this extension in the follow on PR: #5964

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I am happy to ensure c++17 is fully employed where appropriate in this extension in the follow on PR: #5964

As long as it doesn't bleed into sycl.hpp then it should be fine. There should be a test that fails if it does.

if constexpr (NumRows == 16 && NumCols == 16) {
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 8 && NumCols == 32) {
__hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 32 && NumCols == 8) {
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
precision::tf32>::value) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
if constexpr (NumRows == 16 && NumCols == 8) {
__mma_tf32_m16n16k8_ld_a(reinterpret_cast<int32_t *>(res.data),
tileptr, stride, get_layout_id<Layout>());
} else if constexpr (NumRows == 8 && NumCols == 16) {
__mma_tf32_m16n16k8_ld_b(reinterpret_cast<int32_t *>(res.data),
tileptr, stride, get_layout_id<Layout>());
}
}
} else if constexpr (std::is_same<T, double>::value) {
if constexpr (Use ==
Expand Down Expand Up @@ -495,6 +517,10 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
} else if constexpr (M == 16 && N == 16 && K == 8) {
__mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast<int32_t *>(A.data),
reinterpret_cast<int32_t *>(B.data), C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T1, double>::value) {
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
Expand All @@ -507,13 +533,18 @@ struct joint_matrix_mad_impl<

namespace experimental::matrix {

template <typename Group, typename T, matrix_use Use, size_t NumRows,
size_t NumCols, matrix_layout Layout, access::address_space Space>
template <typename Group, typename S, typename T, matrix_use Use,
size_t NumRows, size_t NumCols, matrix_layout Layout,
access::address_space Space,
std::enable_if_t<std::is_same<S, T>::value ||
(std::is_same<S, precision::tf32>::value &&
std::is_same<T, float>::value),
bool> = true>
void joint_matrix_load(
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
sycl::ext::oneapi::detail::joint_matrix_load_impl<S, T, Use, NumRows, NumCols,
Layout, Space>{}
.load(res, src, stride);
#else
Expand Down Expand Up @@ -573,6 +604,21 @@ joint_matrix_mad(
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

// This function rounds the bottom 13 bits up or down, and then zeros out the
// bottom bits
float round_to_tf32(float a) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
int32_t tmp_int = __nvvm_f2tf32_rna(a);
return __nvvm_bitcast_i2f(tmp_int);
#else
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
tmp_uint += 0x1000u;
tmp_uint &= 0xFFFFE000u;
float ret = reinterpret_cast<float &>(tmp_uint);
return ret;
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

} // namespace experimental::matrix
} // namespace oneapi
} // namespace ext
Expand Down
141 changes: 141 additions & 0 deletions sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// REQUIRES: cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s

// IMPORTANT: before updating sm version support beyond sm_86 read the following
// NOTE!

// NOTE: Technically the 'wrong' ptx instruction is called by
// joint_matrix_load/joint_matrix_store in this case: notice that the load and
// store instructions use shape m16n16k16, rather than the correct shape
// m16n16k8. The 'wrong' ptx instruction is used because it returns the correct
// SASS instructions for all existing supported sm versions: sm_80 and sm_86.
// The reason for this ptx instruction redundancy is due to the ptx naming
// convention for the mnk shape triple; however we cannot in principle a priori
// know that future sm versions will behave in the same way and that this
// redundancy will continue as future architecture is released. This should be
// validated before supporting any sm versions beyond sm_86. The reason that we
// choose to use the m16n16k16 instruction is that it allows the significant
// advantage of being able to use a portable interface across Intel and Nvidia
// backends.

#include <CL/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

// M, N, K define the sizes of dimensions of the three matrix types (a, b,
// accumulator) used per subgroup operation.
constexpr int M = 16; // number of rows of accumulator,
// number of cols of b.
constexpr int N = 16; // number of cols of accumulator,
// number of rows of a.
constexpr int K = 8; // number of cols of a/number of rows of b.

// float is used in this test as the storage type for tf32
float A[M * K];
float B[K * N];
float C[M * N];
float D[M * N];

int main() {

buffer<float, 1> bufA(A, range<1>(M * K)); // will be used as tf32
buffer<float, 1> bufB(B, range<1>(K * N)); // will be used as tf32
buffer<float, 1> bufC(C, range<1>(M * N));
buffer<float, 1> bufD(D, range<1>(M * N));

queue q;

q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class row_row>(
nd_range<2>({1, 32}, {1, 32}),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<precision::tf32, matrix_use::a, M, K,
matrix_layout::row_major>
sub_a;

joint_matrix<precision::tf32, matrix_use::b, K, N,
matrix_layout::row_major>
sub_b;

joint_matrix<float, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);

// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
// Round a, b to tf32
for (auto i = 0; i < 4; ++i)
sub_a.data[i] = round_to_tf32(sub_a.data[i]);

for (auto i = 0; i < 4; ++i)
sub_b.data[i] = round_to_tf32(sub_b.data[i]);

//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class col_col>(
nd_range<2>({1, 32}, {1, 32}),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<precision::tf32, matrix_use::a, M, K,
matrix_layout::col_major>
sub_a;

joint_matrix<precision::tf32, matrix_use::b, K, N,
matrix_layout::col_major>
sub_b;

joint_matrix<float, matrix_use::accumulator, M, N,
matrix_layout::col_major>
sub_c;

//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, i32 {{.*}}) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);

// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
// Round a, b to tf32
for (auto i = 0; i < 4; ++i)
sub_a.data[i] = round_to_tf32(sub_a.data[i]);

for (auto i = 0; i < 4; ++i)
sub_b.data[i] = round_to_tf32(sub_b.data[i]);

//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

return 0;
};