Skip to content

Commit 44f6f57

Browse files
authored
Add support for SPV_INTEL_joint_matrix extension (#1165)
The spec is available at intel/llvm#4373 Signed-off-by: Alexey Sotkin <[email protected]>
1 parent e9671a5 commit 44f6f57

17 files changed

+432
-5
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ EXT(SPV_INTEL_debug_module)
4646
EXT(SPV_INTEL_runtime_aligned)
4747
EXT(SPV_INTEL_arithmetic_fence)
4848
EXT(SPV_INTEL_bfloat16_conversion)
49+
EXT(SPV_INTEL_joint_matrix)

lib/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_llvm_library(LLVMSPIRVLib
4040
Analysis
4141
BitWriter
4242
Core
43+
Demangle
4344
IRReader
4445
Linker
4546
Support

lib/SPIRV/SPIRVInternal.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,22 @@ SPIRVMap<SPIRVExtInstSetKind, std::string, SPIRVExtSetShortName>::init() {
265265
typedef SPIRVMap<SPIRVExtInstSetKind, std::string, SPIRVExtSetShortName>
266266
SPIRVExtSetShortNameMap;
267267

268+
template <>
269+
inline void SPIRVMap<internal::InternalJointMatrixLayout, std::string>::init() {
270+
add(internal::RowMajor, "matrix.rowmajor");
271+
add(internal::ColumnMajor, "matrix.columnmajor");
272+
add(internal::PackedA, "matrix.packed.a");
273+
add(internal::PackedB, "matrix.packed.b");
274+
}
275+
typedef SPIRVMap<internal::InternalJointMatrixLayout, std::string>
276+
SPIRVMatrixLayoutMap;
277+
278+
template <> inline void SPIRVMap<spv::Scope, std::string>::init() {
279+
add(ScopeWorkgroup, "scope.workgroup");
280+
add(ScopeSubgroup, "scope.subgroup");
281+
}
282+
typedef SPIRVMap<spv::Scope, std::string> SPIRVMatrixScopeMap;
283+
268284
#define SPIR_MD_COMPILER_OPTIONS "opencl.compiler.options"
269285
#define SPIR_MD_KERNEL_ARG_ADDR_SPACE "kernel_arg_addr_space"
270286
#define SPIR_MD_KERNEL_ARG_ACCESS_QUAL "kernel_arg_access_qual"
@@ -312,6 +328,7 @@ const static char ConstantSampler[] = "ConstantSampler";
312328
const static char PipeStorage[] = "PipeStorage";
313329
const static char ConstantPipeStorage[] = "ConstantPipeStorage";
314330
const static char VmeImageINTEL[] = "VmeImageINTEL";
331+
const static char JointMatrixINTEL[] = "JointMatrixINTEL";
315332
} // namespace kSPIRVTypeName
316333

317334
namespace kSPR2TypeName {
@@ -1054,6 +1071,8 @@ bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false);
10541071

10551072
bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false);
10561073

1074+
template <typename T>
1075+
MetadataAsValue *map2MDString(LLVMContext &C, SPIRVValue *V);
10571076
} // namespace SPIRV
10581077

10591078
#endif // SPIRV_SPIRVINTERNAL_H

lib/SPIRV/SPIRVReader.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,23 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool IsClassMember) {
443443
SPIRAddressSpace::SPIRAS_Global));
444444
}
445445

