Skip to content

[SYCL][Matrix] Add support for tf32 type #5920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
14 changes: 8 additions & 6 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
#endif

#ifdef __SYCL_DEVICE_ONLY__
template <typename T, std::size_t R, std::size_t C,
template <typename T, typename Ts, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spirv_JointMatrixLoadINTEL(Ts *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

Expand Down Expand Up @@ -97,16 +97,18 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
template <typename T, typename Ts, std::size_t R, std::size_t C,
__spv::MatrixLayout U,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
extern SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic(
__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
template <typename T, typename Ts, std::size_t R, std::size_t C,
__spv::MatrixLayout U,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *,
T val, size_t i);
Ts val, size_t i);

#ifndef __SPIRV_BUILTIN_DECLARATIONS__
#error \
Expand Down
64 changes: 49 additions & 15 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,42 +70,65 @@ struct joint_matrix {
}
};

template <typename Group, typename T, size_t NumRows, size_t NumCols,
matrix_layout Layout = matrix_layout::row_major,
// class tf32 should not hold actual data. It is a tag type only, an empty class
// with no member variables. Morally, it is equivalent to an enumeration--it
// just uses the type system to communicate the desired accuracy of arithmetic
// computations. Users can't construct a tf32
namespace precision {
class tf32 {};
Copy link
Contributor

@MrSidims MrSidims Aug 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since users shouldn't construct a tf32.

Suggested change
class tf32 {};
class tf32 {
tf32() = delete;
};

} // namespace precision

// Differentiating between the "element type" and the "storage element type"
template <typename T> struct helper_traits {
using element_type = T;
using storage_element_type = T;
using fill_argument_type = T;
};

template <> struct helper_traits<precision::tf32> {
using element_type = precision::tf32;
using storage_element_type = float;
using fill_argument_type = float;
};

template <typename Group, typename Te, size_t NumRows, size_t NumCols,
matrix_layout Layout = matrix_layout::row_major, typename Tm,
access::address_space Space>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_load(Group sg,
joint_matrix<T, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride, matrix_layout MemL) {
joint_matrix<Te, NumRows, NumCols, Layout, Group> &res,
multi_ptr<Tm, Space> src, size_t stride, matrix_layout MemL) {
#ifdef __SYCL_DEVICE_ONLY__
T *Ptr = src.get();
// For non tf32 case, check that Te is the same that Tm
Tm *Ptr = src.get();
using Ts = typename helper_traits<Te>::storage_element_type;
switch (MemL) {
default:
assert(false && "Invalid Memory Layout!");
case matrix_layout::row_major:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Ptr, stride, __spv::MatrixLayout::RowMajor,
spv_scope_traits<Group>::value);
break;
case matrix_layout::col_major:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
spv_scope_traits<Group>::value);
break;
case matrix_layout::packed_a:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Ptr, stride, __spv::MatrixLayout::PackedA,
spv_scope_traits<Group>::value);
break;
case matrix_layout::packed_b:
res.spvm =
__spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
__spirv_JointMatrixLoadINTEL<Te, Ts, NumRows, NumCols,
spv_matrix_layout_traits<Layout>::value>(
Ptr, stride, __spv::MatrixLayout::PackedB,
spv_scope_traits<Group>::value);
Expand Down Expand Up @@ -231,12 +254,16 @@ class wi_element {
std::size_t idx;

public:
using storage_element_type = typename helper_traits<T>::storage_element_type;
wi_element(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}
operator T() {
operator storage_element_type() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
// __spirv_VectorExtractDynamic returns storage_element_type
storage_element_type elem =
__spirv_VectorExtractDynamic<T, storage_element_type>(M.spvm, idx);
return elem;
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -245,7 +272,10 @@ class wi_element {

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
// __spirv_VectorExtractDynamic returns storage_element_type
storage_element_type elems =
__spirv_VectorExtractDynamic<T, storage_element_type>(M.spvm, idx);
return elems != static_cast<storage_element_type>(0);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -254,7 +284,9 @@ class wi_element {

template <typename T2> wi_element &operator=(const T2 &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx);
// __spirv_VectorInsertDynamic takes storage_element_type as argument
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, static_cast<storage_element_type>(rhs), idx);
return *this;
#else
(void)rhs;
Expand All @@ -279,10 +311,12 @@ class wi_element {
#if __SYCL_DEVICE_ONLY__
#define OP(op) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it is a good opportunity to remove all these macros?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yubingex007-a11y, I remember you change this code to use macros and make it more compact.
The code was before expanded for each of the ops. Bing changed it to remove the redundancy.
@keryell what do you suggest we should use instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes macros are the best or only reasonable solution.
In that case use protected names like __DPC_SYCL_OP or whatever to avoid the case where a user decides to use in her program:

#define OP something

:-)

template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
storage_element_type elems = \
__spirv_VectorExtractDynamic<T, storage_element_type>(M.spvm, idx); \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
op static_cast<T>(rhs)), \
static_cast<storage_element_type>( \
elems op static_cast<storage_element_type>(rhs)), \
idx); \
return *this; \
}
Expand Down
174 changes: 174 additions & 0 deletions sycl/test/matrix/matrix-tf32-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// RUN: %clangxx -fsycl -O2 %s -o %t.out

