Skip to content

Commit 348da24

Browse files
authored
Add 'Use' parameter to TypeJointMatrixINTEL (#1456)
'Use' is an optional parameter that shows where in a math operation the matrix is used. It must be the result of a constant instruction with scalar 'integer type'. Spec: intel/llvm#5944 Signed-off-by: Dmitry Sidorov <[email protected]>
1 parent 28a37e6 commit 348da24

File tree

9 files changed

+72
-46
lines changed

9 files changed

+72
-46
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool IsClassMember) {
471471
SS << kSPIRVTypeName::PostfixDelim << R << kSPIRVTypeName::PostfixDelim << C
472472
<< kSPIRVTypeName::PostfixDelim << L << kSPIRVTypeName::PostfixDelim
473473
<< S;
474+
if (auto *Use = MT->getUse())
475+
SS << kSPIRVTypeName::PostfixDelim
476+
<< static_cast<SPIRVConstant *>(Use)->getZExtIntValue();
474477
std::string Name =
475478
getSPIRVTypeName(kSPIRVTypeName::JointMatrixINTEL, SS.str());
476479
return mapType(T, getOrCreateOpaquePtrType(M, Name));

lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,16 +479,26 @@ void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
479479
auto *PtrTy = dyn_cast<PointerType>(ST->getElementType(0));
480480
assert(PtrTy &&
481481
"Expected a pointer to an array to represent joint matrix type");
482-
size_t TypeLayout[4] = {0, 0, 0, 0};
482+
std::vector<size_t> TypeLayout;
483483
ArrayType *ArrayTy = dyn_cast<ArrayType>(PtrTy->getPointerElementType());
484484
assert(ArrayTy && "Expected a pointer element type of an array type to "
485485
"represent joint matrix type");
486-
TypeLayout[0] = ArrayTy->getNumElements();
486+
TypeLayout.push_back(ArrayTy->getNumElements());
487487
for (size_t I = 1; I != 4; ++I) {
488488
ArrayTy = dyn_cast<ArrayType>(ArrayTy->getElementType());
489489
assert(ArrayTy &&
490490
"Expected a element type to represent joint matrix type");
491-
TypeLayout[I] = ArrayTy->getNumElements();
491+
TypeLayout.push_back(ArrayTy->getNumElements());
492+
}
493+
// JointMatrixINTEL type can have optional 'Use' parameter, which is encoded
494+
// as another array dimention. In case if it has default 'Unnecessary' (4)
495+
// parameter - ignore it.
496+
if (isa<ArrayType>(ArrayTy->getElementType())) {
497+
ArrayTy = cast<ArrayType>(ArrayTy->getElementType());
498+
uint32_t UseInt = ArrayTy->getNumElements();
499+
assert(UseInt <= 4 && "Use parameter encoded in the array must be < 5 ");
500+
if (UseInt != 4)
501+
TypeLayout.push_back(UseInt);
492502
}
493503

494504
auto *ElemTy = ArrayTy->getElementType();
@@ -542,6 +552,9 @@ void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
542552
<< kSPIRVTypeName::PostfixDelim << std::to_string(TypeLayout[2] - 1)
543553
<< kSPIRVTypeName::PostfixDelim
544554
<< std::to_string(TypeLayout[3] - 1);
555+
if (TypeLayout.size() == 5)
556+
SPVName << kSPIRVTypeName::PostfixDelim
557+
<< std::to_string(TypeLayout[4] - 1);
545558
// Note, that this structure is not opaque and there is no way to make it
546559
// opaque but to recreate it entirely and replace it everywhere. Lets
547560
// keep the structure as is, dealing with it during SPIR-V generation.

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -578,12 +578,10 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
578578
consumeUnsignedInteger(Postfix, 10, N);
579579
return getUInt32(M, N);
580580
};
581-
SPIRVValue *Rows = transConstant(ParseInteger(Postfixes[1]));
582-
SPIRVValue *Columns = transConstant(ParseInteger(Postfixes[2]));
583-
SPIRVValue *Layout = transConstant(ParseInteger(Postfixes[3]));
584-
SPIRVValue *Scope = transConstant(ParseInteger(Postfixes[4]));
585-
return mapType(T, BM->addJointMatrixINTELType(transType(ElemTy), Rows,
586-
Columns, Layout, Scope));
581+
std::vector<SPIRVValue *> Args;
582+
for (size_t I = 1; I != Postfixes.size(); ++I)
583+
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
584+
return mapType(T, BM->addJointMatrixINTELType(transType(ElemTy), Args));
587585
}
588586

