Skip to content

[NFC][SYCL] Use visitor to emit forward declarations #2670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,26 +388,6 @@ class SYCLIntegrationHeader {
: nullptr;
}

/// Emits a forward declaration for given declaration.
void emitFwdDecl(raw_ostream &O, const Decl *D,
SourceLocation KernelLocation);

/// Emits forward declarations of classes and template classes on which
/// declaration of given type depends. See example in the comments for the
/// implementation.
/// \param O
/// stream to emit to
/// \param T
/// type to emit forward declarations for
/// \param KernelLocation
/// source location of the SYCL kernel function, used to emit nicer
/// diagnostic messages if kernel name is missing
/// \param Emitted
/// a set of declarations forward declrations has been emitted for already
void emitForwardClassDecls(raw_ostream &O, QualType T,
SourceLocation KernelLocation,
llvm::SmallPtrSetImpl<const void *> &Emitted);

private:
/// Keeps invocation descriptors for each kernel invocation started by
/// SYCLIntegrationHeader::startKernel
Expand Down
325 changes: 161 additions & 164 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3344,58 +3344,6 @@ static const char *paramKind2Str(KernelParamKind K) {
#undef CASE
}

// Emits a forward declaration
void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
SourceLocation KernelLocation) {
// wrap the declaration into namespaces if needed
unsigned NamespaceCnt = 0;
std::string NSStr = "";
const DeclContext *DC = D->getDeclContext();

while (DC) {
auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);

if (!NS) {
break;
}

++NamespaceCnt;
const StringRef NSInlinePrefix = NS->isInline() ? "inline " : "";
NSStr.insert(
0, Twine(NSInlinePrefix + "namespace " + NS->getName() + " { ").str());
DC = NS->getDeclContext();
}
O << NSStr;
if (NamespaceCnt > 0)
O << "\n";
// print declaration into a string:
PrintingPolicy P(D->getASTContext().getLangOpts());
P.adjustForCPlusPlusFwdDecl();
P.SuppressTypedefs = true;
P.SuppressUnwrittenScope = true;
std::string S;
llvm::raw_string_ostream SO(S);
D->print(SO, P);
O << SO.str();

if (const auto *ED = dyn_cast<EnumDecl>(D)) {
QualType T = ED->getIntegerType();
// Backup since getIntegerType() returns null for enum forward
// declaration with no fixed underlying type
if (T.isNull())
T = ED->getPromotionType();
O << " : " << T.getAsString();
}

O << ";\n";

// print closing braces for namespaces if needed
for (unsigned I = 0; I < NamespaceCnt; ++I)
O << "}";
if (NamespaceCnt > 0)
O << "\n";
}

