Skip to content

Commit c063b99

Browse files
MrSidimsbader
andauthored
[SYCL] Re-land "Represent JointMatrixINTEL type as extension type" (#9841)
This reverts commit 4447a50. Previous attempt: #8343 What changed: One extra patch is being added to the headers: ca0595b with this patch clang won't generate llvm.memcpy for trivial c'tor. So later on inst combine won't replace it with a cast to i64 followed by load + store which SROA + mem2reg won't be able to handle for target extension types. It adds: ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type which is represented as a pointer to a structure to LLVM extension type with the parameters that follow SPIR-V JointMatrixINTEL type. The expected representation is: target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, %use%, (optional) %element_type_interpretation%) Better approach is to introduce joint matrix type to clang, but it's off the table now, since we are lacking OpenCL spec. Co-authored-by: Joshua Cranmer <[email protected]> --------- Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Alexey Bader <[email protected]>
1 parent 5f0dbfe commit c063b99

File tree

9 files changed

+161
-88
lines changed

9 files changed

+161
-88
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

+83-59
Original file line numberDiff line numberDiff line change
@@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
5151
StringRef suffix) {
5252
SmallString<256> TypeName;
5353
llvm::raw_svector_ostream OS(TypeName);
54-
// If RD is spirv_JointMatrixINTEL type, mangle differently.
55-
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
56-
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
57-
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
58-
ArrayRef<TemplateArgument> TemplateArgs =
59-
TemplateDecl->getTemplateArgs().asArray();
60-
OS << "spirv.JointMatrixINTEL.";
61-
for (auto &TemplateArg : TemplateArgs) {
62-
OS << "_";
63-
if (TemplateArg.getKind() == TemplateArgument::Type) {
64-
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
65-
if (TTy->isIntegerTy()) {
66-
switch (TTy->getIntegerBitWidth()) {
67-
case 8:
68-
OS << "char";
69-
break;
70-
case 16:
71-
OS << "short";
72-
break;
73-
case 32:
74-
OS << "int";
75-
break;
76-
case 64:
77-
OS << "long";
78-
break;
79-
default:
80-
OS << "i" << TTy->getIntegerBitWidth();
81-
break;
82-
}
83-
} else if (TTy->isHalfTy()) {
84-
OS << "half";
85-
} else if (TTy->isFloatTy()) {
86-
OS << "float";
87-
} else if (TTy->isDoubleTy()) {
88-
OS << "double";
89-
} else if (TTy->isBFloatTy()) {
90-
OS << "bfloat16";
91-
} else if (TTy->isStructTy()) {
92-
StringRef LlvmTyName = TTy->getStructName();
93-
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
94-
if (LlvmTyName.startswith("class.sycl::") ||
95-
LlvmTyName.startswith("class.__sycl_internal::"))
96-
LlvmTyName = LlvmTyName.rsplit("::").second;
97-
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
98-
LlvmTyName != "tf32")
99-
llvm_unreachable("Wrong matrix base type!");
100-
OS << LlvmTyName;
101-
} else {
102-
llvm_unreachable("Wrong matrix base type!");
103-
}
104-
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
105-
OS << TemplateArg.getAsIntegral();
106-
}
107-
}
108-
Ty->setName(OS.str());
109-
return;
110-
}
111-
}
112-
}
11354
OS << RD->getKindName() << '.';
11455

11556
// FIXME: We probably want to make more tweaks to the printing policy. For
@@ -460,6 +401,77 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
460401
return ResultType;
461402
}
462403