589587
SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(Type *T) {

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,8 @@ class SPIRVModuleImpl : public SPIRVModule {
232232
SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) override;
233233
void closeStructType(SPIRVTypeStruct *T, bool) override;
234234
SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) override;
235-
SPIRVTypeJointMatrixINTEL *addJointMatrixINTELType(SPIRVType *, SPIRVValue *,
236-
SPIRVValue *, SPIRVValue *,
237-
SPIRVValue *) override;
235+
SPIRVTypeJointMatrixINTEL *
236+
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
238237
SPIRVType *addOpaqueGenericType(Op) override;
239238
SPIRVTypeDeviceEvent *addDeviceEventType() override;
240239
SPIRVTypeQueue *addQueueType() override;
@@ -903,11 +902,10 @@ SPIRVTypeVector *SPIRVModuleImpl::addVectorType(SPIRVType *CompType,
903902
return addType(new SPIRVTypeVector(this, getId(), CompType, CompCount));
904903
}
905904

906-
SPIRVTypeJointMatrixINTEL *SPIRVModuleImpl::addJointMatrixINTELType(
907-
SPIRVType *CompType, SPIRVValue *Rows, SPIRVValue *Columns,
908-
SPIRVValue *Layout, SPIRVValue *Scope) {
909-
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Rows,
910-
Columns, Layout, Scope));
905+
SPIRVTypeJointMatrixINTEL *
906+
SPIRVModuleImpl::addJointMatrixINTELType(SPIRVType *CompType,
907+
std::vector<SPIRVValue *> Args) {
908+
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Args));
911909
}
912910

913911
SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,7 @@ class SPIRVModule {
244244
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
245245
virtual SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) = 0;
246246
virtual SPIRVTypeJointMatrixINTEL *
247-
addJointMatrixINTELType(SPIRVType *, SPIRVValue *, SPIRVValue *, SPIRVValue *,
248-
SPIRVValue *) = 0;
247+
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
249248
virtual SPIRVTypeVoid *addVoidType() = 0;
250249
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
251250
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,23 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) {
275275
}
276276

277277
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
278-
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType, SPIRVValue *Rows,
279-
SPIRVValue *Columns, SPIRVValue *Layout, SPIRVValue *Scope)
280-
: SPIRVType(M, FixedWC, OC, TheId), CompType(CompType), Rows(Rows),
281-
Columns(Columns), Layout(Layout), Scope(Scope) {}
278+
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
279+
std::vector<SPIRVValue *> Args)
280+
: SPIRVType(M, FixedWC + Args.size(), OC, TheId), CompType(CompType),
281+
Args(Args) {}
282282

283283
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL()
284-
: SPIRVType(OC), CompType(nullptr), Rows(nullptr), Columns(nullptr),
285-
Layout(nullptr), Scope(nullptr) {}
284+
: SPIRVType(OC), CompType(nullptr),
285+
Args({nullptr, nullptr, nullptr, nullptr}) {}
286286

287-
_SPIRV_IMP_ENCDEC6(SPIRVTypeJointMatrixINTEL, Id, CompType, Rows, Columns,
288-
Layout, Scope)
287+
void SPIRVTypeJointMatrixINTEL::encode(spv_ostream &O) const {
288+
auto Encoder = getEncoder(O);
289+
Encoder << Id << CompType << Args;
290+
}
291+
292+
void SPIRVTypeJointMatrixINTEL::decode(std::istream &I) {
293+
auto Decoder = getDecoder(I);
294+
Decoder >> Id >> CompType >> Args;
295+
}
289296

290297
} // namespace SPIRV

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,18 +1060,14 @@ class SPIRVTypeTokenINTEL : public SPIRVType {
10601060

10611061
class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10621062
SPIRVType *CompType;
1063-
SPIRVValue *Rows;
1064-
SPIRVValue *Columns;
1065-
SPIRVValue *Layout;
1066-
SPIRVValue *Scope;
1063+
std::vector<SPIRVValue *> Args;
10671064

10681065
public:
10691066
const static Op OC = internal::OpTypeJointMatrixINTEL;
1070-
const static SPIRVWord FixedWC = 7;
1067+
const static SPIRVWord FixedWC = 3;
10711068
// Complete constructor
10721069
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
1073-
SPIRVValue *Rows, SPIRVValue *Columns,
1074-
SPIRVValue *Layout, SPIRVValue *Scope);
1070+
std::vector<SPIRVValue *> Args);
10751071
// Incomplete constructor
10761072
SPIRVTypeJointMatrixINTEL();
10771073
_SPIRV_DCL_ENCDEC
@@ -1081,11 +1077,16 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10811077
SPIRVCapVec getRequiredCapability() const override {
10821078
return {internal::CapabilityJointMatrixINTEL};
10831079
}
1080+
void setWordCount(SPIRVWord WordCount) override {
1081+
SPIRVType::setWordCount(WordCount);
1082+
Args.resize(WordCount - FixedWC);
1083+
}
10841084
SPIRVType *getCompType() const { return CompType; }
1085-
SPIRVValue *getLayout() const { return Layout; }
1086-
SPIRVValue *getRows() const { return Rows; }
1087-
SPIRVValue *getColumns() const { return Columns; }
1088-
SPIRVValue *getScope() const { return Scope; }
1085+
SPIRVValue *getRows() const { return Args[0]; }
1086+
SPIRVValue *getColumns() const { return Args[1]; }
1087+
SPIRVValue *getLayout() const { return Args[2]; }
1088+
SPIRVValue *getScope() const { return Args[3]; }
1089+
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
10891090
};
10901091

