@@ -51,6 +51,65 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
51
51
StringRef suffix) {
52
52
SmallString<256 > TypeName;
53
53
llvm::raw_svector_ostream OS (TypeName);
54
+ // If RD is spirv_JointMatrixINTEL type, mangle differently.
55
+ if (CGM.getTriple ().isSPIRV () || CGM.getTriple ().isSPIR ()) {
56
+ if (RD->getQualifiedNameAsString () == " __spv::__spirv_JointMatrixINTEL" ) {
57
+ if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
58
+ ArrayRef<TemplateArgument> TemplateArgs =
59
+ TemplateDecl->getTemplateArgs ().asArray ();
60
+ OS << " spirv.JointMatrixINTEL." ;
61
+ for (auto &TemplateArg : TemplateArgs) {
62
+ OS << " _" ;
63
+ if (TemplateArg.getKind () == TemplateArgument::Type) {
64
+ llvm::Type *TTy = ConvertType (TemplateArg.getAsType ());
65
+ if (TTy->isIntegerTy ()) {
66
+ switch (TTy->getIntegerBitWidth ()) {
67
+ case 8 :
68
+ OS << " char" ;
69
+ break ;
70
+ case 16 :
71
+ OS << " short" ;
72
+ break ;
73
+ case 32 :
74
+ OS << " int" ;
75
+ break ;
76
+ case 64 :
77
+ OS << " long" ;
78
+ break ;
79
+ default :
80
+ OS << " i" << TTy->getIntegerBitWidth ();
81
+ break ;
82
+ }
83
+ } else if (TTy->isHalfTy ()) {
84
+ OS << " half" ;
85
+ } else if (TTy->isFloatTy ()) {
86
+ OS << " float" ;
87
+ } else if (TTy->isDoubleTy ()) {
88
+ OS << " double" ;
89
+ } else if (TTy->isBFloatTy ()) {
90
+ OS << " bfloat16" ;
91
+ } else if (TTy->isStructTy ()) {
92
+ StringRef LlvmTyName = TTy->getStructName ();
93
+ // Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
94
+ if (LlvmTyName.startswith (" class.sycl::" ) ||
95
+ LlvmTyName.startswith (" class.__sycl_internal::" ))
96
+ LlvmTyName = LlvmTyName.rsplit (" ::" ).second ;
97
+ if (LlvmTyName != " half" && LlvmTyName != " bfloat16" &&
98
+ LlvmTyName != " tf32" )
99
+ llvm_unreachable (" Wrong matrix base type!" );
100
+ OS << LlvmTyName;
101
+ } else {
102
+ llvm_unreachable (" Wrong matrix base type!" );
103
+ }
104
+ } else if (TemplateArg.getKind () == TemplateArgument::Integral) {
105
+ OS << TemplateArg.getAsIntegral ();
106
+ }
107
+ }
108
+ Ty->setName (OS.str ());
109
+ return ;
110
+ }
111
+ }
112
+ }
54
113
OS << RD->getKindName () << ' .' ;
55
114
56
115
// FIXME: We probably want to make more tweaks to the printing policy. For
@@ -401,78 +460,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
401
460
return ResultType;
402
461
}
403
462
404
- template <bool NeedTypeInterpret = false >
405
- llvm::Type *getJointMatrixINTELExtType (llvm::Type *CompTy,
406
- ArrayRef<TemplateArgument> TemplateArgs,
407
- const unsigned Val = 0 ) {
408
- // TODO: we should actually have exactly 5 template parameters: 1 for
409
- // type and 4 for type parameters. But in previous version of the SPIR-V
410
- // spec we have Layout matrix type parameter, that was later removed.
411
- // Once we update to the newest version of the spec - this should be updated.
412
- assert ((TemplateArgs.size () == 5 || TemplateArgs.size () == 6 ) &&
413
- " Wrong JointMatrixINTEL template parameters number" );
414
- // This is required to represent optional 'Component Type Interpretation'
415
- // parameter
416
- using ParamsType =
417
- typename std::conditional<NeedTypeInterpret, SmallVector<unsigned , 6 >,
418
- SmallVector<unsigned , 5 >>::type;
419
- ParamsType Params;
420
- if constexpr (NeedTypeInterpret)
421
- Params = {0 , 0 , 0 , 0 , 0 , Val};
422
- else
423
- Params = {0 , 0 , 0 , 0 , 0 };
424
- for (size_t I = 1 ; I != TemplateArgs.size (); ++I) {
425
- assert (TemplateArgs[I].getKind () == TemplateArgument::Integral &&
426
- " Wrong JointMatrixINTEL template parameter" );
427
- Params[I - 1 ] = TemplateArgs[I].getAsIntegral ().getExtValue ();
428
- }
429
- return llvm::TargetExtType::get (CompTy->getContext (),
430
- " spirv.JointMatrixINTEL" , {CompTy}, Params);
431
- }
432
-
433
- // / ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
434
- // / which is represented as a pointer to a structure to LLVM extension type
435
- // / with the parameters that follow SPIR-V JointMatrixINTEL type.
436
- // / The expected representation is:
437
- // / target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
438
- // / %use%, (optional) %element_type_interpretation%)
439
- llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType (RecordDecl *RD) {
440
- auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
441
- ArrayRef<TemplateArgument> TemplateArgs =
442
- TemplateDecl->getTemplateArgs ().asArray ();
443
- assert (TemplateArgs[0 ].getKind () == TemplateArgument::Type &&
444
- " 1st JointMatrixINTEL template parameter must be type" );
445
- llvm::Type *CompTy = ConvertType (TemplateArgs[0 ].getAsType ());
446
-
447
- // Per JointMatrixINTEL spec the type can have an optional
448
- // 'Component Type Interpretation' parameter. We should emit it in case
449
- // if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
450
- // matrix's components. Yet 'bfloat16' should be represented as 'int16' and
451
- // 'tf32' as 'float' types.
452
- if (CompTy->isStructTy ()) {
453
- StringRef LlvmTyName = CompTy->getStructName ();
454
- // Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
455
- if (LlvmTyName.startswith (" class.sycl::" ) ||
456
- LlvmTyName.startswith (" class.__sycl_internal::" ))
457
- LlvmTyName = LlvmTyName.rsplit (" ::" ).second ;
458
- if (LlvmTyName == " half" ) {
459
- CompTy = llvm::Type::getHalfTy (getLLVMContext ());
460
- return getJointMatrixINTELExtType (CompTy, TemplateArgs);
461
- } else if (LlvmTyName == " tf32" ) {
462
- CompTy = llvm::Type::getFloatTy (getLLVMContext ());
463
- // 'tf32' interpretation is mapped to '0'
464
- return getJointMatrixINTELExtType<true >(CompTy, TemplateArgs, 0 );
465
- } else if (LlvmTyName == " bfloat16" ) {
466
- CompTy = llvm::Type::getInt16Ty (getLLVMContext ());
467
- // 'bfloat16' interpretation is mapped to '1'
468
- return getJointMatrixINTELExtType<true >(CompTy, TemplateArgs, 1 );
469
- } else {
470
- llvm_unreachable (" Wrong matrix base type!" );
471
- }
472
- }
473
- return getJointMatrixINTELExtType (CompTy, TemplateArgs);
474
- }
475
-
476
463
// / ConvertType - Convert the specified type to its LLVM form.
477
464
llvm::Type *CodeGenTypes::ConvertType (QualType T) {
478
465
T = Context.getCanonicalType (T);
@@ -758,18 +745,6 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
758
745
llvm::Type *PointeeType = ConvertTypeForMem (ETy);
759
746
if (PointeeType->isVoidTy ())
760
747
PointeeType = llvm::Type::getInt8Ty (getLLVMContext ());
761
- if (CGM.getTriple ().isSPIRV () || CGM.getTriple ().isSPIR ()) {
762
- const Type *ClangETy = ETy.getTypePtrOrNull ();
763
- if (ClangETy && ClangETy->isStructureOrClassType ()) {
764
- RecordDecl *RD = ClangETy->getAsCXXRecordDecl ();
765
- if (RD &&
766
- RD->getQualifiedNameAsString () == " __spv::__spirv_JointMatrixINTEL" ) {
767
- ResultType = ConvertSYCLJointMatrixINTELType (RD);
768
- break ;
769
- }
770
- }
771
- }
772
-
773
748
unsigned AS = getTargetAddressSpace (ETy);
774
749
ResultType = llvm::PointerType::get (PointeeType, AS);
775
750
break ;
0 commit comments