404+
template <bool NeedTypeInterpret = false>
405+
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
406+
ArrayRef<TemplateArgument> TemplateArgs,
407+
const unsigned Val = 0) {
408+
// TODO: we should actually have exactly 5 template parameters: 1 for
409+
// type and 4 for type parameters. But in previous version of the SPIR-V
410+
// spec we have Layout matrix type parameter, that was later removed.
411+
// Once we update to the newest version of the spec - this should be updated.
412+
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
413+
"Wrong JointMatrixINTEL template parameters number");
414+
// This is required to represent optional 'Component Type Interpretation'
415+
// parameter
416+
std::vector<unsigned> Params;
417+
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
418+
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
419+
"Wrong JointMatrixINTEL template parameter");
420+
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
421+
}
422+
// Don't add type interpretation for legacy matrices.
423+
// Legacy matrices has 5 template parameters, while new representation
424+
// has 6.
425+
if (NeedTypeInterpret && TemplateArgs.size() != 5)
426+
Params.push_back(Val);
427+
428+
return llvm::TargetExtType::get(CompTy->getContext(),
429+
"spirv.JointMatrixINTEL", {CompTy}, Params);
430+
}
431+
432+
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
433+
/// which is represented as a pointer to a structure to LLVM extension type
434+
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
435+
/// The expected representation is:
436+
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
437+
/// %use%, (optional) %element_type_interpretation%)
438+
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
439+
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
440+
ArrayRef<TemplateArgument> TemplateArgs =
441+
TemplateDecl->getTemplateArgs().asArray();
442+
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
443+
"1st JointMatrixINTEL template parameter must be type");
444+
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
445+
446+
// Per JointMatrixINTEL spec the type can have an optional
447+
// 'Component Type Interpretation' parameter. We should emit it in case
448+
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
449+
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
450+
// 'tf32' as 'float' types.
451+
if (CompTy->isStructTy()) {
452+
StringRef LlvmTyName = CompTy->getStructName();
453+
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
454+
if (LlvmTyName.startswith("class.sycl::") ||
455+
LlvmTyName.startswith("class.__sycl_internal::"))
456+
LlvmTyName = LlvmTyName.rsplit("::").second;
457+
if (LlvmTyName == "half") {
458+
CompTy = llvm::Type::getHalfTy(getLLVMContext());
459+
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
460+
} else if (LlvmTyName == "tf32") {
461+
CompTy = llvm::Type::getFloatTy(getLLVMContext());
462+
// 'tf32' interpretation is mapped to '0'
463+
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
464+
} else if (LlvmTyName == "bfloat16") {
465+
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
466+
// 'bfloat16' interpretation is mapped to '1'
467+
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
468+
} else {
469+
llvm_unreachable("Wrong matrix base type!");
470+
}
471+
}
472+
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
473+
}
474+
463475
/// ConvertType - Convert the specified type to its LLVM form.
464476
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
465477
T = Context.getCanonicalType(T);
@@ -754,6 +766,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
754766
llvm::Type *PointeeType = ConvertTypeForMem(ETy);
755767
if (PointeeType->isVoidTy())
756768
PointeeType = llvm::Type::getInt8Ty(getLLVMContext());
769+
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
770+
const Type *ClangETy = ETy.getTypePtrOrNull();
771+
if (ClangETy && ClangETy->isStructureOrClassType()) {
772+
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
773+
if (RD && RD->getQualifiedNameAsString() ==
774+
"__spv::__spirv_JointMatrixINTEL") {
775+
ResultType = ConvertSYCLJointMatrixINTELType(RD);
776+
break;
777+
}
778+
}
779+
}
780+
757781
unsigned AS = getTargetAddressSpace(ETy);
758782
ResultType = llvm::PointerType::get(PointeeType, AS);
759783
break;

clang/lib/CodeGen/CodeGenTypes.h

+8
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ class CodeGenTypes {
133133
/// memory representation is usually i8 or i32, depending on the target.
134134
llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false);
135135

136+
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
137+
/// which is represented as a pointer to a structure to LLVM extension type
138+
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
139+
/// The expected representation is:
140+
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
141+
/// %use%, (optional) %element_type_interpretation%)
142+
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);
143+
136144
/// GetFunctionType - Get the LLVM function type for \arg Info.
137145
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);
138146

clang/test/CodeGenSYCL/matrix.cpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
#include <stdint.h>
66

77
namespace __spv {
8-
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
8+
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
99
struct __spirv_JointMatrixINTEL;
1010
}
1111

12-
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
13-
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
12+
// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0)
13+
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {}
1414

15-
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
16-
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
15+
// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0)
16+
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {}
1717

18-
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
19-
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
18+
// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0)
19+
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {}
2020

2121
namespace sycl {
2222
class half {};
@@ -25,17 +25,17 @@ namespace sycl {
2525
}
2626
typedef sycl::half my_half;
2727

28-
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
29-
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
28+
// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0)
29+
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {}
3030

31-
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
32-
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
31+
// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1)
32+
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {}
3333

34-
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
35-
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
34+
// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0)
35+
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {}
3636

37-
// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
38-
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}
37+
// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0)
38+
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {}
3939

40-
// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
41-
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}
40+
// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0)
41+
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {}

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

+18
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,24 @@ struct joint_matrix {
6969
get_wi_data() {
7070
return wi_data<T, NumRows, NumCols, Layout, Group>(*this);
7171
}
72+
73+
#ifdef __SYCL_DEVICE_ONLY__
74+
#if defined(__SPIR__)
75+
// Generate a non-trivial assignment operator and copy c'tor that prevents
76+
// memcpy from being generated.
77+
// TODO: to remove, when either IGC can handle alloca JointMatrix or
78+
// combination of InstCombine + SROA + mem2reg can remove it
79+
joint_matrix(const joint_matrix &other) {
80+
spvm = other.spvm;
81+
return *this;
82+
}
83+
84+
joint_matrix &operator=(const joint_matrix &rhs) {
85+
spvm = rhs.spvm;
86+
return *this;
87+
}
88+
#endif // defined(__SPIR__)
89+
#endif
7290
};
7391

