-
Notifications
You must be signed in to change notification settings - Fork 769
[SYCL][ext][CUDA] Use float as storage type for tf32 joint matrix #5870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
025cf7e
66b4e33
2d04406
9418f74
65fddfa
4d99f3f
f6cf7b8
35302b5
712af98
3530643
fa67ff9
3982001
8d2d11f
d8bc53f
bfc68d2
2f9b7d7
8a29c44
ca1d735
52c8e20
813aa4b
0630667
61b3d8f
23cb7da
be60cdd
618c807
1499836
b03e661
077e0f4
1b5503c
5b5bbcc
560b02d
13b1efb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,10 @@ enum class matrix_use { a, b, accumulator }; | |
|
||
enum class matrix_layout { row_major, col_major, packed_a, packed_b }; | ||
|
||
namespace precision { | ||
class tf32 {}; | ||
} // namespace precision | ||
|
||
template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent, | ||
size_t Cols = sycl::dynamic_extent, | ||
matrix_layout Layout = matrix_layout::row_major, | ||
|
@@ -81,18 +85,23 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) | |
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) | ||
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) | ||
|
||
// m16n16k8 tf32 | ||
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, a, 16, 8, float, 4) | ||
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, b, 8, 16, float, 4) | ||
|
||
#undef __SYCL_JOINT_MATRIX_OVERLOAD | ||
} // namespace experimental::matrix | ||
|
||
namespace detail { | ||
|
||
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use, | ||
template <typename S, typename T, | ||
sycl::ext::oneapi::experimental::matrix::matrix_use Use, | ||
size_t NumRows, size_t NumCols, | ||
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout, | ||
access::address_space Space, typename Cond = void> | ||
struct joint_matrix_load_impl { | ||
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, | ||
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, | ||
multi_ptr<T, Space> src, size_t stride); | ||
}; | ||
|
||
|
@@ -111,18 +120,19 @@ constexpr int get_layout_id< | |
return 1; | ||
} | ||
|
||
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use, | ||
template <typename S, typename T, | ||
sycl::ext::oneapi::experimental::matrix::matrix_use Use, | ||
size_t NumRows, size_t NumCols, | ||
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout, | ||
access::address_space Space> | ||
struct joint_matrix_load_impl< | ||
T, Use, NumRows, NumCols, Layout, Space, | ||
S, T, Use, NumRows, NumCols, Layout, Space, | ||
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental:: | ||
matrix::matrix_layout::row_major || | ||
Layout == sycl::ext::oneapi::experimental:: | ||
matrix::matrix_layout::col_major>> { | ||
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, | ||
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, | ||
multi_ptr<T, Space> src, size_t stride) { | ||
if constexpr (std::is_same<T, uint16_t>::value) { | ||
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get()); | ||
|
@@ -247,15 +257,27 @@ struct joint_matrix_load_impl< | |
get_layout_id<Layout>()); | ||
} | ||
} else if constexpr (std::is_same<T, float>::value) { | ||
if constexpr (NumRows == 16 && NumCols == 16) { | ||
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 8 && NumCols == 32) { | ||
__hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 32 && NumCols == 8) { | ||
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
if (std::is_same<S, float>::value) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason for this not to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe extensions that need to be explicitly included are allowed to use C++17 features. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Allowed in the sense that tests don't appear to fail due to c++17 in extension namespace that fail due to c++17 stuff in other namespaces! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case I am happy to ensure c++17 is fully employed where appropriate in this extension in the follow on PR: #5964 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As long as it doesn't bleed into sycl.hpp then it should be fine. There should be a test that fails if it does. |
||
if constexpr (NumRows == 16 && NumCols == 16) { | ||
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 8 && NumCols == 32) { | ||
__hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 32 && NumCols == 8) { | ||
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} | ||
} else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix:: | ||
precision::tf32>::value) { | ||
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get()); | ||
if constexpr (NumRows == 16 && NumCols == 8) { | ||
__mma_tf32_m16n16k8_ld_a(reinterpret_cast<int32_t *>(res.data), | ||
tileptr, stride, get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 8 && NumCols == 16) { | ||
__mma_tf32_m16n16k8_ld_b(reinterpret_cast<int32_t *>(res.data), | ||
tileptr, stride, get_layout_id<Layout>()); | ||
} | ||
} | ||
} else if constexpr (std::is_same<T, double>::value) { | ||
if constexpr (Use == | ||
|
@@ -495,6 +517,10 @@ struct joint_matrix_mad_impl< | |
get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
} | ||
} | ||
} else if constexpr (M == 16 && N == 16 && K == 8) { | ||
__mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast<int32_t *>(A.data), | ||
reinterpret_cast<int32_t *>(B.data), C.data, | ||
get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
} else if constexpr (std::is_same<T1, double>::value) { | ||
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, | ||
get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
|
@@ -507,13 +533,18 @@ struct joint_matrix_mad_impl< | |
|
||
namespace experimental::matrix { | ||
|
||
template <typename Group, typename T, matrix_use Use, size_t NumRows, | ||
size_t NumCols, matrix_layout Layout, access::address_space Space> | ||
template <typename Group, typename S, typename T, matrix_use Use, | ||
size_t NumRows, size_t NumCols, matrix_layout Layout, | ||
access::address_space Space, | ||
std::enable_if_t<std::is_same<S, T>::value || | ||
(std::is_same<S, precision::tf32>::value && | ||
std::is_same<T, float>::value), | ||
bool> = true> | ||
void joint_matrix_load( | ||
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res, | ||
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res, | ||
multi_ptr<T, Space> src, size_t stride) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols, | ||
sycl::ext::oneapi::detail::joint_matrix_load_impl<S, T, Use, NumRows, NumCols, | ||
Layout, Space>{} | ||
.load(res, src, stride); | ||
#else | ||
|
@@ -573,6 +604,26 @@ joint_matrix_mad( | |
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
float float_to_tf32(float a) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is in sync with the fact that an element indexing of joint matrix of type tf32 is of type float. |
||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
int32_t tmp_int = __nvvm_f2tf32_rna(a); | ||
return __nvvm_bitcast_i2f(tmp_int); | ||
#else | ||
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a); | ||
tmp_uint += 0x1000u; | ||
float ret = reinterpret_cast<float &>(tmp_uint); | ||
return ret; | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
// This function just zeros out the bottom 13 bits of the tf32 type | ||
float tf32_to_float(float a) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a use case for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have renamed the function. |
||
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a); | ||
tmp_uint &= 0xFFFFE000u; | ||
float ret = reinterpret_cast<float &>(tmp_uint); | ||
return ret; | ||
} | ||
|
||
} // namespace experimental::matrix | ||
} // namespace oneapi | ||
} // namespace ext | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// REQUIRES: cuda | ||
|
||
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s | ||
|
||
// IMPORTANT: before updating sm version support beyond sm_86 read the following | ||
// NOTE! | ||
|
||
// NOTE: Technically the 'wrong' ptx instruction is called by | ||
// joint_matrix_load/joint_matrix_store in this case: notice that the load and | ||
// store instructions use shape m16n16k16, rather than the correct shape | ||
// m16n16k8. The 'wrong' ptx instruction is used because it returns the correct | ||
// SASS instructions for all existing supported sm versions: sm_80 and sm_86. | ||
// The reason for this ptx instruction redundancy is due to the ptx naming | ||
// convention for the mnk shape triple; however we cannot in principle a priori | ||
// know that future sm versions will behave in the same way and that this | ||
// redundancy will continue as future architecture is released. This should be | ||
// validated before supporting any sm versions beyond sm_86. The reason that we | ||
// choose to use the m16n16k16 instruction is that it allows the significant | ||
// advantage of being able to use a portable interface across Intel and Nvidia | ||
// backends. | ||
|
||
#include <CL/sycl.hpp> | ||
|
||
using namespace sycl; | ||
using namespace sycl::ext::oneapi::experimental::matrix; | ||
|
||
// M, N, K define the sizes of dimensions of the three matrix types (a, b, | ||
// accumulator) used per subgroup operation. | ||
constexpr int M = 16; // number of rows of accumulator, | ||
// number of cols of b. | ||
constexpr int N = 16; // number of cols of accumulator, | ||
// number of rows of a. | ||
constexpr int K = 8; // number of cols of a/number of rows of b. | ||
|
||
// float is used in this test as the storage type for tf32 | ||
float A[M * K]; | ||
float B[K * N]; | ||
float C[M * N]; | ||
float D[M * N]; | ||
|
||
int main() { | ||
|
||
buffer<float, 1> bufA(A, range<1>(M * K)); // will be used as tf32 | ||
buffer<float, 1> bufB(B, range<1>(K * N)); // will be used as tf32 | ||
buffer<float, 1> bufC(C, range<1>(M * N)); | ||
buffer<float, 1> bufD(D, range<1>(M * N)); | ||
|
||
queue q; | ||
|
||
q.submit([&](handler &cgh) { | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class row_row>( | ||
nd_range<2>({1, 32}, {1, 32}), | ||
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
sycl::sub_group sg = item.get_sub_group(); | ||
|
||
joint_matrix<precision::tf32, matrix_use::a, M, K, | ||
matrix_layout::row_major> | ||
sub_a; | ||
|
||
joint_matrix<precision::tf32, matrix_use::b, K, N, | ||
matrix_layout::row_major> | ||
sub_b; | ||
|
||
joint_matrix<float, matrix_use::accumulator, M, N, | ||
matrix_layout::row_major> | ||
sub_c; | ||
|
||
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_a, accA.get_pointer(), K); | ||
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} | ||
joint_matrix_load(sg, sub_b, accB.get_pointer(), N); | ||
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} | ||
joint_matrix_load(sg, sub_c, accC.get_pointer(), N); | ||
|
||
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} | ||
// Round a, b to tf32 | ||
for (auto i = 0; i < 4; ++i) | ||
sub_a.data[i] = float_to_tf32(sub_a.data[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be the expected way to perform the rounding, if users want to, but I am still find exposing ".data" is different from the element wise indexing we are currently doing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is exactly what we will do (but in a future PR): I switched the impl here to use |
||
|
||
for (auto i = 0; i < 4; ++i) | ||
sub_b.data[i] = float_to_tf32(sub_b.data[i]); | ||
|
||
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}} | ||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} | ||
joint_matrix_store(sg, sub_c, accD.get_pointer(), N); | ||
}); | ||
}); | ||
|
||
q.submit([&](handler &cgh) { | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class col_col>( | ||
nd_range<2>({1, 32}, {1, 32}), | ||
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
sycl::sub_group sg = item.get_sub_group(); | ||
|
||
joint_matrix<precision::tf32, matrix_use::a, M, K, | ||
matrix_layout::col_major> | ||
sub_a; | ||
|
||
joint_matrix<precision::tf32, matrix_use::b, K, N, | ||
matrix_layout::col_major> | ||
sub_b; | ||
|
||
joint_matrix<float, matrix_use::accumulator, M, N, | ||
matrix_layout::col_major> | ||
sub_c; | ||
|
||
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_a, accA.get_pointer(), K); | ||
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} | ||
joint_matrix_load(sg, sub_b, accB.get_pointer(), N); | ||
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, i32 {{.*}}) #{{.*}} | ||
joint_matrix_load(sg, sub_c, accC.get_pointer(), N); | ||
|
||
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} | ||
// Round a, b to tf32 | ||
for (auto i = 0; i < 4; ++i) | ||
sub_a.data[i] = float_to_tf32(sub_a.data[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does data contain WI portion of the matrix? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only dimensions allowed for wmma for tf32 are m=16, n=16, k=8. This means that A, B are 16x8, 8x16 matrices, respectively. wmma operations are executed by a single warp, ie 32 threads. So when A, B are loaded into registers, each thread holds a fragment of each matrix, with dimensions 8x16/32 = 4 elements. See here https://github.com/intel/llvm/pull/5870/files/618c80750930b0eaec8cde468c880d52ba54c80c#diff-f71a436bdeda598b29caad471fa637a2844a12f38fe4e85b15b2ccb37bd09833R37 that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds good, so this is really implementing the element indexing we added in https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp#L68 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There's no reason to not match the current syntax you use: we waited to implement the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only caveat is that ideally we would add to the wi_data API in the way described here: #5964 (comment)
hdelan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for (auto i = 0; i < 4; ++i) | ||
sub_b.data[i] = float_to_tf32(sub_b.data[i]); | ||
|
||
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}} | ||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) #{{.*}} | ||
joint_matrix_store(sg, sub_c, accD.get_pointer(), N); | ||
}); | ||
}); | ||
|
||
return 0; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am tagging @yubingex007-a11y here as changing the type of the load will be necessary to handle tf32 case: type of memory can be difference from type of joint matrix.
However, @JackAKirk, we should restrict this flexibility to only tf32.
Can this work in the case of bfloat16? load from float to bfloat16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the final bfloat16 cuda impl is ready now using the old API (#5964).
Sounds fine to restrict the flexibility: I think the way this is implemented it already does restrict it to the tf32 type. If we add subbyte/single-bit cases then I think this would also encounter type of memory can be difference from type of joint matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently restricted to only being used by tf32 when the other datatype is float. See https://github.com/intel/llvm/pull/5870/files/618c80750930b0eaec8cde468c880d52ba54c80c#diff-f71a436bdeda598b29caad471fa637a2844a12f38fe4e85b15b2ccb37bd09833R539