#include <sycl/sycl.hpp>
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
#include <iostream>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

auto constexpr SG_SZ = 8;

#define TM 8
#define TN SG_SZ
#define TK 16

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

// this should be replaced with a DPC++ and spirv functions
float round_to_tf32(float a) {
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
tmp_uint += 0x1000u; // Round up the 13th last bit
tmp_uint &= 0xFFFFE000u; // Zero out the bottom 13 bits
float ret = reinterpret_cast<float &>(tmp_uint);
return ret;
}

template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
size_t NUM_COLS_C>
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
size_t M = NUM_ROWS_C;
size_t N = NUM_COLS_C;
size_t K = NUM_COLS_A;

assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B);
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;
// buffer<float, 2> bufA(A.get_data(), range<2>(M, K));
buffer<float, 2> bufA(A.get_data(), range<2>(M, K));
buffer<float, 2> bufB(B.get_data(), range<2>(K, N));
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));

queue q;
q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class imatrix>(
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]

{
// The submatrix API has to be accessed by all the workitems in a
// subgroup these functions will be called once by the subgroup no
// code divergence between the workitems
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<precision::tf32, TM, TK> sub_a(sg);
joint_matrix<precision::tf32, TK, TN, matrix_layout::packed_b> sub_b(
sg);
joint_matrix<float, TM, TN> sub_c(sg);
joint_matrix_load(sg, sub_c,
accC.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
for (int k = 0; k < K; k += TK) {
joint_matrix_load(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * K + k, K,
matrix_layout::row_major);
// Assume we alreay in vnni format.
joint_matrix_load(sg, sub_b,
accB.get_pointer() + (k) * (N) +
sg_starty / SG_SZ * TN,
N, matrix_layout::packed_b);
// If no rounding to tf32 function is called, the mad function will
// work on truncated floats.
// TODO: change signature of __spirv_VectorInsertDynamic to have
// two types: matrix type can be different from value type
for (int i = 0; i < sub_a.get_wi_data().length(); i++) {
sub_a.get_wi_data()[i] = round_to_tf32(sub_a.get_wi_data()[i]);
}
for (int i = 0; i < sub_b.get_wi_data().length(); i++) {
sub_b.get_wi_data()[i] = round_to_tf32(sub_b.get_wi_data()[i]);
}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
}
auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
float elem = wi_slice_a[i];
wi_slice_a[i] *= 2;
}
joint_matrix_store(sg, sub_c,
accC.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
}).wait();
}

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
float A[MATRIX_M][MATRIX_K];
float B[MATRIX_K][MATRIX_N];
float C[MATRIX_M][MATRIX_N];
float D[MATRIX_M][MATRIX_N];

void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N,
int K) {
for (int m = 0; m < M; m++)
for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k++) {
float va = A_mem[m * K + k];
float vb = B_mem[k * N + n];
float acc = C_mem[m * N + n];
C_mem[m * N + n] = va * vb;
}
}
}

int main() {
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_K; j++) {
A[i][j] = 1.0f * (i + j);
}
}
for (int i = 0; i < MATRIX_K / 2; i++) {
for (int j = 0; j < MATRIX_N * 2; j++) {
B[i][j] = 2.0f * i + 3.0f * j;
}
}
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++) {
C[i][j] = 1.0;
D[i][j] = 1.0;
}
}

big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
big_matrix<float, MATRIX_M, MATRIX_N> MC { C };

and add the right constructor in big_matrix

big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
big_matrix<float, MATRIX_M, MATRIX_K> MA((float *)&A);
big_matrix<float, MATRIX_K, MATRIX_N> MB((float *)&B);
matrix_multiply(MC, MA, MB);
matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N,
MATRIX_K / 2);

bool res = true;
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++) {
if (C[i][j] != D[i][j])
res = false;
}
}
if (res)
std::cout << "passed\n";
else
std::cout << "failed\n";
}
#endif // (SYCL_EXT_ONEAPI_MATRIX == 2)