// Emits forward declarations of classes and template classes on which
// declaration of given type depends.
// For example, consider SimpleVadd
Expand Down Expand Up @@ -3432,126 +3380,176 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
// template <typename T> class MyTmplClass;
// template <typename T1, unsigned int N, typename ...T2> class SimpleVadd;
//
void SYCLIntegrationHeader::emitForwardClassDecls(
raw_ostream &O, QualType T, SourceLocation KernelLocation,
llvm::SmallPtrSetImpl<const void *> &Printed) {
class SYCLFwdDeclEmitter
: public TypeVisitor<SYCLFwdDeclEmitter>,
public ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter> {
using InnerTypeVisitor = TypeVisitor<SYCLFwdDeclEmitter>;
using InnerTemplArgVisitor = ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter>;
raw_ostream &OS;
llvm::SmallPtrSet<const NamedDecl *, 4> Printed;
PrintingPolicy Policy;

// peel off the pointer types and get the class/struct type:
for (; T->isPointerType(); T = T->getPointeeType())
;
const CXXRecordDecl *RD = T->getAsCXXRecordDecl();
void printForwardDecl(NamedDecl *D) {
// wrap the declaration into namespaces if needed
unsigned NamespaceCnt = 0;
std::string NSStr = "";
const DeclContext *DC = D->getDeclContext();

if (!RD) {
while (DC) {
const auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);

return;
if (!NS)
break;

++NamespaceCnt;
const StringRef NSInlinePrefix = NS->isInline() ? "inline " : "";
NSStr.insert(
0,
Twine(NSInlinePrefix + "namespace " + NS->getName() + " { ").str());
DC = NS->getDeclContext();
}
OS << NSStr;
if (NamespaceCnt > 0)
OS << "\n";

D->print(OS, Policy);

if (const auto *ED = dyn_cast<EnumDecl>(D)) {
QualType T = ED->getIntegerType();
// Backup since getIntegerType() returns null for enum forward
// declaration with no fixed underlying type
if (T.isNull())
T = ED->getPromotionType();
OS << " : " << T.getAsString();
}

OS << ";\n";

// print closing braces for namespaces if needed
for (unsigned I = 0; I < NamespaceCnt; ++I)
OS << "}";
if (NamespaceCnt > 0)
OS << "\n";
}

// see if this is a template specialization ...
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
// ... yes, it is template specialization:
// - first, recurse into template parameters and emit needed forward
// declarations
const TemplateArgumentList &Args = TSD->getTemplateArgs();
// Checks if we've already printed forward declaration and prints it if not.
void checkAndEmitForwardDecl(NamedDecl *D) {
if (Printed.insert(D).second)
printForwardDecl(D);
}

for (unsigned I = 0; I < Args.size(); I++) {
const TemplateArgument &Arg = Args[I];
void VisitTemplateArgs(ArrayRef<TemplateArgument> Args) {
for (size_t I = 0, E = Args.size(); I < E; ++I)
Visit(Args[I]);
}

switch (Arg.getKind()) {
case TemplateArgument::ArgKind::Type:
case TemplateArgument::ArgKind::Integral: {
QualType T = (Arg.getKind() == TemplateArgument::ArgKind::Type)
? Arg.getAsType()
: Arg.getIntegralType();

// Handle Kernel Name Type templated using enum type and value.
if (const auto *ET = T->getAs<EnumType>()) {
const EnumDecl *ED = ET->getDecl();
emitFwdDecl(O, ED, KernelLocation);
} else if (Arg.getKind() == TemplateArgument::ArgKind::Type)
emitForwardClassDecls(O, T, KernelLocation, Printed);
break;
}
case TemplateArgument::ArgKind::Pack: {
ArrayRef<TemplateArgument> Pack = Arg.getPackAsArray();
public:
SYCLFwdDeclEmitter(raw_ostream &OS, LangOptions LO) : OS(OS), Policy(LO) {
Policy.adjustForCPlusPlusFwdDecl();
Policy.SuppressTypedefs = true;
Policy.SuppressUnwrittenScope = true;
}

for (const auto &T : Pack) {
if (T.getKind() == TemplateArgument::ArgKind::Type) {
emitForwardClassDecls(O, T.getAsType(), KernelLocation, Printed);
}
}
break;
}
case TemplateArgument::ArgKind::Template: {
// recursion is not required, since the maximum possible nesting level
// equals two for template argument
//
// for example:
// template <typename T> class Bar;
// template <template <typename> class> class Baz;
// template <template <template <typename> class> class T>
// class Foo;
//
// The Baz is a template class. The Baz<Bar> is a class. The class Foo
// should be specialized with template class, not a class. The correct
// specialization of template class Foo is Foo<Baz>. The incorrect
// specialization of template class Foo is Foo<Baz<Bar>>. In this case
// template class Foo specialized by class Baz<Bar>, not a template
// class template <template <typename> class> class T as it should.
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
TemplateParameterList *TemplateParams = TD->getTemplateParameters();
for (NamedDecl *P : *TemplateParams) {
// If template template paramter type has an enum value template
// parameter, forward declaration of enum type is required. Only enum
// values (not types) need to be handled. For example, consider the
// following kernel name type:
//
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
// typename TypeIn> class T> class Foo;
//
// The correct specialization for Foo (with enum type) is:
// Foo<EnumTypeOut, Baz>, where Baz is a template class.
//
// Therefore the forward class declarations generated in the
// integration header are:
// template <EnumValueIn EnumValue, typename TypeIn> class Baz;
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
// typename EnumTypeIn> class T> class Foo;
//
// This requires the following enum forward declarations:
// enum class EnumTypeOut : int; (Used to template Foo)
// enum class EnumValueIn : int; (Used to template Baz)
if (NonTypeTemplateParmDecl *TemplateParam =
dyn_cast<NonTypeTemplateParmDecl>(P)) {
QualType T = TemplateParam->getType();
if (const auto *ET = T->getAs<EnumType>()) {
const EnumDecl *ED = ET->getDecl();
emitFwdDecl(O, ED, KernelLocation);
}
}
}
if (Printed.insert(TD).second) {
emitFwdDecl(O, TD, KernelLocation);
}
break;
}
default:
break; // nop
}
void Visit(QualType T) {
if (T.isNull())
return;
InnerTypeVisitor::Visit(T.getTypePtr());
}

void Visit(const TemplateArgument &TA) {
if (TA.isNull())
return;
InnerTemplArgVisitor::Visit(TA);
}

void VisitPointerType(const PointerType *T) {
// Peel off the pointer types.
QualType PT = T->getPointeeType();
while (PT->isPointerType())
PT = PT->getPointeeType();
Visit(PT);
}

void VisitTagType(const TagType *T) {
TagDecl *TD = T->getDecl();
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(TD)) {
// - first, recurse into template parameters and emit needed forward
// declarations
ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs().asArray();
VisitTemplateArgs(Args);
// - second, emit forward declaration for the template class being
// specialized
ClassTemplateDecl *CTD = TSD->getSpecializedTemplate();
assert(CTD && "template declaration must be available");

checkAndEmitForwardDecl(CTD);
return;
}
// - second, emit forward declaration for the template class being
// specialized
ClassTemplateDecl *CTD = TSD->getSpecializedTemplate();
assert(CTD && "template declaration must be available");
checkAndEmitForwardDecl(TD);
}

void VisitTypeTemplateArgument(const TemplateArgument &TA) {
QualType T = TA.getAsType();
Visit(T);
}

if (Printed.insert(CTD).second) {
emitFwdDecl(O, CTD, KernelLocation);
void VisitIntegralTemplateArgument(const TemplateArgument &TA) {
QualType T = TA.getIntegralType();
if (const EnumType *ET = T->getAs<EnumType>())
VisitTagType(ET);
}

void VisitTemplateTemplateArgument(const TemplateArgument &TA) {
// recursion is not required, since the maximum possible nesting level
// equals two for template argument
//
// for example:
// template <typename T> class Bar;
// template <template <typename> class> class Baz;
// template <template <template <typename> class> class T>
// class Foo;
//
// The Baz is a template class. The Baz<Bar> is a class. The class Foo
// should be specialized with template class, not a class. The correct
// specialization of template class Foo is Foo<Baz>. The incorrect
// specialization of template class Foo is Foo<Baz<Bar>>. In this case
// template class Foo specialized by class Baz<Bar>, not a template
// class template <template <typename> class> class T as it should.
TemplateDecl *TD = TA.getAsTemplate().getAsTemplateDecl();
TemplateParameterList *TemplateParams = TD->getTemplateParameters();
for (NamedDecl *P : *TemplateParams) {
// If template template parameter type has an enum value template
// parameter, forward declaration of enum type is required. Only enum
// values (not types) need to be handled. For example, consider the
// following kernel name type:
//
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
// typename TypeIn> class T> class Foo;
//
// The correct specialization for Foo (with enum type) is:
// Foo<EnumTypeOut, Baz>, where Baz is a template class.
//
// Therefore the forward class declarations generated in the
// integration header are:
// template <EnumValueIn EnumValue, typename TypeIn> class Baz;
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
// typename EnumTypeIn> class T> class Foo;
//
// This requires the following enum forward declarations:
// enum class EnumTypeOut : int; (Used to template Foo)
// enum class EnumValueIn : int; (Used to template Baz)
if (NonTypeTemplateParmDecl *TemplateParam =
dyn_cast<NonTypeTemplateParmDecl>(P))
if (const EnumType *ET = TemplateParam->getType()->getAs<EnumType>())
VisitTagType(ET);
}
} else if (Printed.insert(RD).second) {
// emit forward declarations for "leaf" classes in the template parameter
// tree;
emitFwdDecl(O, RD, KernelLocation);
checkAndEmitForwardDecl(TD);
}
}

void VisitPackTemplateArgument(const TemplateArgument &TA) {
VisitTemplateArgs(TA.getPackAsArray());
}
};

class SYCLKernelNameTypePrinter
: public TypeVisitor<SYCLKernelNameTypePrinter>,
Expand Down Expand Up @@ -3709,10 +3707,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
if (!UnnamedLambdaSupport) {
O << "// Forward declarations of templated kernel function types:\n";

llvm::SmallPtrSet<const void *, 4> Printed;
for (const KernelDesc &K : KernelDescs) {
emitForwardClassDecls(O, K.NameType, K.KernelLocation, Printed);
}
SYCLFwdDeclEmitter FwdDeclEmitter(O, S.getLangOpts());
for (const KernelDesc &K : KernelDescs)
FwdDeclEmitter.Visit(K.NameType);
}
O << "\n";

Expand Down