Skip to content

Commit 4447a50

Browse files
authored
Revert "[SYCL] Represent JointMatrixINTEL type as extension type" (#9071)
It appears to be, that mem2reg and SROA passes can't handle target extension type properly. It means, that with turned on optimizations alloca/load/store sequences of joint matrix types won't be eliminated. It results in a crash in IGC since it can't handle such case yet. Note, it means that matrix samples compiled with -O0 also don't work now. So we have to (temporary?) revert this patch. This reverts commit 6f8e456.
1 parent 097d21c commit 4447a50

File tree

5 files changed

+82
-115
lines changed

5 files changed

+82
-115
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

+59-84
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,65 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
5151
StringRef suffix) {
5252
SmallString<256> TypeName;
5353
llvm::raw_svector_ostream OS(TypeName);
54+
// If RD is spirv_JointMatrixINTEL type, mangle differently.
55+
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
56+
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
57+
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
58+
ArrayRef<TemplateArgument> TemplateArgs =
59+
TemplateDecl->getTemplateArgs().asArray();
60+
OS << "spirv.JointMatrixINTEL.";
61+
for (auto &TemplateArg : TemplateArgs) {
62+
OS << "_";
63+
if (TemplateArg.getKind() == TemplateArgument::Type) {
64+
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
65+
if (TTy->isIntegerTy()) {
66+
switch (TTy->getIntegerBitWidth()) {
67+
case 8:
68+
OS << "char";
69+
break;
70+
case 16:
71+
OS << "short";
72+
break;
73+
case 32:
74+
OS << "int";
75+
break;
76+
case 64:
77+
OS << "long";
78+
break;
79+
default:
80+
OS << "i" << TTy->getIntegerBitWidth();
81+
break;
82+
}
83+
} else if (TTy->isHalfTy()) {
84+
OS << "half";
85+
} else if (TTy->isFloatTy()) {
86+
OS << "float";
87+
} else if (TTy->isDoubleTy()) {
88+
OS << "double";
89+
} else if (TTy->isBFloatTy()) {
90+
OS << "bfloat16";
91+
} else if (TTy->isStructTy()) {
92+
StringRef LlvmTyName = TTy->getStructName();
93+
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
94+
if (LlvmTyName.startswith("class.sycl::") ||
95+
LlvmTyName.startswith("class.__sycl_internal::"))
96+
LlvmTyName = LlvmTyName.rsplit("::").second;
97+
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
98+
LlvmTyName != "tf32")
99+
llvm_unreachable("Wrong matrix base type!");
100+
OS << LlvmTyName;
101+
} else {
102+
llvm_unreachable("Wrong matrix base type!");
103+
}
104+
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
105+
OS << TemplateArg.getAsIntegral();
106+
}
107+
}
108+
Ty->setName(OS.str());
109+
return;
110+
}
111+
}
112+
}
54113
OS << RD->getKindName() << '.';
55114

56115
// FIXME: We probably want to make more tweaks to the printing policy. For
@@ -401,78 +460,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
401460
return ResultType;
402461
}
403462

