Skip to content

Commit 6f8e456

Browse files
authored
[SYCL] Represent JointMatrixINTEL type as extension type (#8343)
This patch is build on top of https://reviews.llvm.org/D141008 It adds: ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type which is represented as a pointer to a structure to LLVM extension type with the parameters that follow SPIR-V JointMatrixINTEL type. The expected representation is: target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, %use%, (optional) %element_type_interpretation%) Better approach is to introduce joint matrix type to clang, but it's off the table now, since we are lacking OpenCL spec. [The SPIR-V spec](https://github.com/intel/llvm/blob/389bdaf40bafb0eb139af3c19b01c846b4497ed4/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc) --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent d666b95 commit 6f8e456

File tree

5 files changed

+115
-82
lines changed

5 files changed

+115
-82
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

+84-59
Original file line numberDiff line numberDiff line change
@@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
5151
StringRef suffix) {
5252
SmallString<256> TypeName;
5353
llvm::raw_svector_ostream OS(TypeName);
54-
// If RD is spirv_JointMatrixINTEL type, mangle differently.
55-
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
56-
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
57-
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
58-
ArrayRef<TemplateArgument> TemplateArgs =
59-
TemplateDecl->getTemplateArgs().asArray();
60-
OS << "spirv.JointMatrixINTEL.";
61-
for (auto &TemplateArg : TemplateArgs) {
62-
OS << "_";
63-
if (TemplateArg.getKind() == TemplateArgument::Type) {
64-
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
65-
if (TTy->isIntegerTy()) {
66-
switch (TTy->getIntegerBitWidth()) {
67-
case 8:
68-
OS << "char";
69-
break;
70-
case 16:
71-
OS << "short";
72-
break;
73-
case 32:
74-
OS << "int";
75-
break;
76-
case 64:
77-
OS << "long";
78-
break;
79-
default:
80-
OS << "i" << TTy->getIntegerBitWidth();
81-
break;
82-
}
83-
} else if (TTy->isHalfTy()) {
84-
OS << "half";
85-
} else if (TTy->isFloatTy()) {
86-
OS << "float";
87-
} else if (TTy->isDoubleTy()) {
88-
OS << "double";
89-
} else if (TTy->isBFloatTy()) {
90-
OS << "bfloat16";
91-
} else if (TTy->isStructTy()) {
92-
StringRef LlvmTyName = TTy->getStructName();
93-
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
94-
if (LlvmTyName.startswith("class.sycl::") ||
95-
LlvmTyName.startswith("class.__sycl_internal::"))
96-
LlvmTyName = LlvmTyName.rsplit("::").second;
97-
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
98-
LlvmTyName != "tf32")
99-
llvm_unreachable("Wrong matrix base type!");
100-
OS << LlvmTyName;
101-
} else {
102-
llvm_unreachable("Wrong matrix base type!");
103-
}
104-
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
105-
OS << TemplateArg.getAsIntegral();
106-
}
107-
}
108-
Ty->setName(OS.str());
109-
return;
110-
}
111-
}
112-
}
11354
OS << RD->getKindName() << '.';
11455

11556
// FIXME: We probably want to make more tweaks to the printing policy. For
@@ -460,6 +401,78 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
460401
return ResultType;
461402
}
462403

