Skip to content

Commit deb6ee9

Browse files
authored
Reverse translation of access + load/store operations for cooperative matrix (#2165)
Implement translation via SPIR-V friendly calls, as: the LLVM instructions are not capable to accept target extension types; cooperative matrix is an opaque object and accessing elements is implementation defined, hence we can't use GEP to which AccessChain naturally maps, since GEP has a different meaning. As for now some BE would need to recognize and define what to do with a call to __spirv_AccessChain(matrix, index). Better option is to map such SPIR-V to an intrinsic or define an appropriate type in LLVM (hence defining rules for GEP and other instructions) , but it's off the table now.
1 parent fdc961f commit deb6ee9

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

lib/SPIRV/SPIRVReader.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
21822182
auto *AC = static_cast<SPIRVAccessChainBase *>(BV);
21832183
auto *Base = transValue(AC->getBase(), F, BB);
21842184
SPIRVType *BaseSPVTy = AC->getBase()->getType();
2185+
if (BaseSPVTy->isTypePointer() &&
2186+
BaseSPVTy->getPointerElementType()->isTypeCooperativeMatrixKHR()) {
2187+
return mapValue(BV, transSPIRVBuiltinFromInst(AC, BB));
2188+
}
21852189
Type *BaseTy =
21862190
BaseSPVTy->isTypeVector()
21872191
? transType(

test/extensions/KHR/SPV_KHR_cooperative_matrix/access_store.ll

+16-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
44
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
55

6-
; TODO: come up with an approach and implement reverse translation
7-
; R/UN: llvm-spirv -r %t.spv -o %t.rev.bc
8-
; R/UN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
98

109
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
1110
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const0:]] 0
@@ -15,28 +14,38 @@
1514
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const42:]] 42
1615

1716
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#TypeMatrix:]] [[#TypeInt]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const0]]
18-
; CHECK-SPIRV: TypePointer [[#Type:]] 7 [[#TypeInt]]
17+
; CHECK-SPIRV: TypePointer [[#TypeMatrixPtr:]] 7 [[#TypeMatrix]]
18+
; CHECK-SPIRV: TypePointer [[#TypeIntPtr:]] 7 [[#TypeInt]]
1919

20+
; CHECK-SPIRV: Variable [[#TypeMatrixPtr]] [[#VarMatrixPtr:]] 7
2021
; CHECK-SPIRV: CompositeConstruct [[#TypeMatrix]] [[#Composite:]] [[#Const0]]
21-
; CHECK-SPIRV: AccessChain [[#Type]] [[#Res:]] [[#Composite]] [[#Const1]]
22+
; CHECK-SPIRV: Store [[#VarMatrixPtr]] [[#Composite]]
23+
; CHECK-SPIRV: AccessChain [[#TypeIntPtr]] [[#Res:]] [[#VarMatrixPtr]] [[#Const1]]
2224
; CHECK-SPIRV: Store [[#Res]] [[#Const42]]
2325

26+
; CHECK-LLVM: %0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0)
27+
; CHECK-LLVM: %Obj = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstructi(i32 0)
28+
; CHECK-LLVM: store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0
29+
; CHECK-LLVM: %call = call spir_func ptr @_Z19__spirv_AccessChainPPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_0i(ptr %0, i32 1)
30+
; CHECK-LLVM: store i32 42, ptr %call
2431

2532
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
2633
target triple = "spir64-unknown-unknown"
2734

2835
; Function Attrs: mustprogress uwtable
2936
define dso_local void @_Z3fooi(i32 noundef %idx) local_unnamed_addr #0 {
3037
entry:
38+
%0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0), align 8
3139
%Obj = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef 0) #4
32-
%call = call noundef ptr @_Z19__spirv_AccessChainP6Matrixii(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, i32 noundef 1)
40+
store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0, align 8
41+
%call = call noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr %0, i32 noundef 1)
3342
call void @_Z13__spirv_StorePii(ptr noundef %call, i32 noundef 42)
3443
ret void
3544
}
3645

3746
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2
3847

39-
declare noundef ptr @_Z19__spirv_AccessChainP6Matrixii(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) noundef, i32 noundef) local_unnamed_addr #2
48+
declare noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr noundef, i32 noundef) local_unnamed_addr #2
4049

4150
declare void @_Z13__spirv_StorePii(ptr noundef, i32 noundef) local_unnamed_addr #2
4251

0 commit comments

Comments
 (0)