@@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
51
51
StringRef suffix) {
52
52
SmallString<256 > TypeName;
53
53
llvm::raw_svector_ostream OS (TypeName);
54
- // If RD is spirv_JointMatrixINTEL type, mangle differently.
55
- if (CGM.getTriple ().isSPIRV () || CGM.getTriple ().isSPIR ()) {
56
- if (RD->getQualifiedNameAsString () == " __spv::__spirv_JointMatrixINTEL" ) {
57
- if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
58
- ArrayRef<TemplateArgument> TemplateArgs =
59
- TemplateDecl->getTemplateArgs ().asArray ();
60
- OS << " spirv.JointMatrixINTEL." ;
61
- for (auto &TemplateArg : TemplateArgs) {
62
- OS << " _" ;
63
- if (TemplateArg.getKind () == TemplateArgument::Type) {
64
- llvm::Type *TTy = ConvertType (TemplateArg.getAsType ());
65
- if (TTy->isIntegerTy ()) {
66
- switch (TTy->getIntegerBitWidth ()) {
67
- case 8 :
68
- OS << " char" ;
69
- break ;
70
- case 16 :
71
- OS << " short" ;
72
- break ;
73
- case 32 :
74
- OS << " int" ;
75
- break ;
76
- case 64 :
77
- OS << " long" ;
78
- break ;
79
- default :
80
- OS << " i" << TTy->getIntegerBitWidth ();
81
- break ;
82
- }
83
- } else if (TTy->isHalfTy ()) {
84
- OS << " half" ;
85
- } else if (TTy->isFloatTy ()) {
86
- OS << " float" ;
87
- } else if (TTy->isDoubleTy ()) {
88
- OS << " double" ;
89
- } else if (TTy->isBFloatTy ()) {
90
- OS << " bfloat16" ;
91
- } else if (TTy->isStructTy ()) {
92
- StringRef LlvmTyName = TTy->getStructName ();
93
- // Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
94
- if (LlvmTyName.startswith (" class.sycl::" ) ||
95
- LlvmTyName.startswith (" class.__sycl_internal::" ))
96
- LlvmTyName = LlvmTyName.rsplit (" ::" ).second ;
97
- if (LlvmTyName != " half" && LlvmTyName != " bfloat16" &&
98
- LlvmTyName != " tf32" )
99
- llvm_unreachable (" Wrong matrix base type!" );
100
- OS << LlvmTyName;
101
- } else {
102
- llvm_unreachable (" Wrong matrix base type!" );
103
- }
104
- } else if (TemplateArg.getKind () == TemplateArgument::Integral) {
105
- OS << TemplateArg.getAsIntegral ();
106
- }
107
- }
108
- Ty->setName (OS.str ());
109
- return ;
110
- }
111
- }
112
- }
113
54
OS << RD->getKindName () << ' .' ;
114
55
115
56
// FIXME: We probably want to make more tweaks to the printing policy. For
@@ -460,6 +401,79 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
460
401
return ResultType;
461
402
}
462
403
404
+ template <bool NeedTypeInterpret = false >
405
+ llvm::Type* getJointMatrixINTELExtType (llvm::Type *CompTy,
406
+ ArrayRef<TemplateArgument> TemplateArgs,
407
+ const unsigned Val = 0 ) {
408
+ // TODO: we should actually have exactly 5 template parameters: 1 for
409
+ // type and 4 for type parameters. But in previous version of the SPIR-V
410
+ // spec we have Layout matrix type parameter, that was later removed.
411
+ // Once we update to the newest version of the spec - this should be updated.
412
+ assert ((TemplateArgs.size () == 5 || TemplateArgs.size () == 6 ) &&
413
+ " Wrong JointMatrixINTEL template parameters number" );
414
+ // This is required to represent optional Optional
415
+ // 'Component Type Interpretation' parameter
416
+ using ParamsType =
417
+ typename std::conditional<NeedTypeInterpret, SmallVector<unsigned , 6 >,
418
+ SmallVector<unsigned , 5 >>::type;
419
+ ParamsType Params;
420
+ if constexpr (NeedTypeInterpret)
421
+ Params = { 0 , 0 , 0 , 0 , 0 , Val};
422
+ else
423
+ Params = { 0 , 0 , 0 , 0 , 0 };
424
+ for (size_t I = 1 ; I != TemplateArgs.size (); ++I) {
425
+ assert (TemplateArgs[I].getKind () ==
426
+ TemplateArgument::Integral &&
427
+ " Wrong JointMatrixINTEL template parameter" );
428
+ Params[I - 1 ] = TemplateArgs[I].getAsIntegral ().getExtValue ();
429
+ }
430
+ return llvm::TargetExtType::get (
431
+ CompTy->getContext (), " spirv.JointMatrixINTEL" , {CompTy}, Params);
432
+ }
433
+
434
+ // / ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
435
+ // / which is represented as a pointer to a structure to LLVM extension type
436
+ // / with the parameters that follow SPIR-V JointMatrixINTEL type.
437
+ // / The expected representation is:
438
+ // / target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
439
+ // / %use%, (optional) %element_type_interpretation%)
440
+ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType (RecordDecl *RD) {
441
+ auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
442
+ ArrayRef<TemplateArgument> TemplateArgs =
443
+ TemplateDecl->getTemplateArgs ().asArray ();
444
+ assert (TemplateArgs[0 ].getKind () == TemplateArgument::Type &&
445
+ " 1st JointMatrixINTEL template parameter must be type" );
446
+ llvm::Type *CompTy = ConvertType (TemplateArgs[0 ].getAsType ());
447
+
448
+ // Per JointMatrixINTEL spec the type can have an Optional
449
+ // 'Component Type Interpretation' parameter. We should emit it in case
450
+ // if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
451
+ // matrix's components. Yet bfloat16 should be represented as 'int16' and
452
+ // 'tf32' as 'float' types.
453
+ if (CompTy->isStructTy ()) {
454
+ StringRef LlvmTyName = CompTy->getStructName ();
455
+ // Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
456
+ if (LlvmTyName.startswith (" class.sycl::" ) ||
457
+ LlvmTyName.startswith (" class.__sycl_internal::" ))
458
+ LlvmTyName = LlvmTyName.rsplit (" ::" ).second ;
459
+ if (LlvmTyName == " half" ) {
460
+ CompTy = llvm::Type::getHalfTy (getLLVMContext ());
461
+ return getJointMatrixINTELExtType (CompTy, TemplateArgs);
462
+ } else if (LlvmTyName == " tf32" ) {
463
+ CompTy = llvm::Type::getHalfTy (getLLVMContext ());
464
+ // 'tf32' interpretation is mapped to '0'
465
+ return getJointMatrixINTELExtType<true >(CompTy, TemplateArgs, 0 );
466
+ } else if (LlvmTyName == " bfloat16" ) {
467
+ CompTy = llvm::Type::getInt16Ty (getLLVMContext ());
468
+ // 'bfloat16' interpretation is mapped to '1'
469
+ return getJointMatrixINTELExtType<true >(CompTy, TemplateArgs, 1 );
470
+ } else {
471
+ llvm_unreachable (" Wrong matrix base type!" );
472
+ }
473
+ }
474
+ return getJointMatrixINTELExtType (CompTy, TemplateArgs);
475
+ }
476
+
463
477
// / ConvertType - Convert the specified type to its LLVM form.
464
478
llvm::Type *CodeGenTypes::ConvertType (QualType T) {
465
479
T = Context.getCanonicalType (T);
@@ -745,6 +759,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
745
759
llvm::Type *PointeeType = ConvertTypeForMem (ETy);
746
760
if (PointeeType->isVoidTy ())
747
761
PointeeType = llvm::Type::getInt8Ty (getLLVMContext ());
762
+ if (CGM.getTriple ().isSPIRV () || CGM.getTriple ().isSPIR ()) {
763
+ const Type* ClangETy = ETy.getTypePtrOrNull ();
764
+ if (ClangETy && ClangETy->isStructureOrClassType ()) {
765
+ RecordDecl *RD = ClangETy->getAsCXXRecordDecl ();
766
+ if (RD->getQualifiedNameAsString () ==
767
+ " __spv::__spirv_JointMatrixINTEL" ) {
768
+ ResultType = ConvertSYCLJointMatrixINTELType (RD);
769
+ break ;
770
+ }
771
+ }
772
+ }
773
+
748
774
unsigned AS = getTargetAddressSpace (ETy);
749
775
ResultType = llvm::PointerType::get (PointeeType, AS);
750
776
break ;
0 commit comments