Skip to content

Commit d482ac3

Browse files
author
Erich Keane
authored
[SYCL] Implement anonymous namespace spec-id functionality. (#3576)
In order to get the specialization_id metadata to be usable in the integration footer, we need to insert a 'shim' function at each anonymous namespace so that we can do a lookup at the 'right' location, then look it up again later by the type. This implements all that functionality based on my latest understanding of the design.
1 parent e6733e4 commit d482ac3

File tree

5 files changed

+745
-38
lines changed

5 files changed

+745
-38
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 163 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4491,10 +4491,22 @@ SYCLIntegrationHeader::SYCLIntegrationHeader(bool _UnnamedLambdaSupport,
44914491
: UnnamedLambdaSupport(_UnnamedLambdaSupport), S(_S) {}
44924492

44934493
void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
4494+
// Skip the dependent version of these variables, we only care about them
4495+
// after instantiation.
4496+
if (VD->getDeclContext()->isDependentContext())
4497+
return;
44944498
// Step 1: ensure that this is of the correct type-spec-constant template
44954499
// specialization).
4496-
if (!Util::isSyclSpecIdType(VD->getType()))
4497-
return;
4500+
if (!Util::isSyclSpecIdType(VD->getType())) {
4501+
// Handle the case where this could be a deduced type, such as a deduction
4502+
// guide. We have to do this here since this function, unlike most of the
4503+
// rest of this file, is called during Sema instead of after it. We will
4504+
// also have to filter out after deduction later.
4505+
QualType Ty = VD->getType().getCanonicalType();
4506+
4507+
if (!Ty->isUndeducedType())
4508+
return;
4509+
}
44984510
// Step 2: ensure that this is a static member, or a namespace-scope.
44994511
// Note that isLocalVarDeclorParm excludes thread-local and static-local
45004512
// intentionally, as there is no way to 'spell' one of those in the
@@ -4536,26 +4548,162 @@ void SYCLIntegrationFooter::emitSpecIDName(raw_ostream &O, const VarDecl *VD) {
45364548
O << "";
45374549
}
45384550

4539-
bool SYCLIntegrationFooter::emit(raw_ostream &O) {
4551+
template <typename BeforeFn, typename AfterFn>
4552+
static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS,
4553+
const DeclContext *DC) {
4554+
if (DC->isTranslationUnit())
4555+
return;
4556+
4557+
const auto *CurDecl = cast<Decl>(DC);
4558+
// Ensure we are in the canonical version, so that we know we have the 'full'
4559+
// name of the thing.
4560+
CurDecl = CurDecl->getCanonicalDecl();
4561+
4562+
// We are intentionally skipping linkage decls and record decls. Namespaces
4563+
// can appear in a linkage decl, but not a record decl, so we don't have to
4564+
// worry about the names getting messed up from that. We handle record decls
4565+
// later when printing the name of the thing.
4566+
const auto *NS = dyn_cast<NamespaceDecl>(CurDecl);
4567+
if (NS)
4568+
Before(OS, NS);
4569+
4570+
if (const DeclContext *NewDC = CurDecl->getDeclContext())
4571+
PrintNSHelper(Before, After, OS, NewDC);
4572+
4573+
if (NS)
4574+
After(OS, NS);
4575+
}
4576+
4577+
static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC) {
4578+
PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {},
4579+
[](raw_ostream &OS, const NamespaceDecl *NS) {
4580+
if (NS->isInline())
4581+
OS << "inline ";
4582+
OS << "namespace ";
4583+
if (!NS->isAnonymousNamespace())
4584+
OS << NS->getName() << " ";
4585+
OS << "{\n";
4586+
},
4587+
OS, DC);
4588+
}
4589+
4590+
static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
4591+
PrintNSHelper(
4592+
[](raw_ostream &OS, const NamespaceDecl *NS) {
4593+
OS << "} // ";
4594+
if (NS->isInline())
4595+
OS << "inline ";
4596+
4597+
OS << "namespace ";
4598+
if (!NS->isAnonymousNamespace())
4599+
OS << NS->getName();
4600+
4601+
OS << '\n';
4602+
},
4603+
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
4604+
}
4605+
4606+
static std::string EmitSpecIdShim(raw_ostream &OS, unsigned &ShimCounter,
4607+
const std::string &LastShim,
4608+
const NamespaceDecl *AnonNS) {
4609+
std::string NewShimName =
4610+
"__sycl_detail::__spec_id_shim_" + std::to_string(ShimCounter) + "()";
4611+
// Print opening-namespace
4612+
PrintNamespaces(OS, Decl::castToDeclContext(AnonNS));
4613+
OS << "namespace __sycl_detail {\n";
4614+
OS << "static constexpr decltype(" << LastShim << ") &__spec_id_shim_"
4615+
<< ShimCounter << "() {\n";
4616+
OS << " return " << LastShim << ";\n";
4617+
OS << "}\n";
4618+
OS << "} // namespace __sycl_detail \n";
4619+
PrintNSClosingBraces(OS, Decl::castToDeclContext(AnonNS));
4620+
4621+
++ShimCounter;
4622+
return std::move(NewShimName);
4623+
}
4624+
4625+
// Emit the list of shims required for a DeclContext, calls itself recursively.
4626+
static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
4627+
const DeclContext *DC,
4628+
std::string &NameForLastShim) {
4629+
if (DC->isTranslationUnit()) {
4630+
NameForLastShim = "::" + NameForLastShim;
4631+
return;
4632+
}
4633+
4634+
const auto *CurDecl = cast<Decl>(DC)->getCanonicalDecl();
4635+
4636+
// We skip linkage decls, since they don't modify the Qualified name.
4637+
if (const auto *RD = dyn_cast<RecordDecl>(CurDecl)) {
4638+
NameForLastShim = RD->getNameAsString() + "::" + NameForLastShim;
4639+
} else if (const auto *ND = dyn_cast<NamespaceDecl>(CurDecl)) {
4640+
if (ND->isAnonymousNamespace()) {
4641+
// Print current shim, reset 'name for last shim'.
4642+
NameForLastShim = EmitSpecIdShim(OS, ShimCounter, NameForLastShim, ND);
4643+
} else {
4644+
NameForLastShim = ND->getNameAsString() + "::" + NameForLastShim;
4645+
}
4646+
} else {
4647+
// FIXME: I don't believe there are other declarations that these variables
4648+
// could possibly find themselves in. LinkageDecls don't change the
4649+
// qualified name, so there is nothing to do here. At one point we should
4650+
// probably convince ourselves that this is entire list and remove this
4651+
// comment.
4652+
assert((isa<LinkageSpecDecl, ExternCContextDecl>(CurDecl)) &&
4653+
"Unhandled decl type");
4654+
}
4655+
4656+
EmitSpecIdShims(OS, ShimCounter, CurDecl->getDeclContext(), NameForLastShim);
4657+
}
4658+
4659+
// Emit the list of shims required for a variable declaration.
4660+
// Returns a string containing the FQN of the 'top most' shim, including its
4661+
// function call parameters.
4662+
static std::string EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
4663+
const VarDecl *VD) {
4664+
assert(VD->isInAnonymousNamespace() &&
4665+
"Function assumes this is in an anonymous namespace");
4666+
std::string RelativeName = VD->getNameAsString();
4667+
EmitSpecIdShims(OS, ShimCounter, VD->getDeclContext(), RelativeName);
4668+
return std::move(RelativeName);
4669+
}
4670+
4671+
bool SYCLIntegrationFooter::emit(raw_ostream &OS) {
45404672
PrintingPolicy Policy{S.getLangOpts()};
45414673
Policy.adjustForCPlusPlusFwdDecl();
45424674
Policy.SuppressTypedefs = true;
45434675
Policy.SuppressUnwrittenScope = true;
45444676

4545-
for (const VarDecl *D : SpecConstants) {
4546-
O << "template<>\n";
4547-
O << "inline const char *get_spec_constant_symbolic_ID<";
4548-
// Emit the FQN for this, but we probably need to do some funny-business for
4549-
// anonymous namespaces.
4550-
D->printQualifiedName(O, Policy);
4551-
O << ">() {\n";
4552-
O << " return \"";
4553-
emitSpecIDName(O, D);
4554-
O << "\";\n";
4555-
O << "}\n";
4677+
// Used to uniquely name the 'shim's as we generate the names in each
4678+
// anonymous namespace.
4679+
unsigned ShimCounter = 0;
4680+
for (const VarDecl *VD : SpecConstants) {
4681+
VD = VD->getCanonicalDecl();
4682+
if (VD->isInAnonymousNamespace()) {
4683+
std::string TopShim = EmitSpecIdShims(OS, ShimCounter, VD);
4684+
OS << "namespace sycl {\n";
4685+
OS << "namespace detail {\n";
4686+
OS << "template<>\n";
4687+
OS << "inline const char *get_spec_constant_symbolic_ID<" << TopShim
4688+
<< ">() {\n";
4689+
OS << " return " << TopShim << ";\n";
4690+
} else {
4691+
OS << "namespace sycl {\n";
4692+
OS << "namespace detail {\n";
4693+
OS << "template<>\n";
4694+
OS << "inline const char *get_spec_constant_symbolic_ID<::";
4695+
VD->printQualifiedName(OS, Policy);
4696+
OS << ">() {\n";
4697+
OS << " return \"";
4698+
emitSpecIDName(OS, VD);
4699+
OS << "\";\n";
4700+
}
4701+
OS << "}\n";
4702+
OS << "} // namespace detail\n";
4703+
OS << "} // namespace sycl\n";
45564704
}
45574705

4558-
O << "#include <CL/sycl/detail/spec_const_integration.hpp>\n";
4706+
OS << "#include <CL/sycl/detail/spec_const_integration.hpp>\n";
45594707
return true;
45604708
}
45614709

clang/test/CodeGenSYCL/Inputs/sycl.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ template <typename T> class specialization_id {
321321
T MDefaultValue;
322322
};
323323

324+
#if __cplusplus >= 201703L
325+
template<typename T> specialization_id(T) -> specialization_id<T>;
326+
#endif // C++17.
327+
324328
#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
325329
template <typename KernelName = auto_name, typename KernelType>
326330
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc) { // #KernelSingleTask

0 commit comments

Comments
 (0)