404+
template <bool NeedTypeInterpret = false>
405+
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
406+
ArrayRef<TemplateArgument> TemplateArgs,
407+
const unsigned Val = 0) {
408+
// TODO: we should actually have exactly 5 template parameters: 1 for
409+
// type and 4 for type parameters. But in previous version of the SPIR-V
410+
// spec we have Layout matrix type parameter, that was later removed.
411+
// Once we update to the newest version of the spec - this should be updated.
412+
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
413+
"Wrong JointMatrixINTEL template parameters number");
414+
// This is required to represent optional 'Component Type Interpretation'
415+
// parameter
416+
using ParamsType =
417+
typename std::conditional<NeedTypeInterpret, SmallVector<unsigned, 6>,
418+
SmallVector<unsigned, 5>>::type;
419+
ParamsType Params;
420+
if constexpr (NeedTypeInterpret)
421+
Params = {0, 0, 0, 0, 0, Val};
422+
else
423+
Params = {0, 0, 0, 0, 0};
424+
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
425+
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
426+
"Wrong JointMatrixINTEL template parameter");
427+
Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue();
428+
}
429+
return llvm::TargetExtType::get(CompTy->getContext(),
430+
"spirv.JointMatrixINTEL", {CompTy}, Params);
431+
}
432+
433+
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
434+
/// which is represented as a pointer to a structure to LLVM extension type
435+
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
436+
/// The expected representation is:
437+
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
438+
/// %use%, (optional) %element_type_interpretation%)
439+
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
440+
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
441+
ArrayRef<TemplateArgument> TemplateArgs =
442+
TemplateDecl->getTemplateArgs().asArray();
443+
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
444+
"1st JointMatrixINTEL template parameter must be type");
445+
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
446+
447+
// Per JointMatrixINTEL spec the type can have an optional
448+
// 'Component Type Interpretation' parameter. We should emit it in case
449+
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
450+
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
451+
// 'tf32' as 'float' types.
452+
if (CompTy->isStructTy()) {
453+
StringRef LlvmTyName = CompTy->getStructName();
454+
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
455+
if (LlvmTyName.startswith("class.sycl::") ||
456+
LlvmTyName.startswith("class.__sycl_internal::"))
457+
LlvmTyName = LlvmTyName.rsplit("::").second;
458+
if (LlvmTyName == "half") {
459+
CompTy = llvm::Type::getHalfTy(getLLVMContext());
460+
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
461+
} else if (LlvmTyName == "tf32") {
462+
CompTy = llvm::Type::getFloatTy(getLLVMContext());
463+
// 'tf32' interpretation is mapped to '0'
464+
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
465+
} else if (LlvmTyName == "bfloat16") {
466+
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
467+
// 'bfloat16' interpretation is mapped to '1'
468+
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
469+
} else {
470+
llvm_unreachable("Wrong matrix base type!");
471+
}
472+
}
473+
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
474+
}
475+
463476
/// ConvertType - Convert the specified type to its LLVM form.
464477
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
465478
T = Context.getCanonicalType(T);
@@ -745,6 +758,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
745758
llvm::Type *PointeeType = ConvertTypeForMem(ETy);
746759
if (PointeeType->isVoidTy())
747760
PointeeType = llvm::Type::getInt8Ty(getLLVMContext());
761+
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
762+
const Type *ClangETy = ETy.getTypePtrOrNull();
763+
if (ClangETy && ClangETy->isStructureOrClassType()) {
764+
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
765+
if (RD &&
766+
RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
767+
ResultType = ConvertSYCLJointMatrixINTELType(RD);
768+
break;
769+
}
770+
}
771+
}
772+
748773
unsigned AS = getTargetAddressSpace(ETy);
749774
ResultType = llvm::PointerType::get(PointeeType, AS);
750775
break;

clang/lib/CodeGen/CodeGenTypes.h

+8
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ class CodeGenTypes {
133133
/// memory representation is usually i8 or i32, depending on the target.
134134
llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false);
135135

136+
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
137+
/// which is represented as a pointer to a structure to LLVM extension type
138+
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
139+
/// The expected representation is:
140+
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
141+
/// %use%, (optional) %element_type_interpretation%)
142+
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);
143+
136144
/// GetFunctionType - Get the LLVM function type for \arg Info.
137145
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);
138146

clang/test/CodeGenSYCL/matrix.cpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
#include <stdint.h>
66

77
namespace __spv {
8-
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
8+
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
99
struct __spirv_JointMatrixINTEL;
1010
}
1111

12-
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
13-
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
12+
// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0)
13+
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {}
1414

15-
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
16-
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
15+
// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0)
16+
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {}
1717

18-
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
19-
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
18+
// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0)
19+
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {}
2020

2121
namespace sycl {
2222
class half {};
@@ -25,17 +25,17 @@ namespace sycl {
2525
}
2626
typedef sycl::half my_half;
2727

28-
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
29-
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
28+
// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0)
29+
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {}
3030

31-
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
32-
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
31+
// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1)
32+
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {}
3333

34-
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
35-
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
34+
// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0)
35+
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {}
3636

37-
// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
38-
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}
37+
// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0)
38+
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {}
3939

40-
// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
41-
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}
40+
// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0)
41+
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {}

sycl/test/matrix/legacy/matrix-int8-test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque
4-
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque
5-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque
3+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
4+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3, 0)
5+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 0)
66

77
#include <iostream>
88
#include <sycl/sycl.hpp>

sycl/test/matrix/matrix-int8-test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque
4-
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque
5-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque
3+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
4+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2)
5+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1)
66

77
#include <iostream>
88
#include <sycl/sycl.hpp>

0 commit comments

Comments
 (0)