7492
template <typename Group, typename T, size_t NumRows, size_t NumCols,

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ struct joint_matrix {
4242
PI_ERROR_INVALID_DEVICE);
4343
#endif
4444
}
45+
#ifdef __SYCL_DEVICE_ONLY__
46+
#if defined(__SPIR__)
47+
// Generate a non-trivial assignment operator and copy c'tor that prevents
48+
// memcpy from being generated.
49+
// TODO: to remove, when either IGC can handle alloca JointMatrix or
50+
// combination of InstCombine + SROA + mem2reg can remove it
51+
joint_matrix(const joint_matrix &other) {
52+
spvm = other.spvm;
53+
return *this;
54+
}
55+
56+
joint_matrix &operator=(const joint_matrix &rhs) {
57+
spvm = rhs.spvm;
58+
return *this;
59+
}
60+
#endif // defined(__SPIR__)
61+
#endif
4562
};
4663

4764
#ifdef __SYCL_DEVICE_ONLY__

sycl/test/check_device_code/matrix/matrix_load_store_as.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: %clangxx -fsycl-device-only -S -emit-llvm -o - %s | FileCheck %s
22

3+
// Check that SROA and mem2reg won't leave alloca of matrix type in IR
4+
// CHECK-NOT: alloca target("spirv.JointMatrixINTEL"
5+
36
// check that correct address spaces are used to load from and store to
47
#define SYCL_EXT_ONEAPI_MATRIX_VERSION 4
58
#include <sycl/sycl.hpp>
@@ -39,16 +42,16 @@ int main(void) {
3942
it.barrier(access::fence_space::local_space);
4043

4144
// A should load from local address space
42-
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_8_16_0_3_0 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
45+
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 8, 16, 0, 3, 0) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
4346
joint_matrix_load(
4447
sg, tA,
4548
tileA.template get_multi_ptr<sycl::access::decorated::yes>(), 16);
4649
// B should load from global address space
47-
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_16_16_2_3_1 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}}
50+
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}}
4851
joint_matrix_load(sg, tB, pB, 32);
4952
tC = joint_matrix_mad(sg, tA, tB, tC);
5053
// C should store to global address space
51-
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, %spirv.JointMatrixINTEL._float_8_16_3_3_2 addrspace(4)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
54+
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
5255
joint_matrix_store(sg, tC, pC, 16, layout::row_major);
5356
});
5457
});

sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: %clangxx -fsycl-device-only -S -emit-llvm -o - %s | FileCheck %s
22

3+
// Check that SROA and mem2reg won't leave alloca of matrix type in IR
4+
// CHECK-NOT: alloca target("spirv.JointMatrixINTEL"
5+
36
// check that correct address spaces are used to load from and store to
47
#define SYCL_EXT_ONEAPI_MATRIX_VERSION 1
58
#include <sycl/sycl.hpp>
@@ -36,17 +39,17 @@ int main(void) {
3639
it.barrier(access::fence_space::local_space);
3740

3841
// A should load from local address space
39-
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_8_16_0_3 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
42+
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 8, 16, 0, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(3)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
4043
joint_matrix_load(
4144
sg, tA,
4245
tileA.template get_multi_ptr<sycl::access::decorated::yes>(), 16,
4346
matrix_layout::row_major);
4447
// B should load from global address space
45-
// CHECK: %{{.*}} = tail call spir_func noundef %spirv.JointMatrixINTEL._short_16_16_3_3 addrspace(4)* @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}}
48+
// CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 3, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(i16 addrspace(1)* noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}}
4649
joint_matrix_load(sg, tB, pB, 32, matrix_layout::packed_b);
4750
tC = joint_matrix_mad(sg, tA, tB, tC);
4851
// C should store to global address space
49-
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, %spirv.JointMatrixINTEL._float_8_16_0_3 addrspace(4)* noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
52+
// CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(float addrspace(1)* noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 0, 3) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}}
5053
joint_matrix_store(sg, tC, pC, 16, matrix_layout::row_major);
5154
});
5255
});

sycl/test/matrix/legacy/matrix-int8-test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque
4-
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque
5-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque
3+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3)
4+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3)
5+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3)
66

77
#include <iostream>
88
#include <sycl/sycl.hpp>

sycl/test/matrix/matrix-int8-test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque
4-
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque
5-
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque
3+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
4+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2)
5+
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1)
66

77
#include <iostream>
88
#include <sycl/sycl.hpp>

0 commit comments

Comments
 (0)