Skip to content

Commit 745423b

Browse files
authored
[SYCL] Add support for __registered_kernels__ (#16485)
This change adds support for a new attribute ``__sycl_detail::__registered_kernels__``, which appears at translation unit scope. The parameter for this attribute is a list of pairs like: ``` [[__sycl_detail__::__registered_kernels__( {"foo", foo}, {"(void(*)(int, int*))iota", (void(*)(int, int*))iota}, {"kernel<float>", kernel<float>} )]]; ``` The first element in each pair is a string, and the second element is a constant expressiton for a pointer to a SYCL free function kernel. The change creates the kernel's wrapper function and generates module-level metadata of the form: ``` !sycl_registered_kernels = !{!0, !1} !0 = !{!"foo", !"mangled-name-of-wrapper-for-foo"} !1 = !{!"kernel<float>", !"mangled-name-of-wrapper-for-kernel<float>"} ``` where the first element in the pair of strings, is the first element of the pair in ``__registered_kernels__`` and the second element is the mangled named of the wrapper corresponding to the free function.
1 parent 8af1eb3 commit 745423b

11 files changed

+411
-20
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,22 @@ def SYCLAddIRAnnotationsMember : InheritableAttr {
21472147
let Documentation = [SYCLAddIRAnnotationsMemberDocs];
21482148
}
21492149

2150+
def SYCLRegisteredKernels : InheritableAttr {
2151+
let Spellings = [CXX11<"__sycl_detail__", "__registered_kernels__">];
2152+
let Args = [VariadicExprArgument<"Args">];
2153+
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
2154+
let Subjects = SubjectList<[Empty], ErrorDiag, "Translation Unit Scope">;
2155+
let AdditionalMembers = SYCLAddIRAttrCommonMembers.MemberCode;
2156+
let Documentation = [SYCLAddIRAnnotationsMemberDocs];
2157+
}
2158+
2159+
def SYCLRegisteredKernelName : InheritableAttr {
2160+
let Spellings = [];
2161+
let Subjects = SubjectList<[Function]>;
2162+
let Args = [StringArgument<"RegName">];
2163+
let Documentation = [InternalOnly];
2164+
}
2165+
21502166
def C11NoReturn : InheritableAttr {
21512167
let Spellings = [CustomKeyword<"_Noreturn">];
21522168
let Subjects = SubjectList<[Function], ErrorDiag>;

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12546,6 +12546,20 @@ def err_sycl_special_type_num_init_method : Error<
1254612546
def warn_launch_bounds_is_cuda_specific : Warning<
1254712547
"%0 attribute ignored, only applicable when targeting Nvidia devices">,
1254812548
InGroup<IgnoredAttributes>;
12549+
def err_registered_kernels_num_of_args : Error<
12550+
"'__registered_kernels__' attribute must have at least one argument">;
12551+
def err_registered_kernels_init_list : Error<
12552+
"argument to the '__registered_kernels__' attribute must be an "
12553+
"initializer list expression">;
12554+
def err_registered_kernels_init_list_pair_values : Error<
12555+
"each initializer list argument to the '__registered_kernels__' attribute "
12556+
"must contain a pair of values">;
12557+
def err_registered_kernels_resolve_function : Error<
12558+
"unable to resolve free function kernel '%0'">;
12559+
def err_registered_kernels_name_already_registered : Error<
12560+
"free function kernel has already been registered with '%0'; cannot register with '%1'">;
12561+
def err_not_sycl_free_function : Error<
12562+
"attempting to register a function that is not a SYCL free function as '%0'">;
1254912563

1255012564
def warn_cuda_maxclusterrank_sm_90 : Warning<
1255112565
"'maxclusterrank' requires sm_90 or higher, CUDA arch provided: %0, ignoring "

clang/include/clang/Sema/SemaSYCL.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,9 @@ class SemaSYCL : public SemaBase {
252252
// We need to store the list of the sycl_kernel functions and their associated
253253
// generated OpenCL Kernels so we can go back and re-name these after the
254254
// fact.
255-
llvm::SmallVector<std::pair<const FunctionDecl *, FunctionDecl *>>
256-
SyclKernelsToOpenCLKernels;
255+
using KernelFDPairs =
256+
llvm::SmallVector<std::pair<const FunctionDecl *, FunctionDecl *>>;
257+
KernelFDPairs SyclKernelsToOpenCLKernels;
257258

258259
// Used to suppress diagnostics during kernel construction, since these were
259260
// already emitted earlier. Diagnosing during Kernel emissions also skips the
@@ -296,11 +297,15 @@ class SemaSYCL : public SemaBase {
296297
llvm::DenseSet<QualType> Visited,
297298
ValueDecl *DeclToCheck);
298299

300+
const KernelFDPairs &getKernelFDPairs() { return SyclKernelsToOpenCLKernels; }
301+
299302
void addSyclOpenCLKernel(const FunctionDecl *SyclKernel,
300303
FunctionDecl *OpenCLKernel) {
301304
SyclKernelsToOpenCLKernels.emplace_back(SyclKernel, OpenCLKernel);
302305
}
303306

307+
void constructFreeFunctionKernel(FunctionDecl *FD, StringRef NameStr = "");
308+
304309
void addSyclDeviceDecl(Decl *d) { SyclDeviceDecls.insert(d); }
305310
llvm::SetVector<Decl *> &syclDeviceDecls() { return SyclDeviceDecls; }
306311

@@ -480,6 +485,7 @@ class SemaSYCL : public SemaBase {
480485
void handleSYCLIntelMaxWorkGroupsPerMultiprocessor(Decl *D,
481486
const ParsedAttr &AL);
482487
void handleSYCLScopeAttr(Decl *D, const ParsedAttr &AL);
488+
void handleSYCLRegisteredKernels(Decl *D, const ParsedAttr &AL);
483489

484490
void checkSYCLAddIRAttributesFunctionAttrConflicts(Decl *D);
485491

@@ -655,6 +661,10 @@ class SemaSYCL : public SemaBase {
655661
void addIntelReqdSubGroupSizeAttr(Decl *D, const AttributeCommonInfo &CI,
656662
Expr *E);
657663
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);
664+
665+
// Used to check whether the function represented by FD is a SYCL
666+
// free function kernel or not.
667+
bool isFreeFunction(const FunctionDecl *FD);
658668
};
659669

660670
} // namespace clang

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,12 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
641641

642642
llvm::LLVMContext &Context = getLLVMContext();
643643

644+
if (getLangOpts().SYCLIsDevice)
645+
if (FD->hasAttr<SYCLRegisteredKernelNameAttr>())
646+
CGM.SYCLAddRegKernelNamePairs(
647+
FD->getAttr<SYCLRegisteredKernelNameAttr>()->getRegName(),
648+
FD->getNameAsString());
649+
644650
if (FD->hasAttr<OpenCLKernelAttr>() || FD->hasAttr<CUDAGlobalAttr>())
645651
CGM.GenKernelArgMetadata(Fn, FD, this);
646652

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,19 @@ void CodeGenModule::Release() {
14311431
AspectEnumValsMD->addOperand(
14321432
getAspectEnumValueMD(Context, TheModule.getContext(), ECD));
14331433
}
1434+
1435+
if (!SYCLRegKernelNames.empty()) {
1436+
std::vector<llvm::Metadata *> Nodes;
1437+
llvm::LLVMContext &Ctx = TheModule.getContext();
1438+
for (auto MDKernelNames : SYCLRegKernelNames) {
1439+
llvm::Metadata *Vals[] = {MDKernelNames.first, MDKernelNames.second};
1440+
Nodes.push_back(llvm::MDTuple::get(Ctx, Vals));
1441+
}
1442+
1443+
llvm::NamedMDNode *SYCLRegKernelsMD =
1444+
TheModule.getOrInsertNamedMetadata("sycl_registered_kernels");
1445+
SYCLRegKernelsMD->addOperand(llvm::MDNode::get(Ctx, Nodes));
1446+
}
14341447
}
14351448

14361449
// HLSL related end of code gen work items.

clang/lib/CodeGen/CodeGenModule.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,9 @@ class CodeGenModule : public CodeGenTypeCache {
456456
/// handled differently than regular annotations so they cannot share map.
457457
llvm::DenseMap<unsigned, llvm::Constant *> SYCLAnnotationArgs;
458458

459+
typedef std::pair<llvm::Metadata *, llvm::Metadata *> MetadataPair;
460+
SmallVector<MetadataPair, 4> SYCLRegKernelNames;
461+
459462
llvm::StringMap<llvm::GlobalVariable *> CFConstantStringMap;
460463

461464
llvm::DenseMap<llvm::Constant *, llvm::GlobalVariable *> ConstantStringMap;
@@ -1483,6 +1486,12 @@ class CodeGenModule : public CodeGenTypeCache {
14831486
llvm::Constant *EmitSYCLAnnotationArgs(
14841487
SmallVectorImpl<std::pair<std::string, std::string>> &Pairs);
14851488

1489+
void SYCLAddRegKernelNamePairs(StringRef First, StringRef Second) {
1490+
SYCLRegKernelNames.push_back(
1491+
std::make_pair(llvm::MDString::get(getLLVMContext(), First),
1492+
llvm::MDString::get(getLLVMContext(), Second)));
1493+
}
1494+
14861495
/// Add attributes from add_ir_attributes_global_variable on TND to GV.
14871496
void AddGlobalSYCLIRAttributes(llvm::GlobalVariable *GV,
14881497
const RecordDecl *RD);

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7479,6 +7479,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
74797479
case ParsedAttr::AT_SYCLAddIRAnnotationsMember:
74807480
S.SYCL().handleSYCLAddIRAnnotationsMemberAttr(D, AL);
74817481
break;
7482+
case ParsedAttr::AT_SYCLRegisteredKernels:
7483+
S.SYCL().handleSYCLRegisteredKernels(D, AL);
7484+
break;
74827485

74837486
// Swift attributes.
74847487
case ParsedAttr::AT_SwiftAsyncName:

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,10 +1148,10 @@ static target getAccessTarget(QualType FieldTy,
11481148

11491149
// FIXME: Free functions must have void return type and be declared at file
11501150
// scope, outside any namespaces.
1151-
static bool isFreeFunction(SemaSYCL &SemaSYCLRef, const FunctionDecl *FD) {
1151+
bool SemaSYCL::isFreeFunction(const FunctionDecl *FD) {
11521152
for (auto *IRAttr : FD->specific_attrs<SYCLAddIRAttributesFunctionAttr>()) {
11531153
SmallVector<std::pair<std::string, std::string>, 4> NameValuePairs =
1154-
IRAttr->getAttributeNameValuePairs(SemaSYCLRef.getASTContext());
1154+
IRAttr->getAttributeNameValuePairs(getASTContext());
11551155
for (const auto &NameValuePair : NameValuePairs) {
11561156
if (NameValuePair.first == "sycl-nd-range-kernel" ||
11571157
NameValuePair.first == "sycl-single-task-kernel") {
@@ -5291,7 +5291,7 @@ void SemaSYCL::SetSYCLKernelNames() {
52915291
SyclKernelsToOpenCLKernels) {
52925292
std::string CalculatedName, StableName;
52935293
StringRef KernelName;
5294-
if (isFreeFunction(*this, Pair.first)) {
5294+
if (isFreeFunction(Pair.first)) {
52955295
std::tie(CalculatedName, StableName) =
52965296
constructFreeFunctionKernelName(*this, Pair.first, *MangleCtx);
52975297
KernelName = CalculatedName;
@@ -5414,24 +5414,66 @@ void SemaSYCL::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
54145414
}
54155415
}
54165416

5417-
void ConstructFreeFunctionKernel(SemaSYCL &SemaSYCLRef, FunctionDecl *FD) {
5418-
SyclKernelArgsSizeChecker argsSizeChecker(SemaSYCLRef, FD->getLocation(),
5417+
static void addRegisteredKernelName(SemaSYCL &S, StringRef Str,
5418+
FunctionDecl *FD, SourceLocation Loc) {
5419+
if (!Str.empty())
5420+
FD->addAttr(SYCLRegisteredKernelNameAttr::CreateImplicit(S.getASTContext(),
5421+
Str, Loc));
5422+
}
5423+
5424+
static bool checkAndAddRegisteredKernelName(SemaSYCL &S, FunctionDecl *FD,
5425+
StringRef Str) {
5426+
using KernelPair = std::pair<const FunctionDecl *, FunctionDecl *>;
5427+
for (const KernelPair &Pair : S.getKernelFDPairs()) {
5428+
if (Pair.first == FD) {
5429+
// If the current list of free function entries already contains this
5430+
// free function, apply the name Str as an attribute. But if it already
5431+
// has an attribute name, issue a diagnostic instead.
5432+
if (!Str.empty()) {
5433+
if (!Pair.second->hasAttr<SYCLRegisteredKernelNameAttr>())
5434+
addRegisteredKernelName(S, Str, Pair.second, FD->getLocation());
5435+
else
5436+
S.Diag(FD->getLocation(),
5437+
diag::err_registered_kernels_name_already_registered)
5438+
<< Pair.second->getAttr<SYCLRegisteredKernelNameAttr>()
5439+
->getRegName()
5440+
<< Str;
5441+
}
5442+
// An empty name string implies a regular free kernel construction
5443+
// call, so simply return.
5444+
return false;
5445+
}
5446+
}
5447+
return true;
5448+
}
5449+
5450+
void SemaSYCL::constructFreeFunctionKernel(FunctionDecl *FD,
5451+
StringRef NameStr) {
5452+
if (!checkAndAddRegisteredKernelName(*this, FD, NameStr))
5453+
return;
5454+
5455+
SyclKernelArgsSizeChecker argsSizeChecker(*this, FD->getLocation(),
54195456
false /*IsSIMDKernel*/);
5420-
SyclKernelDeclCreator kernel_decl(SemaSYCLRef, FD->getLocation(),
5421-
FD->isInlined(), false /*IsSIMDKernel */,
5422-
FD);
5457+
SyclKernelDeclCreator kernel_decl(*this, FD->getLocation(), FD->isInlined(),
5458+
false /*IsSIMDKernel */, FD);
54235459

5424-
FreeFunctionKernelBodyCreator kernel_body(SemaSYCLRef, kernel_decl, FD);
5460+
FreeFunctionKernelBodyCreator kernel_body(*this, kernel_decl, FD);
54255461

5426-
SyclKernelIntHeaderCreator int_header(
5427-
SemaSYCLRef, SemaSYCLRef.getSyclIntegrationHeader(), FD->getType(), FD);
5462+
SyclKernelIntHeaderCreator int_header(*this, getSyclIntegrationHeader(),
5463+
FD->getType(), FD);
54285464

5429-
SyclKernelIntFooterCreator int_footer(SemaSYCLRef,
5430-
SemaSYCLRef.getSyclIntegrationFooter());
5431-
KernelObjVisitor Visitor{SemaSYCLRef};
5465+
SyclKernelIntFooterCreator int_footer(*this, getSyclIntegrationFooter());
5466+
KernelObjVisitor Visitor{*this};
54325467

54335468
Visitor.VisitFunctionParameters(FD, argsSizeChecker, kernel_decl, kernel_body,
54345469
int_header, int_footer);
5470+
5471+
assert(getKernelFDPairs().back().first == FD &&
5472+
"OpenCL Kernel not found for free function entry");
5473+
// Register the kernel name with the OpenCL kernel generated for the
5474+
// free function.
5475+
addRegisteredKernelName(*this, NameStr, getKernelFDPairs().back().second,
5476+
FD->getLocation());
54355477
}
54365478

54375479
// Figure out the sub-group for the this function. First we check the
@@ -5717,7 +5759,7 @@ void SemaSYCL::MarkDevices() {
57175759
}
57185760

57195761
void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
5720-
if (isFreeFunction(*this, FD)) {
5762+
if (isFreeFunction(FD)) {
57215763
SyclKernelDecompMarker DecompMarker(*this);
57225764
SyclKernelFieldChecker FieldChecker(*this);
57235765
SyclKernelUnionChecker UnionChecker(*this);
@@ -5736,7 +5778,7 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
57365778
if (!FieldChecker.isValid() || !UnionChecker.isValid())
57375779
return;
57385780

5739-
ConstructFreeFunctionKernel(*this, FD);
5781+
constructFreeFunctionKernel(FD);
57405782
}
57415783
}
57425784

@@ -6621,7 +6663,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
66216663
unsigned ShimCounter = 1;
66226664
int FreeFunctionCount = 0;
66236665
for (const KernelDesc &K : KernelDescs) {
6624-
if (!isFreeFunction(S, K.SyclKernel))
6666+
if (!S.isFreeFunction(K.SyclKernel))
66256667
continue;
66266668
++FreeFunctionCount;
66276669
// Generate forward declaration for free function.
@@ -6739,7 +6781,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
67396781
}
67406782
ShimCounter = 1;
67416783
for (const KernelDesc &K : KernelDescs) {
6742-
if (!isFreeFunction(S, K.SyclKernel))
6784+
if (!S.isFreeFunction(K.SyclKernel))
67436785
continue;
67446786

67456787
O << "\n// Definition of kernel_id of " << K.Name << "\n";

clang/lib/Sema/SemaSYCLDeclAttr.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3162,3 +3162,68 @@ void SemaSYCL::checkSYCLAddIRAttributesFunctionAttrConflicts(Decl *D) {
31623162
Diag(Attr->getLoc(), diag::warn_sycl_old_and_new_kernel_attributes)
31633163
<< Attr;
31643164
}
3165+
3166+
void SemaSYCL::handleSYCLRegisteredKernels(Decl *D, const ParsedAttr &A) {
3167+
// Check for SYCL device compilation context.
3168+
if (!getLangOpts().SYCLIsDevice)
3169+
return;
3170+
3171+
unsigned NumArgs = A.getNumArgs();
3172+
// When declared, we expect at least one item in the list.
3173+
if (NumArgs == 0) {
3174+
Diag(A.getLoc(), diag::err_registered_kernels_num_of_args);
3175+
return;
3176+
}
3177+
3178+
// Traverse through the items in the list.
3179+
for (unsigned I = 0; I < NumArgs; I++) {
3180+
assert(A.isArgExpr(I) && "Expected expression argument");
3181+
// Each item in the list must be an initializer list expression.
3182+
Expr *ArgExpr = A.getArgAsExpr(I);
3183+
if (!isa<InitListExpr>(ArgExpr)) {
3184+
Diag(ArgExpr->getExprLoc(), diag::err_registered_kernels_init_list);
3185+
return;
3186+
}
3187+
3188+
auto *ArgListE = cast<InitListExpr>(ArgExpr);
3189+
unsigned NumInits = ArgListE->getNumInits();
3190+
// Each init-list expression must have a pair of values.
3191+
if (NumInits != 2) {
3192+
Diag(ArgExpr->getExprLoc(),
3193+
diag::err_registered_kernels_init_list_pair_values);
3194+
return;
3195+
}
3196+
3197+
// The first value of the pair must be a string.
3198+
Expr *FirstExpr = ArgListE->getInit(0);
3199+
StringRef CurStr;
3200+
SourceLocation Loc = FirstExpr->getExprLoc();
3201+
if (!SemaRef.checkStringLiteralArgumentAttr(A, FirstExpr, CurStr, &Loc))
3202+
return;
3203+
3204+
// Resolve the FunctionDecl from the second value of the pair.
3205+
Expr *SecondE = ArgListE->getInit(1);
3206+
FunctionDecl *FD = nullptr;
3207+
if (auto *ULE = dyn_cast<UnresolvedLookupExpr>(SecondE)) {
3208+
FD = SemaRef.ResolveSingleFunctionTemplateSpecialization(ULE, true);
3209+
Loc = ULE->getExprLoc();
3210+
} else {
3211+
SecondE = SecondE->IgnoreParenCasts();
3212+
if (auto *DRE = dyn_cast<DeclRefExpr>(SecondE))
3213+
FD = dyn_cast<FunctionDecl>(DRE->getDecl());
3214+
Loc = SecondE->getExprLoc();
3215+
}
3216+
// Issue a diagnostic if we are unable to resolve the FunctionDecl.
3217+
if (!FD) {
3218+
Diag(Loc, diag::err_registered_kernels_resolve_function) << CurStr;
3219+
return;
3220+
}
3221+
// Issue a diagnostic is the FunctionDecl is not a SYCL free function.
3222+
if (!isFreeFunction(FD)) {
3223+
Diag(FD->getLocation(), diag::err_not_sycl_free_function) << CurStr;
3224+
return;
3225+
}
3226+
// Construct a free function kernel.
3227+
constructFreeFunctionKernel(FD, CurStr);
3228+
}
3229+
}

0 commit comments

Comments
 (0)