Skip to content

Commit b9529a6

Browse files
committed
[SYCL] Represent JointMatrixINTEL type as extension type
Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 24ec33c commit b9529a6

File tree

3 files changed

+111
-76
lines changed

3 files changed

+111
-76
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

+85-59
Original file line numberDiff line numberDiff line change
@@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
5151
StringRef suffix) {
5252
SmallString<256> TypeName;
5353
llvm::raw_svector_ostream OS(TypeName);
54-
// If RD is spirv_JointMatrixINTEL type, mangle differently.
55-
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
56-
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
57-
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
58-
ArrayRef<TemplateArgument> TemplateArgs =
59-
TemplateDecl->getTemplateArgs().asArray();
60-
OS << "spirv.JointMatrixINTEL.";
61-
for (auto &TemplateArg : TemplateArgs) {
62-
OS << "_";
63-
if (TemplateArg.getKind() == TemplateArgument::Type) {
64-
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
65-
if (TTy->isIntegerTy()) {
66-
switch (TTy->getIntegerBitWidth()) {
67-
case 8:
68-
OS << "char";
69-
break;
70-
case 16:
71-
OS << "short";
72-
break;
73-
case 32:
74-
OS << "int";
75-
break;
76-
case 64:
77-
OS << "long";
78-
break;
79-
default:
80-
OS << "i" << TTy->getIntegerBitWidth();
81-
break;
82-
}
83-
} else if (TTy->isHalfTy()) {
84-
OS << "half";
85-
} else if (TTy->isFloatTy()) {
86-
OS << "float";
87-
} else if (TTy->isDoubleTy()) {
88-
OS << "double";
89-
} else if (TTy->isBFloatTy()) {
90-
OS << "bfloat16";
91-
} else if (TTy->isStructTy()) {
92-
StringRef LlvmTyName = TTy->getStructName();
93-
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
94-
if (LlvmTyName.startswith("class.sycl::") ||
95-
LlvmTyName.startswith("class.__sycl_internal::"))
96-
LlvmTyName = LlvmTyName.rsplit("::").second;
97-
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
98-
LlvmTyName != "tf32")
99-
llvm_unreachable("Wrong matrix base type!");
100-
OS << LlvmTyName;
101-
} else {
102-
llvm_unreachable("Wrong matrix base type!");
103-
}
104-
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
105-
OS << TemplateArg.getAsIntegral();
106-
}
107-
}
108-
Ty->setName(OS.str());
109-
return;
110-
}
111-
}
112-
}
11354
OS << RD->getKindName() << '.';
11455

11556
// FIXME: We probably want to make more tweaks to the printing policy. For
@@ -460,6 +401,79 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
460401
return ResultType;
461402
}
462403

