-
Notifications
You must be signed in to change notification settings - Fork 769
[SYCL][Matrix] Add support for tf32 type #5920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
9b6ed22
016523f
108f04a
98480ab
4d7c661
e349b71
a910a25
682dff9
41d07cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,42 +70,65 @@ struct joint_matrix { | |
} | ||
}; | ||
|
||
template <typename Group, typename T, size_t NumRows, size_t NumCols, | ||
matrix_layout Layout = matrix_layout::row_major, | ||
// class tf32 should not hold actual data. It is a tag type only, an empty class | ||
// with no member variables. Morally, it is equivalent to an enumeration--it | ||
// just uses the type system to communicate the desired accuracy of arithmetic | ||
// computations. Users can't construct a tf32 | ||
namespace precision { | ||
class tf32 {}; | ||
} // namespace precision | ||
|
||
// Differentiating between the "element type" and the "storage element type" | ||
template <typename T> struct helper_traits { | ||
using element_type = T; | ||
using storage_element_type = T; | ||
using fill_argument_type = T; | ||
}; | ||
|
||
template <> struct helper_traits<precision::tf32> { | ||
using element_type = precision::tf32; | ||
using storage_element_type = float; | ||
using fill_argument_type = float; | ||
}; | ||
|
||
template <typename Group, typename Te, size_t NumRows, size_t NumCols, | ||
matrix_layout Layout = matrix_layout::row_major, typename Tm, | ||
access::address_space Space> | ||
inline __SYCL_ALWAYS_INLINE void | ||
joint_matrix_load(Group sg, | ||
joint_matrix<T, NumRows, NumCols, Layout, Group> &res, | ||
multi_ptr<T, Space> src, size_t stride, matrix_layout MemL) { | ||
joint_matrix<Te, NumRows, NumCols, Layout, Group> &res, | ||
multi_ptr<Tm, Space> src, size_t stride, matrix_layout MemL) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
T *Ptr = src.get(); | ||
// For non tf32 case, check that Te is the same that Tm | ||
Tm *Ptr = src.get(); | ||
using Ts = typename helper_traits<Te>::storage_element_type; | ||
switch (MemL) { | ||
default: | ||
assert(false && "Invalid Memory Layout!"); | ||
case matrix_layout::row_major: | ||
res.spvm = | ||
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols, | ||
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols, | ||
spv_matrix_layout_traits<Layout>::value>( | ||
Ptr, stride, __spv::MatrixLayout::RowMajor, | ||
spv_scope_traits<Group>::value); | ||
break; | ||
case matrix_layout::col_major: | ||
res.spvm = | ||
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols, | ||
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols, | ||
spv_matrix_layout_traits<Layout>::value>( | ||
Ptr, stride, __spv::MatrixLayout::ColumnMajor, | ||
spv_scope_traits<Group>::value); | ||
break; | ||
case matrix_layout::packed_a: | ||
res.spvm = | ||
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols, | ||
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols, | ||
spv_matrix_layout_traits<Layout>::value>( | ||
Ptr, stride, __spv::MatrixLayout::PackedA, | ||
spv_scope_traits<Group>::value); | ||
break; | ||
case matrix_layout::packed_b: | ||
res.spvm = | ||
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols, | ||
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols, | ||
spv_matrix_layout_traits<Layout>::value>( | ||
Ptr, stride, __spv::MatrixLayout::PackedB, | ||
spv_scope_traits<Group>::value); | ||
|
@@ -231,12 +254,16 @@ class wi_element { | |
std::size_t idx; | ||
|
||
public: | ||
using storage_element_type = typename helper_traits<T>::storage_element_type; | ||
wi_element(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat, | ||
std::size_t i) | ||
: M(Mat), idx(i) {} | ||
operator T() { | ||
operator storage_element_type() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return __spirv_VectorExtractDynamic(M.spvm, idx); | ||
// __spirv_VectorExtractDynamic returns storage_element_type | ||
storage_element_type elem = | ||
__spirv_VectorExtractDynamic<T, storage_element_type>(M.spvm, idx); | ||
return elem; | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
|
@@ -245,7 +272,10 @@ class wi_element { | |
|
||
explicit operator bool() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0); | ||
// __spirv_VectorExtractDynamic returns storage_element_type | ||
storage_element_type elems = | ||
__spirv_VectorExtractDynamic<T, storage_element_type>(M.spvm, idx); | ||
return elems != static_cast<storage_element_type>(0); | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
|
@@ -254,7 +284,9 @@ class wi_element { | |
|
||
template <typename T2> wi_element &operator=(const T2 &rhs) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx); | ||
// __spirv_VectorInsertDynamic takes storage_element_type as argument | ||
M.spvm = __spirv_VectorInsertDynamic( | ||
M.spvm, static_cast<storage_element_type>(rhs), idx); | ||
return *this; | ||
#else | ||
(void)rhs; | ||
|
@@ -279,10 +311,12 @@ class wi_element { | |
#if __SYCL_DEVICE_ONLY__ | ||
#define OP(op) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps it is a good opportunity to remove all these macros? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yubingex007-a11y, I remember you change this code to use macros and make it more compact. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sometimes macros are the best or only reasonable solution. #define OP something :-) |
||
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \ | ||
storage_element_type elems = \ | ||
__spirv_VectorExtractDynamic<T, storage_element_type>(M.spvm, idx); \ | ||
M.spvm = __spirv_VectorInsertDynamic( \ | ||
M.spvm, \ | ||
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \ | ||
op static_cast<T>(rhs)), \ | ||
static_cast<storage_element_type>( \ | ||
elems op static_cast<storage_element_type>(rhs)), \ | ||
idx); \ | ||
return *this; \ | ||
} | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,174 @@ | ||||||
// RUN: %clangxx -fsycl -O2 %s -o %t.out | ||||||
|
||||||
#include <sycl/sycl.hpp> | ||||||
#if (SYCL_EXT_ONEAPI_MATRIX == 2) | ||||||
#include <iostream> | ||||||
|
||||||
using namespace sycl; | ||||||
using namespace sycl::ext::oneapi::experimental::matrix; | ||||||
|
||||||
auto constexpr SG_SZ = 8; | ||||||
|
||||||
#define TM 8 | ||||||
#define TN SG_SZ | ||||||
#define TK 16 | ||||||
|
||||||
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix { | ||||||
public: | ||||||
T *mat; | ||||||
|
||||||
public: | ||||||
T *get_data() { return mat; } | ||||||
void set_data(T *data) { mat = data; } | ||||||
big_matrix(T *data) : mat(data) {} | ||||||
}; | ||||||
|
||||||
// this should be replaced with a DPC++ and spirv functions | ||||||
float round_to_tf32(float a) { | ||||||
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a); | ||||||
tmp_uint += 0x1000u; // Round up the 13th last bit | ||||||
tmp_uint &= 0xFFFFE000u; // Zero out the bottom 13 bits | ||||||
float ret = reinterpret_cast<float &>(tmp_uint); | ||||||
return ret; | ||||||
} | ||||||
|
||||||
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A, | ||||||
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C, | ||||||
size_t NUM_COLS_C> | ||||||
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C, | ||||||
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A, | ||||||
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) { | ||||||
size_t M = NUM_ROWS_C; | ||||||
size_t N = NUM_COLS_C; | ||||||
size_t K = NUM_COLS_A; | ||||||
|
||||||
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B); | ||||||
size_t NDRangeM = M / TM; | ||||||
size_t NDRangeN = N / TN; | ||||||
// buffer<float, 2> bufA(A.get_data(), range<2>(M, K)); | ||||||
buffer<float, 2> bufA(A.get_data(), range<2>(M, K)); | ||||||
buffer<float, 2> bufB(B.get_data(), range<2>(K, N)); | ||||||
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N)); | ||||||
|
||||||
queue q; | ||||||
q.submit([&](handler &cgh) { | ||||||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||||||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||||||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||||||
|
||||||
cgh.parallel_for<class imatrix>( | ||||||
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), | ||||||
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] | ||||||
|
||||||
{ | ||||||
// The submatrix API has to be accessed by all the workitems in a | ||||||
// subgroup these functions will be called once by the subgroup no | ||||||
// code divergence between the workitems | ||||||
const auto global_idx = spmd_item.get_global_id(0); | ||||||
const auto global_idy = spmd_item.get_global_id(1); | ||||||
const auto sg_startx = global_idx - spmd_item.get_local_id(0); | ||||||
const auto sg_starty = global_idy - spmd_item.get_local_id(1); | ||||||
|
||||||
sub_group sg = spmd_item.get_sub_group(); | ||||||
joint_matrix<precision::tf32, TM, TK> sub_a(sg); | ||||||
joint_matrix<precision::tf32, TK, TN, matrix_layout::packed_b> sub_b( | ||||||
sg); | ||||||
joint_matrix<float, TM, TN> sub_c(sg); | ||||||
joint_matrix_load(sg, sub_c, | ||||||
accC.get_pointer() + (sg_startx * TM) * N + | ||||||
sg_starty / SG_SZ * TN, | ||||||
N, matrix_layout::row_major); | ||||||
for (int k = 0; k < K; k += TK) { | ||||||
joint_matrix_load(sg, sub_a, | ||||||
accA.get_pointer() + (sg_startx * TM) * K + k, K, | ||||||
matrix_layout::row_major); | ||||||
// Assume we alreay in vnni format. | ||||||
joint_matrix_load(sg, sub_b, | ||||||
accB.get_pointer() + (k) * (N) + | ||||||
sg_starty / SG_SZ * TN, | ||||||
N, matrix_layout::packed_b); | ||||||
// If no rounding to tf32 function is called, the mad function will | ||||||
// work on truncated floats. | ||||||
// TODO: change signature of __spirv_VectorInsertDynamic to have | ||||||
// two types: matrix type can be different from value type | ||||||
for (int i = 0; i < sub_a.get_wi_data().length(); i++) { | ||||||
sub_a.get_wi_data()[i] = round_to_tf32(sub_a.get_wi_data()[i]); | ||||||
} | ||||||
for (int i = 0; i < sub_b.get_wi_data().length(); i++) { | ||||||
sub_b.get_wi_data()[i] = round_to_tf32(sub_b.get_wi_data()[i]); | ||||||
} | ||||||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||||||
} | ||||||
auto wi_slice_a = sub_a.get_wi_data(); | ||||||
for (int i = 0; i < wi_slice_a.length(); i++) { | ||||||
float elem = wi_slice_a[i]; | ||||||
wi_slice_a[i] *= 2; | ||||||
} | ||||||
joint_matrix_store(sg, sub_c, | ||||||
accC.get_pointer() + (sg_startx * TM) * N + | ||||||
sg_starty / SG_SZ * TN, | ||||||
N, matrix_layout::row_major); | ||||||
}); // parallel for | ||||||
}).wait(); | ||||||
} | ||||||
|
||||||
static constexpr size_t MATRIX_M = TM * 2; | ||||||
static constexpr size_t MATRIX_N = TN * 2; | ||||||
static constexpr size_t MATRIX_K = TK * 2; | ||||||
float A[MATRIX_M][MATRIX_K]; | ||||||
float B[MATRIX_K][MATRIX_N]; | ||||||
float C[MATRIX_M][MATRIX_N]; | ||||||
float D[MATRIX_M][MATRIX_N]; | ||||||
|
||||||
void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N, | ||||||
int K) { | ||||||
for (int m = 0; m < M; m++) | ||||||
for (int n = 0; n < N; n++) { | ||||||
for (int k = 0; k < K; k++) { | ||||||
float va = A_mem[m * K + k]; | ||||||
float vb = B_mem[k * N + n]; | ||||||
float acc = C_mem[m * N + n]; | ||||||
C_mem[m * N + n] = va * vb; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
int main() { | ||||||
for (int i = 0; i < MATRIX_M; i++) { | ||||||
for (int j = 0; j < MATRIX_K; j++) { | ||||||
A[i][j] = 1.0f * (i + j); | ||||||
} | ||||||
} | ||||||
for (int i = 0; i < MATRIX_K / 2; i++) { | ||||||
for (int j = 0; j < MATRIX_N * 2; j++) { | ||||||
B[i][j] = 2.0f * i + 3.0f * j; | ||||||
} | ||||||
} | ||||||
for (int i = 0; i < MATRIX_M; i++) { | ||||||
for (int j = 0; j < MATRIX_N; j++) { | ||||||
C[i][j] = 1.0; | ||||||
D[i][j] = 1.0; | ||||||
} | ||||||
} | ||||||
|
||||||
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
and add the right constructor in |
||||||
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D); | ||||||
big_matrix<float, MATRIX_M, MATRIX_K> MA((float *)&A); | ||||||
big_matrix<float, MATRIX_K, MATRIX_N> MB((float *)&B); | ||||||
matrix_multiply(MC, MA, MB); | ||||||
matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N, | ||||||
MATRIX_K / 2); | ||||||
|
||||||
bool res = true; | ||||||
for (int i = 0; i < MATRIX_M; i++) { | ||||||
for (int j = 0; j < MATRIX_N; j++) { | ||||||
if (C[i][j] != D[i][j]) | ||||||
res = false; | ||||||
} | ||||||
} | ||||||
if (res) | ||||||
std::cout << "passed\n"; | ||||||
else | ||||||
std::cout << "failed\n"; | ||||||
} | ||||||
#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since users shouldn't construct a tf32.