446+
case internal::OpTypeJointMatrixINTEL: {
447+
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
448+
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
449+
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
450+
std::stringstream SS;
451+
SS << kSPIRVTypeName::PostfixDelim;
452+
SS << transTypeToOCLTypeName(MT->getCompType());
453+
auto L = static_cast<SPIRVConstant *>(MT->getLayout())->getZExtIntValue();
454+
auto S = static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue();
455+
SS << kSPIRVTypeName::PostfixDelim << R << kSPIRVTypeName::PostfixDelim << C
456+
<< kSPIRVTypeName::PostfixDelim << L << kSPIRVTypeName::PostfixDelim
457+
<< S;
458+
std::string Name =
459+
getSPIRVTypeName(kSPIRVTypeName::JointMatrixINTEL, SS.str());
460+
return mapType(T, getOrCreateOpaquePtrType(M, Name));
461+
}
462+
446463
default: {
447464
auto OC = T->getOpCode();
448465
if (isOpaqueGenericTypeOpCode(OC) || isSubgroupAvcINTELTypeOpCode(OC))
@@ -3097,7 +3114,7 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
30973114
bool AddRetTypePostfix = false;
30983115
if (OC == OpImageQuerySizeLod || OC == OpImageQuerySize ||
30993116
OC == OpImageRead || OC == OpSubgroupImageBlockReadINTEL ||
3100-
OC == OpSubgroupBlockReadINTEL)
3117+
OC == OpSubgroupBlockReadINTEL || OC == internal::OpJointMatrixLoadINTEL)
31013118
AddRetTypePostfix = true;
31023119

31033120
bool IsRetSigned = false;

lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
#include "SPIRVInternal.h"
4242
#include "libSPIRV/SPIRVDebug.h"
4343

44+
#include "llvm/ADT/StringExtras.h" // llvm::isDigit
45+
#include "llvm/Demangle/Demangle.h"
4446
#include "llvm/IR/InstVisitor.h"
4547
#include "llvm/IR/Instructions.h"
4648
#include "llvm/IR/Operator.h"
@@ -104,7 +106,7 @@ class SPIRVRegularizeLLVMBase {
104106
void buildUMulWithOverflowFunc(Function *UMulFunc);
105107

106108
static std::string lowerLLVMIntrinsicName(IntrinsicInst *II);
107-
109+
void adaptStructTypes(StructType *ST);
108110
static char ID;
109111

110112
private:
@@ -291,6 +293,58 @@ void SPIRVRegularizeLLVMBase::lowerUMulWithOverflow(
291293
UMulIntrinsic->setCalledFunction(UMulFunc);
292294
}
293295

296+
void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
297+
if (!ST->hasName())
298+
return;
299+
StringRef STName = ST->getName();
300+
STName.consume_front("struct.");
301+
StringRef MangledName = STName.substr(0, STName.find('.'));
302+
303+
// Demangle the name of a template struct and parse the template
304+
// parameters which look like:
305+
// <signed char, 2ul, 2ul, (spv::MatrixLayout)0, (spv::Scope)3>
306+
// The result should look like SPIR-V friendly LLVM IR:
307+
// %spirv.JointMatrixINTEL._char_2_2_0_3
308+
if (MangledName.startswith("_ZTSN5__spv24__spirv_JointMatrixINTEL")) {
309+
std::string DemangledName = llvm::demangle(MangledName.str());
310+
StringRef Name(DemangledName);
311+
Name = Name.slice(Name.find('<') + 1, Name.rfind('>'));
312+
std::stringstream SPVName;
313+
// Name = signed char, 2ul, 2ul, (spv::MatrixLayout)0, (spv::Scope)3
314+
auto P = Name.split(", ");
315+
// P.first = "signed char
316+
// P.second = "2ul, 2ul, (spv::MatrixLayout)0, (spv::Scope)3"
317+
StringRef ElemType = P.first;
318+
// remove possile qualifiers, like "const" or "signed"
319+
ElemType.consume_back(" const");
320+
size_t Space = ElemType.rfind(' ');
321+
if (Space != StringRef::npos)
322+
ElemType = ElemType.substr(Space + 1);
323+
P = P.second.split(", ");
324+
// P.first = "2ul"
325+
// P.second = "2ul, (spv::MatrixLayout)0, (spv::Scope)3"
326+
StringRef Rows = P.first.take_while(llvm::isDigit);
327+
P = P.second.split(", ");
328+
// P.first = "2ul"
329+
// P.second = "(spv::MatrixLayout)0, (spv::Scope)3"
330+
StringRef Cols = P.first.take_while(llvm::isDigit);
331+
P = P.second.split(", ");
332+
// P.first = "(spv::MatrixLayout)0"
333+
// P.second = "(spv::Scope)3"
334+
StringRef Layout = P.first.substr(P.first.rfind(')') + 1);
335+
StringRef Scope = P.second.substr(P.second.rfind(')') + 1);
336+
337+
SPVName << kSPIRVTypeName::PrefixAndDelim
338+
<< kSPIRVTypeName::JointMatrixINTEL << kSPIRVTypeName::Delimiter
339+
<< kSPIRVTypeName::PostfixDelim << ElemType.str()
340+
<< kSPIRVTypeName::PostfixDelim << Rows.str()
341+
<< kSPIRVTypeName::PostfixDelim << Cols.str()
342+
<< kSPIRVTypeName::PostfixDelim << Layout.str()
343+
<< kSPIRVTypeName::PostfixDelim << Scope.str();
344+
ST->setName(SPVName.str());
345+
}
346+
}
347+
294348
bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
295349
M = &Module;
296350
Ctx = &M->getContext();
@@ -430,6 +484,9 @@ bool SPIRVRegularizeLLVMBase::regularize() {
430484
}
431485
}
432486

