Skip to content

Commit fdc961f

Browse files
authored
Reverse translation of arithmetic instructions for cooperative matrixes (#2166)
Implement translation via SPIR-V friendly calls, as: the LLVM instructions are not capable to accept target extension types; matrix arithmetic instructions require additional carry additional rules, which LLVM can not perform (for example while technically Add for vectors and (flattened) matrices is the same - yet for matrices we need to perform extra checks, also mul instruction is complitely different). As for now some BE would need to recognize and define what to do with a call to __spirv_FMul(matrixA, matrixB). Better option is to map such SPIR-V to an intrinsic or define an appropriate type in LLVM (hence defining rules for GEP and other instructions) , but it's off the table now.
1 parent 21cc1d0 commit fdc961f

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

lib/SPIRV/SPIRVReader.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,9 @@ static void applyFPFastMathModeDecorations(const SPIRVValue *BV,
10961096
Value *SPIRVToLLVM::transShiftLogicalBitwiseInst(SPIRVValue *BV, BasicBlock *BB,
10971097
Function *F) {
10981098
SPIRVBinary *BBN = static_cast<SPIRVBinary *>(BV);
1099+
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
1100+
return mapValue(BV, transSPIRVBuiltinFromInst(BBN, BB));
1101+
}
10991102
Instruction::BinaryOps BO;
11001103
auto OP = BBN->getOpCode();
11011104
if (isLogicalOpCode(OP))
@@ -2412,6 +2415,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
24122415
Builder.SetInsertPoint(BB);
24132416
}
24142417
SPIRVUnary *BC = static_cast<SPIRVUnary *>(BV);
2418+
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
2419+
return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB));
2420+
}
24152421
auto *Neg =
24162422
Builder.CreateNeg(transValue(BC->getOperand(0), F, BB), BV->getName());
24172423
if (auto *NegInst = dyn_cast<Instruction>(Neg)) {
@@ -2464,6 +2470,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
24642470

24652471
case OpFNegate: {
24662472
SPIRVUnary *BC = static_cast<SPIRVUnary *>(BV);
2473+
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
2474+
return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB));
2475+
}
24672476
auto *Neg = UnaryOperator::CreateFNeg(transValue(BC->getOperand(0), F, BB),
24682477
BV->getName(), BB);
24692478
applyFPFastMathModeDecorations(BV, Neg);

test/extensions/KHR/SPV_KHR_cooperative_matrix/arithmetic_instructions.ll

+24-9
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
; TODO: Validation is disabled till the moment the tools in CI are updated (passes locally)
77
; R/UN: spirv-val %t.spv
88

9-
; TODO: come up with an approach and implement reverse translation
10-
; R/UN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc
11-
; R/UN: llvm-dis %t.rev.bc
12-
; R/UN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
9+
; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc
10+
; RUN: llvm-dis %t.rev.bc
11+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
1312

1413
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
1514
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#MatrixTypeInt:]] [[#TypeInt]]
@@ -21,6 +20,8 @@ target triple = "spir-unknown-unknown"
2120

2221
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixIn:]] [[#]] {{$}}
2322
; CHECK-SPIRV: SNegate [[#MatrixTypeInt]] [[#]] [[#MatrixIn]]
23+
; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
24+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegatePU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1)
2425
define spir_kernel void @testSNegate(i32 %a) #0 !kernel_arg_addr_space !10 !kernel_arg_access_qual !11 !kernel_arg_type !12 !kernel_arg_type_qual !9 !kernel_arg_base_type !12 {
2526
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
2627
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegate(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1)
@@ -29,7 +30,8 @@ define spir_kernel void @testSNegate(i32 %a) #0 !kernel_arg_addr_space !10 !kern
2930

3031
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixIn:]] [[#]] {{$}}
3132
; CHECK-SPIRV: FNegate [[#MatrixTypeFloat]] [[#]] [[#MatrixIn]]
32-
; CHECK-LLVM: fneg
33+
; CHECK-LLVM: %0 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
34+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z15__spirv_FNegatePU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
3335
define spir_kernel void @testFNeg(float %a) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !9 {
3436
entry:
3537
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
@@ -40,6 +42,9 @@ entry:
4042
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
4143
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
4244
; CHECK-SPIRV: IAdd [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
45+
; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
46+
; CHECK-LLVM: %2 = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
47+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IAddPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
4348
define spir_kernel void @testIAdd(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
4449
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
4550
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
@@ -50,6 +55,7 @@ define spir_kernel void @testIAdd(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !
5055
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
5156
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
5257
; CHECK-SPIRV: ISub [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
58+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_ISubPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
5359
define spir_kernel void @testISub(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
5460
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
5561
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
@@ -60,6 +66,7 @@ define spir_kernel void @testISub(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !
6066
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
6167
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
6268
; CHECK-SPIRV: IMul [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
69+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IMulPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
6370
define spir_kernel void @testIMul(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
6471
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
6572
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
@@ -70,6 +77,7 @@ define spir_kernel void @testIMul(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !
7077
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
7178
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
7279
; CHECK-SPIRV: SDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
80+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_SDivPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
7381
define void @testSDiv(i32 %a, i32 %b) {
7482
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
7583
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
@@ -80,6 +88,7 @@ define void @testSDiv(i32 %a, i32 %b) {
8088
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
8189
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
8290
; CHECK-SPIRV: UDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
91+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_UDivPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
8392
define void @testUDiv(i32 %a, i32 %b) {
8493
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
8594
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
@@ -91,44 +100,50 @@ define void @testUDiv(i32 %a, i32 %b) {
91100
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
92101
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
93102
; CHECK-SPIRV: FAdd [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
103+
; CHECK-LLVM: %0 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
104+
; CHECK-LLVM: %1 = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
105+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAddPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
94106
define spir_kernel void @testFAdd(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
95107
entry:
96108
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
97109
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
98-
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
110+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
99111
ret void
100112
}
101113

102114
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
103115
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
104116
; CHECK-SPIRV: FSub [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
117+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSubPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
105118
define spir_kernel void @testFSub(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
106119
entry:
107120
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
108121
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
109-
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
122+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
110123
ret void
111124
}
112125

113126
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
114127
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
115128
; CHECK-SPIRV: FMul [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
129+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMulPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
116130
define spir_kernel void @testFMul(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
117131
entry:
118132
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
119133
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
120-
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
134+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
121135
ret void
122136
}
123137

124138
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
125139
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
126140
; CHECK-SPIRV: FDiv [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
141+
; CHECK-LLVM: %call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDivPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3S1_(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
127142
define spir_kernel void @testFDiv(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
128143
entry:
129144
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
130145
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
131-
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
146+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
132147
ret void
133148
}
134149

0 commit comments

Comments
 (0)