@@ -3130,21 +3130,14 @@ class SYCLKernelNameTypeVisitor
3130
3130
void Visit (QualType T) {
3131
3131
if (T.isNull ())
3132
3132
return ;
3133
+
3133
3134
const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
3134
- if (!RD) {
3135
- if (T->isNullPtrType ()) {
3136
- S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3137
- << KernelNameType;
3138
- S.Diag (KernelInvocationFuncLoc, diag::note_invalid_type_in_sycl_kernel)
3139
- << /* kernel name cannot be a type in the std namespace */ 2 << T;
3140
- IsInvalid = true ;
3141
- }
3142
- return ;
3143
- }
3144
3135
// If KernelNameType has template args visit each template arg via
3145
3136
// ConstTemplateArgumentVisitor
3146
- if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3137
+ if (const auto *TSD =
3138
+ dyn_cast_or_null<ClassTemplateSpecializationDecl>(RD)) {
3147
3139
ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs ().asArray ();
3140
+
3148
3141
VisitTemplateArgs (Args);
3149
3142
} else {
3150
3143
InnerTypeVisitor::Visit (T.getTypePtr ());
@@ -3157,62 +3150,104 @@ class SYCLKernelNameTypeVisitor
3157
3150
InnerTemplArgVisitor::Visit (TA);
3158
3151
}
3159
3152
3160
- void VisitEnumType (const EnumType *T) {
3161
- const EnumDecl *ED = T->getDecl ();
3162
- if (!ED->isScoped () && !ED->isFixed ()) {
3163
- S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3153
+ void VisitBuiltinType (const BuiltinType *TT) {
3154
+ if (TT->isNullPtrType ()) {
3155
+ S.Diag (KernelInvocationFuncLoc, diag::err_nullptr_t_type_in_sycl_kernel)
3164
3156
<< KernelNameType;
3165
- S.Diag (KernelInvocationFuncLoc, diag::note_invalid_type_in_sycl_kernel)
3166
- << /* Unscoped enum requires fixed underlying type */ 1
3167
- << QualType (ED->getTypeForDecl (), 0 );
3157
+
3168
3158
IsInvalid = true ;
3169
3159
}
3160
+ return ;
3170
3161
}
3171
3162
3172
- void VisitRecordType (const RecordType *T) {
3173
- return VisitTagDecl (T->getDecl ());
3174
- }
3163
+ void VisitTagType (const TagType *TT) {
3164
+ return DiagnoseKernelNameType (TT->getDecl ());
3165
+ }
3166
+
3167
+ void DiagnoseKernelNameType (const NamedDecl *DeclNamed) {
3168
+ /*
3169
+ This is a helper function which throws an error if the kernel name
3170
+ declaration is:
3171
+ * declared within namespace 'std' (at any level)
3172
+ e.g., namespace std { namespace literals { class Whatever; } }
3173
+ h.single_task<std::literals::Whatever>([]() {});
3174
+ * declared within an anonymous namespace (at any level)
3175
+ e.g., namespace foo { namespace { class Whatever; } }
3176
+ h.single_task<foo::Whatever>([]() {});
3177
+ * declared within a function
3178
+ e.g., void foo() { struct S { int i; };
3179
+ h.single_task<S>([]() {}); }
3180
+ * declared within another tag
3181
+ e.g., struct S { struct T { int i } t; };
3182
+ h.single_task<S::T>([]() {});
3183
+ */
3184
+
3185
+ if (const auto *ED = dyn_cast<EnumDecl>(DeclNamed)) {
3186
+ if (!ED->isScoped () && !ED->isFixed ()) {
3187
+ S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3188
+ << /* unscoped enum requires fixed underlying type */ 1
3189
+ << DeclNamed;
3190
+ IsInvalid = true ;
3191
+ }
3192
+ }
3175
3193
3176
- void VisitTagDecl (const TagDecl *Tag) {
3177
3194
bool UnnamedLambdaEnabled =
3178
3195
S.getASTContext ().getLangOpts ().SYCLUnnamedLambda ;
3179
- const DeclContext *DeclCtx = Tag ->getDeclContext ();
3196
+ const DeclContext *DeclCtx = DeclNamed ->getDeclContext ();
3180
3197
if (DeclCtx && !UnnamedLambdaEnabled) {
3181
- auto *NameSpace = dyn_cast_or_null<NamespaceDecl>(DeclCtx);
3182
- if (NameSpace && NameSpace->isStdNamespace ()) {
3183
- S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3184
- << KernelNameType;
3185
- S.Diag (KernelInvocationFuncLoc, diag::note_invalid_type_in_sycl_kernel)
3186
- << /* kernel name cannot be a type in the std namespace */ 2
3187
- << QualType (Tag->getTypeForDecl (), 0 );
3188
- IsInvalid = true ;
3189
- return ;
3190
- }
3191
- if (!DeclCtx->isTranslationUnit () && !isa<NamespaceDecl>(DeclCtx)) {
3192
- const bool KernelNameIsMissing = Tag->getName ().empty ();
3193
- if (KernelNameIsMissing) {
3194
- S.Diag (KernelInvocationFuncLoc,
3195
- diag::err_sycl_kernel_incorrectly_named)
3196
- << KernelNameType;
3198
+
3199
+ // Check if the kernel name declaration is declared within namespace
3200
+ // "std" or "anonymous" namespace (at any level).
3201
+ while (!DeclCtx->isTranslationUnit () && isa<NamespaceDecl>(DeclCtx)) {
3202
+ const auto *NSDecl = cast<NamespaceDecl>(DeclCtx);
3203
+ if (NSDecl->isStdNamespace ()) {
3197
3204
S.Diag (KernelInvocationFuncLoc,
3198
- diag::note_invalid_type_in_sycl_kernel )
3199
- << /* unnamed type used in a SYCL kernel name */ 3 ;
3205
+ diag::err_invalid_std_type_in_sycl_kernel )
3206
+ << KernelNameType << DeclNamed ;
3200
3207
IsInvalid = true ;
3201
3208
return ;
3202
3209
}
3203
- if (Tag-> isCompleteDefinition ()) {
3210
+ if (NSDecl-> isAnonymousNamespace ()) {
3204
3211
S.Diag (KernelInvocationFuncLoc,
3205
3212
diag::err_sycl_kernel_incorrectly_named)
3213
+ << /* kernel name should be globally visible */ 0
3206
3214
<< KernelNameType;
3207
- S.Diag (KernelInvocationFuncLoc,
3208
- diag::note_invalid_type_in_sycl_kernel)
3209
- << /* kernel name is not globally-visible */ 0
3210
- << QualType (Tag->getTypeForDecl (), 0 );
3211
3215
IsInvalid = true ;
3212
- } else {
3213
- S.Diag (KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl);
3214
- S.Diag (Tag->getSourceRange ().getBegin (), diag::note_previous_decl)
3215
- << Tag->getName ();
3216
+ return ;
3217
+ }
3218
+ DeclCtx = DeclCtx->getParent ();
3219
+ }
3220
+
3221
+ // Check if the kernel name is a Tag declaration
3222
+ // local to a non-namespace scope (i.e. Inside a function or within
3223
+ // another Tag etc).
3224
+ if (!DeclCtx->isTranslationUnit () && !isa<NamespaceDecl>(DeclCtx)) {
3225
+ if (const auto *Tag = dyn_cast<TagDecl>(DeclNamed)) {
3226
+ bool UnnamedLambdaUsed = Tag->getIdentifier () == nullptr ;
3227
+
3228
+ if (UnnamedLambdaUsed) {
3229
+ S.Diag (KernelInvocationFuncLoc,
3230
+ diag::err_sycl_kernel_incorrectly_named)
3231
+ << /* unnamed lambda used */ 2 << KernelNameType;
3232
+
3233
+ IsInvalid = true ;
3234
+ return ;
3235
+ }
3236
+ // Check if the declaration is completely defined within a
3237
+ // function or class/struct.
3238
+
3239
+ if (Tag->isCompleteDefinition ()) {
3240
+ S.Diag (KernelInvocationFuncLoc,
3241
+ diag::err_sycl_kernel_incorrectly_named)
3242
+ << /* kernel name should be globally visible */ 0
3243
+ << KernelNameType;
3244
+
3245
+ IsInvalid = true ;
3246
+ } else {
3247
+ S.Diag (KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl);
3248
+ S.Diag (DeclNamed->getLocation (), diag::note_previous_decl)
3249
+ << DeclNamed->getName ();
3250
+ }
3216
3251
}
3217
3252
}
3218
3253
}
@@ -3221,15 +3256,15 @@ class SYCLKernelNameTypeVisitor
3221
3256
void VisitTypeTemplateArgument (const TemplateArgument &TA) {
3222
3257
QualType T = TA.getAsType ();
3223
3258
if (const auto *ET = T->getAs <EnumType>())
3224
- VisitEnumType (ET);
3259
+ VisitTagType (ET);
3225
3260
else
3226
3261
Visit (T);
3227
3262
}
3228
3263
3229
3264
void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
3230
3265
QualType T = TA.getIntegralType ();
3231
3266
if (const EnumType *ET = T->getAs <EnumType>())
3232
- VisitEnumType (ET);
3267
+ VisitTagType (ET);
3233
3268
}
3234
3269
3235
3270
void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
@@ -3240,7 +3275,7 @@ class SYCLKernelNameTypeVisitor
3240
3275
if (NonTypeTemplateParmDecl *TemplateParam =
3241
3276
dyn_cast<NonTypeTemplateParmDecl>(P))
3242
3277
if (const EnumType *ET = TemplateParam->getType ()->getAs <EnumType>())
3243
- VisitEnumType (ET);
3278
+ VisitTagType (ET);
3244
3279
}
3245
3280
}
3246
3281
@@ -3301,7 +3336,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
3301
3336
3302
3337
// Emit diagnostics for SYCL device kernels only
3303
3338
if (LangOpts.SYCLIsDevice )
3304
- KernelNameTypeVisitor.Visit (KernelNameType);
3339
+ KernelNameTypeVisitor.Visit (KernelNameType. getCanonicalType () );
3305
3340
Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker, DecompMarker);
3306
3341
Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker,
3307
3342
DecompMarker);
0 commit comments