487+
for (StructType *ST : M->getIdentifiedStructTypes())
488+
adaptStructTypes(ST);
489+
433490
if (SPIRVDbgSaveRegularizedModule)
434491
saveLLVMModule(M, RegularizedModuleTmpFile);
435492
return true;

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "llvm/Bitcode/BitcodeWriter.h"
5353
#include "llvm/IR/IRBuilder.h"
5454
#include "llvm/IR/IntrinsicInst.h"
55+
#include "llvm/IR/Metadata.h"
5556
#include "llvm/Support/CommandLine.h"
5657
#include "llvm/Support/Debug.h"
5758
#include "llvm/Support/ErrorHandling.h"
@@ -148,7 +149,17 @@ std::string mapLLVMTypeToOCLType(const Type *Ty, bool Signed) {
148149
Ss << mapLLVMTypeToOCLType(EleTy, Signed) << Size;
149150
return Ss.str();
150151
}
151-
return "invalid_type";
152+
// It is expected that `Ty` can be mapped to `ReturnType` from "Optional
153+
// Postfixes for SPIR-V Builtin Function Names" section of
154+
// SPIRVRepresentationInLLVM.rst document (aka SPIRV-friendly IR).
155+
// If `Ty` is not a scalar or vector type mentioned in the document (return
156+
// value of some SPIR-V instructions may be represented as pointer to a struct
157+
// in LLVM IR) we can mangle the type.
158+
BuiltinFuncMangleInfo MangleInfo;
159+
std::string MangledName =
160+
mangleBuiltin("", const_cast<Type *>(Ty), &MangleInfo);
161+
// Remove "_Z0"(3 characters) from the front of the name
162+
return MangledName.erase(0, 3);
152163
}
153164

