diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index c15e58e9d28b1..bde9632b19e7d 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -3757,8 +3757,14 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { PrintingPolicy Policy(LO); Policy.SuppressTypedefs = true; Policy.SuppressUnwrittenScope = true; + SYCLFwdDeclEmitter FwdDeclEmitter(O, S.getLangOpts()); if (SpecConsts.size() > 0) { + O << "// Forward declarations of templated spec constant types:\n"; + for (const auto &SC : SpecConsts) + FwdDeclEmitter.Visit(SC.first); + O << "\n"; + // Remove duplicates. std::sort(SpecConsts.begin(), SpecConsts.end(), [](const SpecConstID &SC1, const SpecConstID &SC2) { @@ -3772,10 +3778,12 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { // Here can do faster comparison of types. return SC1.first == SC2.first; }); + O << "// Specialization constants IDs:\n"; for (const auto &P : llvm::make_range(SpecConsts.begin(), End)) { O << "template <> struct sycl::detail::SpecConstantInfo<"; - O << P.first.getAsString(Policy); + SYCLKernelNameTypePrinter Printer(O, Policy); + Printer.Visit(P.first); O << "> {\n"; O << " static constexpr const char* getName() {\n"; O << " return \"" << P.second << "\";\n"; @@ -3786,8 +3794,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { if (!UnnamedLambdaSupport) { O << "// Forward declarations of templated kernel function types:\n"; - - SYCLFwdDeclEmitter FwdDeclEmitter(O, S.getLangOpts()); for (const KernelDesc &K : KernelDescs) FwdDeclEmitter.Visit(K.NameType); } diff --git a/clang/test/CodeGenSYCL/int_header_spec_const.cpp b/clang/test/CodeGenSYCL/int_header_spec_const.cpp index 39e3d2b05ad0f..c9ca63bcecf64 100644 --- a/clang/test/CodeGenSYCL/int_header_spec_const.cpp +++ b/clang/test/CodeGenSYCL/int_header_spec_const.cpp @@ -18,6 +18,10 @@ class MyUInt32Const; class MyFloatConst; class MyDoubleConst; +namespace test { +class MySpecConstantWithinANamespace; +}; + int main() { // Create specialization constants. cl::sycl::ONEAPI::experimental::spec_constant i1(false); @@ -32,13 +36,31 @@ int main() { cl::sycl::ONEAPI::experimental::spec_constant ui32(0); cl::sycl::ONEAPI::experimental::spec_constant f32(0); cl::sycl::ONEAPI::experimental::spec_constant f64(0); + // Kernel name can be used as a spec constant name + cl::sycl::ONEAPI::experimental::spec_constant spec1(0); + // Spec constant name can be declared within a namespace + cl::sycl::ONEAPI::experimental::spec_constant spec2(0); double val; double *ptr = &val; // to avoid "unused" warnings + // CHECK: // Forward declarations of templated spec constant types: + // CHECK: class MyInt8Const; + // CHECK: class MyUInt8Const; + // CHECK: class MyInt16Const; + // CHECK: class MyUInt16Const; + // CHECK: class MyInt32Const; + // CHECK: class MyUInt32Const; + // CHECK: class MyFloatConst; + // CHECK: class MyDoubleConst; + // CHECK: class SpecializedKernel; + // CHECK: namespace test { + // CHECK: class MySpecConstantWithinANamespace; + // CHECK: } + cl::sycl::kernel_single_task([=]() { *ptr = i1.get() + - // CHECK-DAG: template <> struct sycl::detail::SpecConstantInfo { + // CHECK-DAG: template <> struct sycl::detail::SpecConstantInfo<::MyBoolConst> { // CHECK-DAG-NEXT: static constexpr const char* getName() { // CHECK-DAG-NEXT: return "_ZTS11MyBoolConst"; // CHECK-DAG-NEXT: } @@ -58,7 +80,11 @@ int main() { // CHECK-DAG: return "_ZTS13MyUInt32Const"; f32.get() + // CHECK-DAG: return "_ZTS12MyFloatConst"; - f64.get(); - // CHECK-DAG: return "_ZTS13MyDoubleConst"; + f64.get() + + // CHECK-DAG: return "_ZTS13MyDoubleConst"; + spec1.get() + + // CHECK-DAG: return "_ZTS17SpecializedKernel" + spec2.get(); + // CHECK-DAG: return "_ZTSN4test30MySpecConstantWithinANamespaceE" }); }