Skip to content

Commit 06f667a

Browse files
author
Erich Keane
authored
[SYCL] Allow recursive function calls in a constexpr context. (#2105)
It doesn't really make sense to restrict recursion for something that is forced to be evaluated at constexpr time, since this is a restriction due to device limitations. This patch creates a constexpr-context count for all of the constexpr contexts I could create, and creates a counter. It is implemented this way to permit a future implementer to add other diagnostics (such as DLLImport?) as permissible in constexpr.
1 parent 9b9639a commit 06f667a

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

clang/lib/Sema/SemaSYCL.cpp

+62-1
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,24 @@ static int64_t getIntExprValue(const Expr *E, ASTContext &Ctx) {
322322
}
323323

324324
class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
325+
// Used to keep track of the constexpr depth, so we know whether to skip
326+
// diagnostics.
327+
unsigned ConstexprDepth = 0;
328+
struct ConstexprDepthRAII {
329+
MarkDeviceFunction &MDF;
330+
bool Increment;
331+
332+
ConstexprDepthRAII(MarkDeviceFunction &MDF, bool Increment = true)
333+
: MDF(MDF), Increment(Increment) {
334+
if (Increment)
335+
++MDF.ConstexprDepth;
336+
}
337+
~ConstexprDepthRAII() {
338+
if (Increment)
339+
--MDF.ConstexprDepth;
340+
}
341+
};
342+
325343
public:
326344
MarkDeviceFunction(Sema &S)
327345
: RecursiveASTVisitor<MarkDeviceFunction>(), SemaRef(S) {}
@@ -335,7 +353,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
335353
// instantiation as template functions. It means that
336354
// all functions used by kernel have already been parsed and have
337355
// definitions.
338-
if (RecursiveSet.count(Callee)) {
356+
if (RecursiveSet.count(Callee) && !ConstexprDepth) {
339357
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
340358
<< Sema::KernelCallRecursiveFunction;
341359
SemaRef.Diag(Callee->getSourceRange().getBegin(),
@@ -386,6 +404,49 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
386404
return true;
387405
}
388406

407+
// Skip checking rules on variables initialized during constant evaluation.
408+
bool TraverseVarDecl(VarDecl *VD) {
409+
ConstexprDepthRAII R(*this, VD->isConstexpr());
410+
return RecursiveASTVisitor::TraverseVarDecl(VD);
411+
}
412+
413+
// Skip checking rules on template arguments, since these are constant
414+
// expressions.
415+
bool TraverseTemplateArgumentLoc(const TemplateArgumentLoc &ArgLoc) {
416+
ConstexprDepthRAII R(*this);
417+
return RecursiveASTVisitor::TraverseTemplateArgumentLoc(ArgLoc);
418+
}
419+
420+
// Skip checking the static assert, both components are required to be
421+
// constant expressions.
422+
bool TraverseStaticAssertDecl(StaticAssertDecl *D) {
423+
ConstexprDepthRAII R(*this);
424+
return RecursiveASTVisitor::TraverseStaticAssertDecl(D);
425+
}
426+
427+
// Make sure we skip the condition of the case, since that is a constant
428+
// expression.
429+
bool TraverseCaseStmt(CaseStmt *S) {
430+
{
431+
ConstexprDepthRAII R(*this);
432+
if (!TraverseStmt(S->getLHS()))
433+
return false;
434+
if (!TraverseStmt(S->getRHS()))
435+
return false;
436+
}
437+
return TraverseStmt(S->getSubStmt());
438+
}
439+
440+
// Skip checking the size expr, since a constant array type loc's size expr is
441+
// a constant expression.
442+
bool TraverseConstantArrayTypeLoc(const ConstantArrayTypeLoc &ArrLoc) {
443+
if (!TraverseTypeLoc(ArrLoc.getElementLoc()))
444+
return false;
445+
446+
ConstexprDepthRAII R(*this);
447+
return TraverseStmt(ArrLoc.getSizeExpr());
448+
}
449+
389450
// The call graph for this translation unit.
390451
CallGraph SYCLCG;
391452
// The set of functions called by a kernel function.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: %clang_cc1 -fsycl -fsycl-is-device -fcxx-exceptions -Wno-return-type -verify -fsyntax-only -std=c++20 -Werror=vla %s
2+
3+
template <typename name, typename Func>
4+
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
5+
kernelFunc();
6+
}
7+
8+
// expected-note@+1{{function implemented using recursion declared here}}
9+
constexpr int constexpr_recurse1(int n);
10+
11+
// expected-note@+1 3{{function implemented using recursion declared here}}
12+
constexpr int constexpr_recurse(int n) {
13+
if (n)
14+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
15+
return constexpr_recurse1(n - 1);
16+
return 103;
17+
}
18+
19+
constexpr int constexpr_recurse1(int n) {
20+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
21+
return constexpr_recurse(n) + 1;
22+
}
23+
24+
template <int I>
25+
void bar() {}
26+
27+
template <int... args>
28+
void bar2() {}
29+
30+
enum class SomeE {
31+
Value = constexpr_recurse(5)
32+
};
33+
34+
struct ConditionallyExplicitCtor {
35+
explicit(constexpr_recurse(5) == 103) ConditionallyExplicitCtor(int i) {}
36+
};
37+
38+
void conditionally_noexcept() noexcept(constexpr_recurse(5)) {}
39+
40+
// All of the uses of constexpr_recurse here are forced constant expressions, so
41+
// they should not diagnose.
42+
void constexpr_recurse_test() {
43+
constexpr int i = constexpr_recurse(1);
44+
bar<constexpr_recurse(2)>();
45+
bar2<1, 2, constexpr_recurse(2)>();
46+
static_assert(constexpr_recurse(2) == 105, "");
47+
48+
int j;
49+
switch (105) {
50+
case constexpr_recurse(2):
51+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
52+
j = constexpr_recurse(5);
53+
break;
54+
}
55+
56+
SomeE e = SomeE::Value;
57+
58+
int ce_array[constexpr_recurse(5)];
59+
60+
conditionally_noexcept();
61+
62+
if constexpr ((bool)SomeE::Value) {
63+
}
64+
65+
ConditionallyExplicitCtor c(1);
66+
}
67+
68+
void constexpr_recurse_test_err() {
69+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
70+
int i = constexpr_recurse(1);
71+
}
72+
73+
int main() {
74+
kernel_single_task<class fake_kernel>([]() { constexpr_recurse_test(); });
75+
kernel_single_task<class fake_kernel>([]() { constexpr_recurse_test_err(); });
76+
}

0 commit comments

Comments
 (0)