@@ -335,9 +335,8 @@ class SYCLIntegrationHeader {
335
335
336
336
/// Signals that subsequent parameter descriptor additions will go to
337
337
/// the kernel with given name. Starts new kernel invocation descriptor.
338
- void startKernel(StringRef KernelName, QualType KernelNameType,
339
- StringRef KernelStableName, SourceLocation Loc, bool IsESIMD,
340
- bool IsUnnamedKernel);
338
+ void startKernel(const FunctionDecl *SyclKernel, QualType KernelNameType,
339
+ SourceLocation Loc, bool IsESIMD, bool IsUnnamedKernel);
341
340
342
341
/// Adds a kernel parameter descriptor to current kernel invocation
343
342
/// descriptor.
@@ -350,6 +349,17 @@ class SYCLIntegrationHeader {
350
349
/// Registers a specialization constant to emit info for it into the header.
351
350
void addSpecConstant(StringRef IDName, QualType IDType);
352
351
352
+ /// Update the names of a kernel description based on its SyclKernel.
353
+ void updateKernelNames(const FunctionDecl *SyclKernel, StringRef Name,
354
+ StringRef StableName) {
355
+ auto Itr = llvm::find_if(KernelDescs, [SyclKernel](const KernelDesc &KD) {
356
+ return KD.SyclKernel == SyclKernel;
357
+ });
358
+
359
+ assert(Itr != KernelDescs.end() && "Unknown kernel description");
360
+ Itr->updateKernelNames(Name, StableName);
361
+ }
362
+
353
363
/// Note which free functions (this_id, this_item, etc) are called within the
354
364
/// kernel
355
365
void setCallsThisId(bool B);
@@ -385,6 +395,9 @@ class SYCLIntegrationHeader {
385
395
386
396
// Kernel invocation descriptor
387
397
struct KernelDesc {
398
+ /// sycl_kernel function associated with this kernel.
399
+ const FunctionDecl *SyclKernel;
400
+
388
401
/// Kernel name.
389
402
std::string Name;
390
403
@@ -410,11 +423,15 @@ class SYCLIntegrationHeader {
410
423
// hasn't provided an explicit name for.
411
424
bool IsUnnamedKernel;
412
425
413
- KernelDesc(StringRef Name , QualType NameType, StringRef StableName ,
426
+ KernelDesc(const FunctionDecl *SyclKernel , QualType NameType,
414
427
SourceLocation KernelLoc, bool IsESIMD, bool IsUnnamedKernel)
415
- : Name(Name), NameType(NameType), StableName(StableName),
416
- KernelLocation(KernelLoc), IsESIMDKernel(IsESIMD),
417
- IsUnnamedKernel(IsUnnamedKernel) {}
428
+ : SyclKernel(SyclKernel), NameType(NameType), KernelLocation(KernelLoc),
429
+ IsESIMDKernel(IsESIMD), IsUnnamedKernel(IsUnnamedKernel) {}
430
+
431
+ void updateKernelNames(StringRef Name, StringRef StableName) {
432
+ this->Name = Name.str();
433
+ this->StableName = StableName.str();
434
+ }
418
435
};
419
436
420
437
/// Returns the latest invocation descriptor started by
@@ -13314,12 +13331,23 @@ class Sema final {
13314
13331
std::unique_ptr<SYCLIntegrationHeader> SyclIntHeader;
13315
13332
std::unique_ptr<SYCLIntegrationFooter> SyclIntFooter;
13316
13333
13334
+ // We need to store the list of the sycl_kernel functions and their associated
13335
+ // generated OpenCL Kernels so we can go back and re-name these after the
13336
+ // fact.
13337
+ llvm::SmallVector<std::pair<const FunctionDecl *, FunctionDecl *>>
13338
+ SyclKernelsToOpenCLKernels;
13339
+
13317
13340
// Used to suppress diagnostics during kernel construction, since these were
13318
13341
// already emitted earlier. Diagnosing during Kernel emissions also skips the
13319
13342
// useful notes that shows where the kernel was called.
13320
13343
bool DiagnosingSYCLKernel = false;
13321
13344
13322
13345
public:
13346
+ void addSyclOpenCLKernel(const FunctionDecl *SyclKernel,
13347
+ FunctionDecl *OpenCLKernel) {
13348
+ SyclKernelsToOpenCLKernels.emplace_back(SyclKernel, OpenCLKernel);
13349
+ }
13350
+
13323
13351
void addSyclDeviceDecl(Decl *d) { SyclDeviceDecls.insert(d); }
13324
13352
llvm::SetVector<Decl *> &syclDeviceDecls() { return SyclDeviceDecls; }
13325
13353
@@ -13361,6 +13389,7 @@ class Sema final {
13361
13389
void checkSYCLDeviceVarDecl(VarDecl *Var);
13362
13390
void copySYCLKernelAttrs(const CXXRecordDecl *KernelObj);
13363
13391
void ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, MangleContext &MC);
13392
+ void SetSYCLKernelNames();
13364
13393
void MarkDevices();
13365
13394
13366
13395
/// Get the number of fields or captures within the parsed type.
0 commit comments