@@ -322,6 +322,24 @@ static int64_t getIntExprValue(const Expr *E, ASTContext &Ctx) {
322
322
}
323
323
324
324
class MarkDeviceFunction : public RecursiveASTVisitor <MarkDeviceFunction> {
325
+ // Used to keep track of the constexpr depth, so we know whether to skip
326
+ // diagnostics.
327
+ unsigned ConstexprDepth = 0 ;
328
+ struct ConstexprDepthRAII {
329
+ MarkDeviceFunction &MDF;
330
+ bool Increment;
331
+
332
+ ConstexprDepthRAII (MarkDeviceFunction &MDF, bool Increment = true )
333
+ : MDF(MDF), Increment(Increment) {
334
+ if (Increment)
335
+ ++MDF.ConstexprDepth ;
336
+ }
337
+ ~ConstexprDepthRAII () {
338
+ if (Increment)
339
+ --MDF.ConstexprDepth ;
340
+ }
341
+ };
342
+
325
343
public:
326
344
MarkDeviceFunction (Sema &S)
327
345
: RecursiveASTVisitor<MarkDeviceFunction>(), SemaRef(S) {}
@@ -335,7 +353,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
335
353
// instantiation as template functions. It means that
336
354
// all functions used by kernel have already been parsed and have
337
355
// definitions.
338
- if (RecursiveSet.count (Callee)) {
356
+ if (RecursiveSet.count (Callee) && !ConstexprDepth ) {
339
357
SemaRef.Diag (e->getExprLoc (), diag::err_sycl_restrict)
340
358
<< Sema::KernelCallRecursiveFunction;
341
359
SemaRef.Diag (Callee->getSourceRange ().getBegin (),
@@ -386,6 +404,49 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
386
404
return true ;
387
405
}
388
406
407
+ // Skip checking rules on variables initialized during constant evaluation.
408
+ bool TraverseVarDecl (VarDecl *VD) {
409
+ ConstexprDepthRAII R (*this , VD->isConstexpr ());
410
+ return RecursiveASTVisitor::TraverseVarDecl (VD);
411
+ }
412
+
413
+ // Skip checking rules on template arguments, since these are constant
414
+ // expressions.
415
+ bool TraverseTemplateArgumentLoc (const TemplateArgumentLoc &ArgLoc) {
416
+ ConstexprDepthRAII R (*this );
417
+ return RecursiveASTVisitor::TraverseTemplateArgumentLoc (ArgLoc);
418
+ }
419
+
420
+ // Skip checking the static assert, both components are required to be
421
+ // constant expressions.
422
+ bool TraverseStaticAssertDecl (StaticAssertDecl *D) {
423
+ ConstexprDepthRAII R (*this );
424
+ return RecursiveASTVisitor::TraverseStaticAssertDecl (D);
425
+ }
426
+
427
+ // Make sure we skip the condition of the case, since that is a constant
428
+ // expression.
429
+ bool TraverseCaseStmt (CaseStmt *S) {
430
+ {
431
+ ConstexprDepthRAII R (*this );
432
+ if (!TraverseStmt (S->getLHS ()))
433
+ return false ;
434
+ if (!TraverseStmt (S->getRHS ()))
435
+ return false ;
436
+ }
437
+ return TraverseStmt (S->getSubStmt ());
438
+ }
439
+
440
+ // Skip checking the size expr, since a constant array type loc's size expr is
441
+ // a constant expression.
442
+ bool TraverseConstantArrayTypeLoc (const ConstantArrayTypeLoc &ArrLoc) {
443
+ if (!TraverseTypeLoc (ArrLoc.getElementLoc ()))
444
+ return false ;
445
+
446
+ ConstexprDepthRAII R (*this );
447
+ return TraverseStmt (ArrLoc.getSizeExpr ());
448
+ }
449
+
389
450
// The call graph for this translation unit.
390
451
CallGraph SYCLCG;
391
452
// The set of functions called by a kernel function.
0 commit comments