404+
template <bool NeedTypeInterpret = false>
405+
llvm::Type* getJointMatrixINTELExtType(llvm::Type *CompTy,
406+
ArrayRef<TemplateArgument> TemplateArgs,
407+
const unsigned Val = 0) {
408+
// TODO: we should actually have exactly 5 template parameters: 1 for
409+
// type and 4 for type parameters. But in previous version of the SPIR-V
410+
// spec we have Layout matrix type parameter, that was later removed.
411+
// Once we update to the newest version of the spec - this should be updated.
412+
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
413+
"Wrong JointMatrixINTEL template parameters number");
414+
// This is required to represent optional Optional
415+
// 'Component Type Interpretation' parameter
416+
using ParamsType =
417+
typename std::conditional<NeedTypeInterpret, SmallVector<unsigned, 6>,
418+
SmallVector<unsigned, 5>>::type;
419+
ParamsType Params;
420+
if constexpr (NeedTypeInterpret)
421+
Params = { 0, 0, 0, 0, 0, Val};
422+
else
423+
Params = { 0, 0, 0, 0, 0};
424+
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
425+
assert(TemplateArgs[I].getKind() ==
426+
TemplateArgument::Integral &&
427+
"Wrong JointMatrixINTEL template parameter");
428+
Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue();
429+
}
430+
return llvm::TargetExtType::get(
431+
CompTy->getContext(), "spirv.JointMatrixINTEL", {CompTy}, Params);
432+
}
433+
434+
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
435+
/// which is represented as a pointer to a structure to LLVM extension type
436+
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
437+
/// The expected representation is:
438+
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
439+
/// %use%, (optional) %element_type_interpretation%)
440+
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
441+
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
442+
ArrayRef<TemplateArgument> TemplateArgs =
443+
TemplateDecl->getTemplateArgs().asArray();
444+
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
445+
"1st JointMatrixINTEL template parameter must be type");
446+
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
447+
448+
// Per JointMatrixINTEL spec the type can have an Optional
449+
// 'Component Type Interpretation' parameter. We should emit it in case
450+
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
451+
// matrix's components. Yet bfloat16 should be represented as 'int16' and
452+
// 'tf32' as 'float' types.
453+
if (CompTy->isStructTy()) {
454+
StringRef LlvmTyName = CompTy->getStructName();
455+
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
456+
if (LlvmTyName.startswith("class.sycl::") ||
457+
LlvmTyName.startswith("class.__sycl_internal::"))
458+
LlvmTyName = LlvmTyName.rsplit("::").second;
459+
if (LlvmTyName == "half") {
460+
CompTy = llvm::Type::getHalfTy(getLLVMContext());
461+
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
462+
} else if (LlvmTyName == "tf32") {
463+
CompTy = llvm::Type::getHalfTy(getLLVMContext());
464+
// 'tf32' interpretation is mapped to '0'
465+
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
466+
} else if (LlvmTyName == "bfloat16") {
467+
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
468+
// 'bfloat16' interpretation is mapped to '1'
469+
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
470+
} else {
471+
llvm_unreachable("Wrong matrix base type!");
472+
}
473+
}
474+
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
475+
}
476+
463477
/// ConvertType - Convert the specified type to its LLVM form.
464478
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
465479
T = Context.getCanonicalType(T);
@@ -745,6 +759,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
745759
llvm::Type *PointeeType = ConvertTypeForMem(ETy);
746760
if (PointeeType->isVoidTy())
747761
PointeeType = llvm::Type::getInt8Ty(getLLVMContext());
762+
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
763+
const Type* ClangETy = ETy.getTypePtrOrNull();
764+
if (ClangETy && ClangETy->isStructureOrClassType()) {
765+
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
766+
if (RD->getQualifiedNameAsString() ==
767+
"__spv::__spirv_JointMatrixINTEL") {
768+
ResultType = ConvertSYCLJointMatrixINTELType(RD);
769+
break;
770+
}
771+
}
772+
}
773+
748774
unsigned AS = getTargetAddressSpace(ETy);
749775
ResultType = llvm::PointerType::get(PointeeType, AS);
750776
break;

clang/lib/CodeGen/CodeGenTypes.h

+8
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ class CodeGenTypes {
133133
/// memory representation is usually i8 or i32, depending on the target.
134134
llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false);
135135

136+
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
137+
/// which is represented as a pointer to a structure to LLVM extension type
138+
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
139+
/// The expected representation is:
140+
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
141+
/// %use%, (optional) %element_type_interpretation%)
142+
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);
143+
136144
/// GetFunctionType - Get the LLVM function type for \arg Info.
137145
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);
138146

clang/test/CodeGenSYCL/matrix.cpp

+18-17
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
#include <stdint.h>
66

77
namespace __spv {
8-
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
8+
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
99
struct __spirv_JointMatrixINTEL;
1010
}
1111

12-
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
13-
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
12+
target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1)
13+
// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0)
14+
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {}
1415

15-
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
16-
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
16+
// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0)
17+
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {}
1718

18-
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
19-
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
19+
// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0)
20+
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {}
2021

2122
namespace sycl {
2223
class half {};
@@ -25,17 +26,17 @@ namespace sycl {
2526
}
2627
typedef sycl::half my_half;
2728

28-
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
29-
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
29+
// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0)
30+
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {}
3031

31-
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
32-
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
32+
// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1)
33+
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {}
3334

34-
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
35-
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
35+
// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0)
36+
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {}
3637

37-
// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
38-
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}
38+
// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0)
39+
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {}
3940

40-
// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
41-
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}
41+
// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0)
42+
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {}

0 commit comments

Comments
 (0)