10911092
} // namespace SPIRV

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,14 @@ enum InternalLoopControlMask { ILoopControlLoopCountINTELMask = 0x1000000 };
9191
constexpr LinkageType LinkageTypeInternal =
9292
static_cast<LinkageType>(ILTInternal);
9393

94-
enum InternalJointMatrixLayout { RowMajor, ColumnMajor, PackedA, PackedB };
94+
enum InternalJointMatrixLayout {
95+
RowMajor = 0,
96+
ColumnMajor = 1,
97+
PackedA = 2,
98+
PackedB = 3
99+
};
100+
101+
enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
95102

96103
enum InternalBuiltIn {
97104
IBuiltInSubDeviceIDINTEL = 6135,

test/transcoding/SPV_INTEL_joint_matrix/joint_matrix.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
99

1010
; CHECK-PRE: %spirv.JointMatrixINTEL._short_2_2_0_3
11-
; CHECK-PRE: %spirv.JointMatrixINTEL._char_2_16_0_3
11+
; CHECK-PRE: %spirv.JointMatrixINTEL._char_2_16_0_3_0
1212
; CHECK-PRE: %spirv.JointMatrixINTEL._char_16_2_3_3
1313

1414
; CHECK-SPIRV: Capability JointMatrixINTEL
@@ -24,7 +24,7 @@
2424
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#Sixteen:]] 16
2525
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#FortyTwo:]] 42
2626
; CHECK-SPIRV: TypeJointMatrixINTEL [[#CTy:]] [[#ShortTy]] [[#Two]] [[#Two]] [[#Zero]] [[#Three]]
27-
; CHECK-SPIRV: TypeJointMatrixINTEL [[#ATy:]] [[#CharTy]] [[#Two]] [[#Sixteen]] [[#Zero]] [[#Three]]
27+
; CHECK-SPIRV: TypeJointMatrixINTEL [[#ATy:]] [[#CharTy]] [[#Two]] [[#Sixteen]] [[#Zero]] [[#Three]] [[#Zero]]
2828
; CHECK-SPIRV: TypeJointMatrixINTEL [[#BTy:]] [[#CharTy]] [[#Sixteen]] [[#Two]] [[#Three]] [[#Three]]
2929

3030
; CHECK-SPIRV: Function [[#]] [[#Kernel]]
@@ -48,14 +48,14 @@
4848

4949

5050
; CHECK-LLVM: %spirv.JointMatrixINTEL._short_2_2_0_3
51-
; CHECK-LLVM: %spirv.JointMatrixINTEL._char_2_16_0_3
51+
; CHECK-LLVM: %spirv.JointMatrixINTEL._char_2_16_0_3_0
5252
; CHECK-LLVM: %spirv.JointMatrixINTEL._char_16_2_3_3
5353

5454
; CHECK-LLVM: [[CLoaded:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3PU3AS4sliii(i16 addrspace(4)* [[CPtr:%.*]], i64 [[Stride:%.*]], i32 0, i32 3, i32 0)
5555
; CHECK-LLVM: [[C:%.*]] = phi %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [ [[CLoaded]], %entry ], [ [[CMad:%.*]], %for.body.i ]
56-
; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_2_16_0_3PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
56+
; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* @_Z79__spirv_JointMatrixLoadINTEL_RPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
5757
; CHECK-LLVM: [[B:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS4cliii(i8 addrspace(4)* [[BPtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
58-
; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS139__spirv_JointMatrixINTEL__char_2_16_0_3PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
58+
; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
5959
; CHECK-LLVM: call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS4sPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3liii(i16 addrspace(4)* [[CPtr]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i64 [[Stride]], i32 0, i32 3, i32 0)
6060
; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42)
6161
; CHECK-LLVM: store i32 0, i32 addrspace(4)* [[StoredZero:%.*]], align 4
@@ -67,8 +67,8 @@ source_filename = "./joint_matrix_test.cpp"
6767
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
6868
target triple = "spir64-unknown-unknown"
6969

70-
%"struct.__spv::__spirv_JointMatrixINTEL" = type { [2 x [2 x [1 x [4 x i16]]]]* }
71-
%"struct.__spv::__spirv_JointMatrixINTEL.0" = type { [2 x [16 x [1 x [4 x i8]]]]* }
70+
%"struct.__spv::__spirv_JointMatrixINTEL" = type { [2 x [2 x [1 x [4 x [4 x i16]]]]]* }
71+
%"struct.__spv::__spirv_JointMatrixINTEL.0" = type { [2 x [16 x [1 x [4 x [1 x i8]]]]]* }
7272
%"struct.__spv::__spirv_JointMatrixINTEL.2" = type { [16 x [2 x [4 x [4 x i8]]]]* }
7373

7474
$_ZTSZ4mainE11matrix_test = comdat any

0 commit comments

Comments
 (0)