154165
std::string mapSPIRVTypeToOCLType(SPIRVType *Ty, bool Signed) {
@@ -2122,4 +2133,16 @@ std::string getSPIRVFriendlyIRFunctionName(const std::string &UniqName,
21222133
return mangleBuiltin(UniqName, ArgTys, &MangleInfo);
21232134
}
21242135

2136+
template <typename T>
2137+
MetadataAsValue *map2MDString(LLVMContext &C, SPIRVValue *V) {
2138+
if (V->getOpCode() != OpConstant)
2139+
return nullptr;
2140+
uint64_t Const = static_cast<SPIRVConstant *>(V)->getZExtIntValue();
2141+
std::string Str = SPIRVMap<T, std::string>::map(static_cast<T>(Const));
2142+
return MetadataAsValue::get(C, MDString::get(C, Str));
2143+
}
2144+
template MetadataAsValue *
2145+
map2MDString<internal::InternalJointMatrixLayout>(LLVMContext &, SPIRVValue *);
2146+
template MetadataAsValue *map2MDString<spv::Scope>(LLVMContext &, SPIRVValue *);
2147+
21252148
} // namespace SPIRV

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,38 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(Type *T) {
561561
return mapType(T, BM->addQueueType());
562562
else if (TN == kSPIRVTypeName::PipeStorage)
563563
return mapType(T, BM->addPipeStorageType());
564-
else
564+
else if (TN == kSPIRVTypeName::JointMatrixINTEL) {
565+
Type *ElemTy = nullptr;
566+
StringRef Ty{Postfixes[0]};
567+
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
568+
.Case("char", 8)
569+
.Case("short", 16)
570+
.Case("int", 32)
571+
.Case("long", 64)
572+
.Default(0);
573+
if (NumBits)
574+
ElemTy = IntegerType::get(M->getContext(), NumBits);
575+
else if (Ty == "half")
576+
ElemTy = Type::getHalfTy(M->getContext());
577+
else if (Ty == "float")
578+
ElemTy = Type::getFloatTy(M->getContext());
579+
else if (Ty == "double")
580+
ElemTy = Type::getDoubleTy(M->getContext());
581+
else
582+
llvm_unreachable("Unexpected type for matrix!");
583+
584+
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
585+
unsigned long long N = 0;
586+
consumeUnsignedInteger(Postfix, 10, N);
587+
return getUInt32(M, N);
588+
};
589+
SPIRVValue *Rows = transConstant(ParseInteger(Postfixes[1]));
590+
SPIRVValue *Columns = transConstant(ParseInteger(Postfixes[2]));
591+
SPIRVValue *Layout = transConstant(ParseInteger(Postfixes[3]));
592+
SPIRVValue *Scope = transConstant(ParseInteger(Postfixes[4]));
593+
return mapType(T, BM->addJointMatrixINTELType(transType(ElemTy), Rows,
594+
Columns, Layout, Scope));
595+
} else
565596
return mapType(T,
566597
BM->addOpaqueGenericType(SPIRVOpaqueTypeOpCodeMap::map(TN)));
567598
}

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,6 +3251,28 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
32513251
_SPIRV_OP(ConvertFToBF16INTEL)
32523252
_SPIRV_OP(ConvertBF16ToFINTEL)
32533253
#undef _SPIRV_OP
3254+
3255+
class SPIRVJointMatrixINTELInstBase : public SPIRVInstTemplateBase {
3256+
protected:
3257+
llvm::Optional<ExtensionID> getRequiredExtension() const override {
3258+
return ExtensionID::SPV_INTEL_joint_matrix;
3259+
}
3260+
};
3261+
3262+
class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
3263+
SPIRVCapVec getRequiredCapability() const override {
3264+
return getVec(internal::CapabilityJointMatrixINTEL);
3265+
}
3266+
};
3267+
3268+
#define _SPIRV_OP(x, ...) \
3269+
typedef SPIRVInstTemplate<SPIRVJointMatrixINTELInst, internal::Op##x##INTEL, \
3270+
__VA_ARGS__> \
3271+
SPIRV##x##INTEL;
3272+
_SPIRV_OP(JointMatrixLoad, true, 6, true)
3273+
_SPIRV_OP(JointMatrixStore, false, 5, true)
3274+
_SPIRV_OP(JointMatrixMad, true, 7)
3275+
#undef _SPIRV_OP
32543276
} // namespace SPIRV
32553277

