Skip to content

Commit ee03f5f

Browse files
authored
[OpaquePointers] Add support for JointMatrixINTEL target ext type (#1852)
The expected representation is: target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, %use%, (optional) %element_type_interpretation%) TODO: figure out, how to deal with the switch from old API (Matrix has Layout) to new API (Layout was removed) Depends on: #1799 intel/llvm#8343
1 parent b0c6c70 commit ee03f5f

File tree

2 files changed

+164
-1
lines changed

2 files changed

+164
-1
lines changed

lib/SPIRV/SPIRVWriter.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
481481
auto CastAccess = [](unsigned Val) {
482482
return static_cast<SPIRVAccessQualifierKind>(Val);
483483
};
484-
switch (Opcode) {
484+
switch (static_cast<size_t>(Opcode)) {
485485
case OpTypePipe: {
486486
auto *PipeT = BM->addPipeType();
487487
PipeT->setPipeAcessQualifier(CastAccess(TargetTy->getIntParameter(0)));
@@ -509,6 +509,18 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
509509
return mapType(T, BM->addQueueType());
510510
case OpTypeDeviceEvent:
511511
return mapType(T, BM->addDeviceEventType());
512+
case internal::OpTypeJointMatrixINTEL: {
513+
// The expected representation is:
514+
// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%,
515+
// %layout%, %scope%, %use%,
516+
// (optional) %element_type_interpretation%)
517+
auto *ElemTy = transType(TargetTy->getTypeParameter(0));
518+
ArrayRef<unsigned> Ops = TargetTy->int_params();
519+
std::vector<SPIRVValue *> Args;
520+
for (const auto &Op : Ops)
521+
Args.emplace_back(transConstant(getUInt32(M, Op)));
522+
return mapType(T, BM->addJointMatrixINTELType(ElemTy, Args));
523+
}
512524
default:
513525
return mapType(T, BM->addOpaqueGenericType(Opcode));
514526
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_joint_matrix -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc -opaque-pointers=0
7+
; RUN: llvm-dis -opaque-pointers=0 < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; CHECK-SPIRV-DAG: Capability JointMatrixINTEL
10+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
11+
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
12+
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
13+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const12:]] 12
14+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3
15+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2
16+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0
17+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const48:]] 48
18+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const1:]] 1
19+
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy1:]] [[#Int32Ty]] [[#Const12]] [[#Const12]] [[#Const3]] [[#Const3]] [[#Const2]]
20+
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy2:]] [[#Int8Ty]] [[#Const12]] [[#Const48]] [[#Const0]] [[#Const3]] [[#Const0]]
21+
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy3:]] [[#Int8Ty]] [[#Const48]] [[#Const12]] [[#Const2]] [[#Const3]] [[#Const1]]
22+
23+
; CHECK-LLVM-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque
24+
; CHECK-LLVM-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque
25+
; CHECK-LLVM-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque
26+
27+
; ModuleID = 'test-matrix-opaque.bc'
28+
source_filename = "matrix-int8-test.cpp"
29+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
30+
target triple = "spir64-unknown-unknown"
31+
32+
%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
33+
%"class.sycl::_V1::detail::array" = type { [2 x i64] }
34+
%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" }
35+
36+
$_ZTSZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_E7imatrix = comdat any
37+
38+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
39+
@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
40+
41+
; Function Attrs: convergent norecurse
42+
define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_E7imatrix(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) noundef align 4 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K) local_unnamed_addr #0 comdat {
43+
entry:
44+
%sub_c.sroa.0.i = alloca target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), align 8
45+
%ref.tmp29.sroa.0.i = alloca target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), align 8
46+
%agg.tmp15.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::range", ptr %_arg_accB5, i64 0, i32 0, i32 0, i64 1
47+
%agg.tmp15.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp15.sroa.0.sroa.2.0..sroa_idx, align 8
48+
%0 = getelementptr inbounds %"class.sycl::_V1::id", ptr %_arg_accB6, i64 0, i32 0, i32 0, i64 0
49+
%agg.tmp16.sroa.0.sroa.0.0.copyload = load i64, ptr %0, align 8
50+
%agg.tmp16.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::id", ptr %_arg_accB6, i64 0, i32 0, i32 0, i64 1
51+
%agg.tmp16.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp16.sroa.0.sroa.2.0..sroa_idx, align 8
52+
%mul.i4.i.i.i.i45 = mul i64 %agg.tmp16.sroa.0.sroa.0.0.copyload, %agg.tmp15.sroa.0.sroa.2.0.copyload
53+
%add.i6.i.i.i.i46 = add i64 %mul.i4.i.i.i.i45, %agg.tmp16.sroa.0.sroa.2.0.copyload
54+
%add.ptr.i47 = getelementptr inbounds i8, ptr addrspace(1) %_arg_accB, i64 %add.i6.i.i.i.i46
55+
%1 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
56+
%2 = extractelement <3 x i64> %1, i64 1
57+
%3 = extractelement <3 x i64> %1, i64 0
58+
%4 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32
59+
%5 = extractelement <3 x i64> %4, i64 1
60+
%6 = extractelement <3 x i64> %4, i64 0
61+
%cmp.i.i = icmp ult i64 %2, 2147483648
62+
%cmp.i54.i = icmp ult i64 %3, 2147483648
63+
%cmp.i56.i = icmp ult i64 %5, 2147483648
64+
%sub.i = sub nsw i64 %2, %5
65+
%cmp.i58.i = icmp ult i64 %6, 2147483648
66+
%sub5.i = sub nsw i64 %3, %6
67+
%sub_c.sroa.0.i.0.i.0..sroa_cast = bitcast ptr %sub_c.sroa.0.i to ptr
68+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %sub_c.sroa.0.i.0.i.0..sroa_cast)
69+
%call.i.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z26__spirv_CompositeConstructIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEES6_(i32 noundef 0) #4
70+
store target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) %call.i.i, ptr %sub_c.sroa.0.i, align 8
71+
%mul.i = mul nsw i64 %sub.i, 12
72+
%div2452.i = lshr i64 %sub5.i, 4
73+
%mul26.i = mul i64 %div2452.i, 48
74+
%div.i = udiv i64 %_arg_K, 48
75+
%mul11.i = mul i64 %mul.i, %_arg_K
76+
%add.ptr.i93.i = getelementptr inbounds i8, ptr addrspace(1) %_arg_accA, i64 %mul11.i
77+
%idx.neg.i.i104.i = sub i64 0, %add.i6.i.i.i.i46
78+
%add.ptr.i.i105141.i = getelementptr i8, ptr addrspace(1) %add.ptr.i47, i64 %mul26.i
79+
%mul22.i = shl i64 %_arg_N, 2
80+
%add.ptr.i108140.i = getelementptr i8, ptr addrspace(1) %add.ptr.i.i105141.i, i64 %idx.neg.i.i104.i
81+
%ref.tmp29.sroa.0.i.0.i.0..sroa_cast = bitcast ptr %ref.tmp29.sroa.0.i to ptr
82+
%7 = bitcast ptr %ref.tmp29.sroa.0.i to ptr
83+
%8 = bitcast ptr %sub_c.sroa.0.i to ptr
84+
br label %for.cond.i
85+
86+
for.cond.i: ; preds = %for.body.i, %entry
87+
%k.0.i = phi i32 [ 0, %entry ], [ %add.i, %for.body.i ]
88+
%conv.i = zext i32 %k.0.i to i64
89+
%cmp.i = icmp ugt i64 %div.i, %conv.i
90+
br i1 %cmp.i, label %for.body.i, label %_ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_ENKUlNSA_7nd_itemILi2EEEE_clESF_.exit
91+
92+
for.body.i: ; preds = %for.cond.i
93+
%mul12.i = mul nsw i32 %k.0.i, 48
94+
%conv13.i = zext i32 %mul12.i to i64
95+
%add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i
96+
%call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4)
97+
%call1.i.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) @_Z28__spirv_JointMatrixLoadINTELIaLm12ELm48ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef %call.ascast.i66.i, i64 noundef %_arg_K, i32 noundef 0, i32 noundef 3, i32 noundef 0) #4
98+
%div20.i = mul nsw i32 %k.0.i, 12
99+
%conv21.i = zext i32 %div20.i to i64
100+
%mul23.i = mul i64 %mul22.i, %conv21.i
101+
%add.ptr.i111.i = getelementptr i8, ptr addrspace(1) %add.ptr.i108140.i, i64 %mul23.i
102+
%call.ascast.i72.i = addrspacecast ptr addrspace(1) %add.ptr.i111.i to ptr addrspace(4)
103+
%call1.i73.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) @_Z28__spirv_JointMatrixLoadINTELIaLm48ELm12ELN5__spv9MatrixUseE1ELNS0_12MatrixLayoutE2ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef %call.ascast.i72.i, i64 noundef %mul22.i, i32 noundef 2, i32 noundef 3, i32 noundef 0) #4
104+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i.0.i.0..sroa_cast)
105+
%sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i = load target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), ptr %sub_c.sroa.0.i, align 8
106+
%call.i77.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIaiLm12ELm48ELm12ELN5__spv9MatrixUseE0ELS1_1ELS1_2ELNS0_12MatrixLayoutE0ELS2_2ELS2_3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT9_EXT10_EXT6_EEEPNS5_IT_XT1_EXT2_EXT7_EXT10_EXT4_EEEPNS5_IS9_XT2_EXT3_EXT8_EXT10_EXT5_EEES8_S4_(target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) noundef %call1.i.i, target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) noundef %call1.i73.i, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i, i32 noundef 3) #4
107+
store target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) %call.i77.i, ptr %ref.tmp29.sroa.0.i, align 8
108+
%ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i = load i64, ptr %7, align 8
109+
store i64 %ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i, ptr %8, align 8
110+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i.0.i.0..sroa_cast)
111+
%add.i = add nuw nsw i32 %k.0.i, 1
112+
br label %for.cond.i
113+
114+
_ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_ENKUlNSA_7nd_itemILi2EEEE_clESF_.exit: ; preds = %for.cond.i
115+
%mul37.i = mul i64 %mul.i, %_arg_N
116+
%add.ptr.i.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_accC, i64 %mul37.i
117+
%mul39.i = mul nuw i64 %div2452.i, 12
118+
%add.ptr.i81.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i, i64 %mul39.i
119+
%call.ascast.i.i = addrspacecast ptr addrspace(1) %add.ptr.i81.i to ptr addrspace(4)
120+
%sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i = load target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), ptr %sub_c.sroa.0.i, align 8
121+
tail call spir_func void @_Z29__spirv_JointMatrixStoreINTELIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS5_XT0_EXT1_EXT3_EXT4_EXT2_EEEmS2_S4_i(ptr addrspace(4) noundef %call.ascast.i.i, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i64 noundef %_arg_N, i32 noundef 0, i32 noundef 3, i32 noundef 0) #4
122+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %sub_c.sroa.0.i.0.i.0..sroa_cast)
123+
ret void
124+
}
125+
126+
; Function Attrs: convergent
127+
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z26__spirv_CompositeConstructIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEES6_(i32 noundef) local_unnamed_addr #2
128+
129+
; Function Attrs: convergent
130+
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) @_Z28__spirv_JointMatrixLoadINTELIaLm12ELm48ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2
131+
132+
; Function Attrs: convergent
133+
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) @_Z28__spirv_JointMatrixLoadINTELIaLm48ELm12ELN5__spv9MatrixUseE1ELNS0_12MatrixLayoutE2ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2
134+
135+
; Function Attrs: convergent
136+
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIaiLm12ELm48ELm12ELN5__spv9MatrixUseE0ELS1_1ELS1_2ELNS0_12MatrixLayoutE0ELS2_2ELS2_3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT9_EXT10_EXT6_EEEPNS5_IT_XT1_EXT2_EXT7_EXT10_EXT4_EEEPNS5_IS9_XT2_EXT3_EXT8_EXT10_EXT5_EEES8_S4_(target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) noundef, target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) noundef, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef, i32 noundef) local_unnamed_addr #2
137+
138+
; Function Attrs: convergent
139+
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS5_XT0_EXT1_EXT3_EXT4_EXT2_EEEmS2_S4_i(ptr addrspace(4) noundef, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2
140+
141+
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
142+
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3
143+
144+
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
145+
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3
146+
147+
attributes #0 = { convergent norecurse "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="matrix-int8-test.cpp" "uniform-work-group-size"="true" }
148+
attributes #1 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) }
149+
attributes #2 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
150+
attributes #3 = { nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) }
151+
attributes #4 = { convergent }

0 commit comments

Comments
 (0)