404-
template <bool NeedTypeInterpret = false>
405-
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
406-
ArrayRef<TemplateArgument> TemplateArgs,
407-
const unsigned Val = 0) {
408-
// TODO: we should actually have exactly 5 template parameters: 1 for
409-
// type and 4 for type parameters. But in previous version of the SPIR-V
410-
// spec we have Layout matrix type parameter, that was later removed.
411-
// Once we update to the newest version of the spec - this should be updated.
412-
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
413-
"Wrong JointMatrixINTEL template parameters number");
414-
// This is required to represent optional 'Component Type Interpretation'
415-
// parameter
416-
using ParamsType =
417-
typename std::conditional<NeedTypeInterpret, SmallVector<unsigned, 6>,
418-
SmallVector<unsigned, 5>>::type;
419-
ParamsType Params;
420-
if constexpr (NeedTypeInterpret)
421-
Params = {0, 0, 0, 0, 0, Val};
422-
else
423-
Params = {0, 0, 0, 0, 0};
424-
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
425-
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
426-
"Wrong JointMatrixINTEL template parameter");
427-
Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue();
428-
}
429-
return llvm::TargetExtType::get(CompTy->getContext(),
430-
"spirv.JointMatrixINTEL", {CompTy}, Params);
431-
}
432-
433-
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
434-
/// which is represented as a pointer to a structure to LLVM extension type
435-
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
436-
/// The expected representation is:
437-
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
438-
/// %use%, (optional) %element_type_interpretation%)
439-
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
440-
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
441-
ArrayRef<TemplateArgument> TemplateArgs =
442-
TemplateDecl->getTemplateArgs().asArray();
443-
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
444-
"1st JointMatrixINTEL template parameter must be type");
445-
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
446-
447-
// Per JointMatrixINTEL spec the type can have an optional
448-
// 'Component Type Interpretation' parameter. We should emit it in case
449-
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
450-
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
451-
// 'tf32' as 'float' types.
452-
if (CompTy->isStructTy()) {
453-
StringRef LlvmTyName = CompTy->getStructName();
454-
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
455-
if (LlvmTyName.startswith("class.sycl::") ||
456-
LlvmTyName.startswith("class.__sycl_internal::"))
457-
LlvmTyName = LlvmTyName.rsplit("::").second;
458-
if (LlvmTyName == "half") {
459-
CompTy = llvm::Type::getHalfTy(getLLVMContext());
460-
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
461-
} else if (LlvmTyName == "tf32") {
462-
CompTy = llvm::Type::getFloatTy(getLLVMContext());
463-
// 'tf32' interpretation is mapped to '0'
464-
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
465-
} else if (LlvmTyName == "bfloat16") {
466-
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
467-
// 'bfloat16' interpretation is mapped to '1'
468-
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
469-
} else {
470-
llvm_unreachable("Wrong matrix base type!");
471-
}
472-
}
473-
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
474-
}
475-
476463
/// ConvertType - Convert the specified type to its LLVM form.
477464
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
478465
T = Context.getCanonicalType(T);
@@ -758,18 +745,6 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
758745
llvm::Type *PointeeType = ConvertTypeForMem(ETy);
759746
if (PointeeType->isVoidTy())
760747
PointeeType = llvm::Type::getInt8Ty(getLLVMContext());
761-
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
762-
const Type *ClangETy = ETy.getTypePtrOrNull();
763-
if (ClangETy && ClangETy->isStructureOrClassType()) {
764-
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
765-
if (RD &&
766-
RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
767-
ResultType = ConvertSYCLJointMatrixINTELType(RD);
768-
break;
769-
}
770-
}
771-
}
772-
773748
unsigned AS = getTargetAddressSpace(ETy);
774749
ResultType = llvm::PointerType::get(PointeeType, AS);
775750
break;

clang/lib/CodeGen/CodeGenTypes.h

-8
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,6 @@ class CodeGenTypes {
133133
/// memory representation is usually i8 or i32, depending on the target.
134134
llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false);
135135

136-
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
137-
/// which is represented as a pointer to a structure to LLVM extension type
138-
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
139-
/// The expected representation is:
140-
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
141-
/// %use%, (optional) %element_type_interpretation%)
142-
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);
143-
144136
/// GetFunctionType - Get the LLVM function type for \arg Info.
145137
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);
146138

clang/test/CodeGenSYCL/matrix.cpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
#include <stdint.h>
66

77
namespace __spv {
8-
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
8+
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
99
struct __spirv_JointMatrixINTEL;
1010
}
1111

12-
// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0)
13-
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {}
12+
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
13+
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
1414

15-
// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0)
16-
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {}
15+
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
16+
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
1717

18-
// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0)
19-
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {}
18+
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
19+
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
2020

2121
namespace sycl {
2222
class half {};
@@ -25,17 +25,17 @@ namespace sycl {
2525
}
2626
typedef sycl::half my_half;
2727

28-
// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0)
29-
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {}
28+
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
29+
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
3030

31-
// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1)
32-
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {}
31+
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
32+
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
3333

34-
// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0)
35-
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {}
34+
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
35+
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
3636

37-
// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0)
38-
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {}
37+
// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
38+
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}
3939

40-
// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0)
41-
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {}
40+
// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
41+
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}

sycl/test/matrix/legacy/matrix-int8-test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
4-
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3, 0)
5-
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 0)
3+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque
4+
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque
5+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque
66

77
#include <iostream>
88
#include <sycl/sycl.hpp>

sycl/test/matrix/matrix-int8-test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
4-
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2)
5-
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1)
3+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque
4+
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque
5+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque
66

77
#include <iostream>
88
#include <sycl/sycl.hpp>

0 commit comments

Comments
 (0)