32563278
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ class SPIRVModuleImpl : public SPIRVModule {
244244
SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) override;
245245
void closeStructType(SPIRVTypeStruct *T, bool) override;
246246
SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) override;
247+
SPIRVTypeJointMatrixINTEL *addJointMatrixINTELType(SPIRVType *, SPIRVValue *,
248+
SPIRVValue *, SPIRVValue *,
249+
SPIRVValue *) override;
247250
SPIRVType *addOpaqueGenericType(Op) override;
248251
SPIRVTypeDeviceEvent *addDeviceEventType() override;
249252
SPIRVTypeQueue *addQueueType() override;
@@ -897,6 +900,14 @@ SPIRVTypeVector *SPIRVModuleImpl::addVectorType(SPIRVType *CompType,
897900
SPIRVWord CompCount) {
898901
return addType(new SPIRVTypeVector(this, getId(), CompType, CompCount));
899902
}
903+
904+
SPIRVTypeJointMatrixINTEL *SPIRVModuleImpl::addJointMatrixINTELType(
905+
SPIRVType *CompType, SPIRVValue *Rows, SPIRVValue *Columns,
906+
SPIRVValue *Layout, SPIRVValue *Scope) {
907+
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Rows,
908+
Columns, Layout, Scope));
909+
}
910+
900911
SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
901912
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
902913
}

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class SPIRVAsmINTEL;
9595
class SPIRVAsmCallINTEL;
9696
class SPIRVTypeBufferSurfaceINTEL;
9797
class SPIRVTypeTokenINTEL;
98+
class SPIRVTypeJointMatrixINTEL;
9899

99100
typedef SPIRVBasicBlock SPIRVLabel;
100101
struct SPIRVTypeImageDescriptor;
@@ -242,6 +243,9 @@ class SPIRVModule {
242243
virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0;
243244
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
244245
virtual SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) = 0;
246+
virtual SPIRVTypeJointMatrixINTEL *
247+
addJointMatrixINTELType(SPIRVType *, SPIRVValue *, SPIRVValue *, SPIRVValue *,
248+
SPIRVValue *) = 0;
245249
virtual SPIRVTypeVoid *addVoidType() = 0;
246250
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
247251
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
579579
add(CapabilityMax, "Max");
580580
add(internal::CapabilityFPArithmeticFenceINTEL, "FPArithmeticFenceINTEL");
581581
add(internal::CapabilityBfloat16ConversionINTEL, "Bfloat16ConversionINTEL");
582+
add(internal::CapabilityJointMatrixINTEL, "JointMatrixINTEL");
582583
}
583584
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
584585

lib/SPIRV/libSPIRV/SPIRVOpCode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ inline bool isTypeOpCode(Op OpCode) {
215215
unsigned OC = OpCode;
216216
return (OpTypeVoid <= OC && OC <= OpTypePipe) || OC == OpTypePipeStorage ||
217217
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
218-
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL;
218+
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL ||
219+
OC == internal::OpTypeJointMatrixINTEL;
219220
}
220221

221222
inline bool isSpecConstantOpCode(Op OpCode) {

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@ _SPIRV_OP_INTERNAL(TypeTokenINTEL, internal::OpTypeTokenINTEL)
88
_SPIRV_OP_INTERNAL(ArithmeticFenceINTEL, internal::OpArithmeticFenceINTEL)
99
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
1010
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
11+
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
12+
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
13+
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
14+
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,18 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) {
265265
SPIRVId PointerId;
266266
Decoder >> PointerId >> SC;
267267
}
268+
269+
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
270+
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType, SPIRVValue *Rows,
271+
SPIRVValue *Columns, SPIRVValue *Layout, SPIRVValue *Scope)
272+
: SPIRVType(M, FixedWC, OC, TheId), CompType(CompType), Rows(Rows),
273+
Columns(Columns), Layout(Layout), Scope(Scope) {}
274+
275+
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL()
276+
: SPIRVType(OC), CompType(nullptr), Rows(nullptr), Columns(nullptr),
277+
Layout(nullptr), Scope(nullptr) {}
278+
279+
_SPIRV_IMP_ENCDEC6(SPIRVTypeJointMatrixINTEL, Id, CompType, Rows, Columns,
280+
Layout, Scope)
281+
268282
} // namespace SPIRV

0 commit comments

Comments
 (0)