diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index e0eefa5f72155..c64123f010dd6 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -346,6 +346,9 @@ class SYCLIntegrationHeader { /// Registers a specialization constant to emit info for it into the header. void addSpecConstant(StringRef IDName, QualType IDType); + /// Notes that this_item is called within the kernel. + void setCallsThisItem(bool B); + private: // Kernel actual parameter descriptor. struct KernelParamDesc { @@ -382,6 +385,9 @@ class SYCLIntegrationHeader { /// Descriptor of kernel actual parameters. SmallVector Params; + // Whether kernel calls this_item() + bool CallsThisItem; + KernelDesc() = default; }; diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 5ab884860dcd9..1dd2c8bb64fa8 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -99,10 +99,23 @@ class Util { /// \param Tmpl whether the class is template instantiation or simple record static bool isSyclType(const QualType &Ty, StringRef Name, bool Tmpl = false); + /// Checks whether given function is a standard SYCL API function with given + /// name. + /// \param FD the function being checked. + /// \param Name the function name to be checked against. + static bool isSyclFunction(const FunctionDecl *FD, StringRef Name); + /// Checks whether given clang type is a full specialization of the SYCL /// specialization constant class. static bool isSyclSpecConstantType(const QualType &Ty); + // Checks declaration context hierarchy. + /// \param DC the context of the item to be checked. + /// \param Scopes the declaration scopes leading from the item context to the + /// translation unit (excluding the latter) + static bool matchContext(const DeclContext *DC, + ArrayRef Scopes); + /// Checks whether given clang type is declared in the given hierarchy of /// declaration contexts. /// \param Ty the clang type being checked @@ -511,6 +524,21 @@ class MarkDeviceFunction : public RecursiveASTVisitor { FunctionDecl *FD = WorkList.back().first; FunctionDecl *ParentFD = WorkList.back().second; + // To implement rounding-up of a parallel-for range the + // SYCL header implementation modifies the kernel call like this: + // auto Wrapper = [=](TransformedArgType Arg) { + // if (Arg[0] >= NumWorkItems[0]) + // return; + // Arg.set_allowed_range(NumWorkItems); + // KernelFunc(Arg); + // }; + // + // This transformation leads to a condition where a kernel body + // function becomes callable from a new kernel body function. + // Hence this test. + if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction(FD)) + KernelBody = FD; + if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction(FD)) { assert(!KernelBody && "inconsistent call graph - only one kernel body " "function can be called"); @@ -2691,15 +2719,63 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return !SemaRef.getASTContext().hasSameType(FD->getType(), Ty); } + // Sets a flag if the kernel is a parallel_for that calls the + // free function API "this_item". + void setThisItemIsCalled(const CXXRecordDecl *KernelObj, + FunctionDecl *KernelFunc) { + if (getKernelInvocationKind(KernelFunc) != InvokeParallelFor) + return; + + const CXXMethodDecl *WGLambdaFn = getOperatorParens(KernelObj); + if (!WGLambdaFn) + return; + + // The call graph for this translation unit. + CallGraph SYCLCG; + SYCLCG.addToCallGraph(SemaRef.getASTContext().getTranslationUnitDecl()); + using ChildParentPair = + std::pair; + llvm::SmallPtrSet Visited; + llvm::SmallVector WorkList; + WorkList.push_back({WGLambdaFn, nullptr}); + + while (!WorkList.empty()) { + const FunctionDecl *FD = WorkList.back().first; + WorkList.pop_back(); + if (!Visited.insert(FD).second) + continue; // We've already seen this Decl + + // Check whether this call is to sycl::this_item(). + if (Util::isSyclFunction(FD, "this_item")) { + Header.setCallsThisItem(true); + return; + } + + CallGraphNode *N = SYCLCG.getNode(FD); + if (!N) + continue; + + for (const CallGraphNode *CI : *N) { + if (auto *Callee = dyn_cast(CI->getDecl())) { + Callee = Callee->getMostRecentDecl(); + if (!Visited.count(Callee)) + WorkList.push_back({Callee, FD}); + } + } + } + } + public: static constexpr const bool VisitInsideSimpleContainers = false; SyclKernelIntHeaderCreator(Sema &S, SYCLIntegrationHeader &H, const CXXRecordDecl *KernelObj, QualType NameType, - StringRef Name, StringRef StableName) + StringRef Name, StringRef StableName, + FunctionDecl *KernelFunc) : SyclKernelFieldHandler(S), Header(H) { bool IsSIMDKernel = isESIMDKernelType(KernelObj); Header.startKernel(Name, NameType, StableName, KernelObj->getLocation(), IsSIMDKernel); + setThisItemIsCalled(KernelObj, KernelFunc); } bool handleSyclAccessorType(const CXXRecordDecl *RD, @@ -3146,7 +3222,7 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, SyclKernelIntHeaderCreator int_header( *this, getSyclIntegrationHeader(), KernelObj, calculateKernelNameType(Context, KernelCallerFunc), KernelName, - StableName); + StableName, KernelCallerFunc); KernelObjVisitor Visitor{*this}; Visitor.VisitRecordBases(KernelObj, kernel_decl, kernel_body, int_header); @@ -3858,6 +3934,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { O << " __SYCL_DLL_LOCAL\n"; O << " static constexpr bool isESIMD() { return " << K.IsESIMDKernel << "; }\n"; + O << " __SYCL_DLL_LOCAL\n"; + O << " static constexpr bool callsThisItem() { return "; + O << K.CallsThisItem << "; }\n"; O << "};\n"; CurStart += N; } @@ -3916,6 +3995,12 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) { SpecConsts.emplace_back(std::make_pair(IDType, IDName.str())); } +void SYCLIntegrationHeader::setCallsThisItem(bool B) { + KernelDesc *K = getCurKernelDesc(); + assert(K && "no kernels"); + K->CallsThisItem = B; +} + SYCLIntegrationHeader::SYCLIntegrationHeader(DiagnosticsEngine &_Diag, bool _UnnamedLambdaSupport, Sema &_S) @@ -3983,6 +4068,21 @@ bool Util::isSyclType(const QualType &Ty, StringRef Name, bool Tmpl) { return matchQualifiedTypeName(Ty, Scopes); } +bool Util::isSyclFunction(const FunctionDecl *FD, StringRef Name) { + if (!FD->isFunctionOrMethod() || !FD->getIdentifier() || + FD->getName().empty() || Name != FD->getName()) + return false; + + const DeclContext *DC = FD->getDeclContext(); + if (DC->isTranslationUnit()) + return false; + + std::array Scopes = { + Util::DeclContextDesc{clang::Decl::Kind::Namespace, "cl"}, + Util::DeclContextDesc{clang::Decl::Kind::Namespace, "sycl"}}; + return matchContext(DC, Scopes); +} + bool Util::isAccessorPropertyListType(const QualType &Ty) { const StringRef &Name = "accessor_property_list"; std::array Scopes = { @@ -3993,21 +4093,15 @@ bool Util::isAccessorPropertyListType(const QualType &Ty) { return matchQualifiedTypeName(Ty, Scopes); } -bool Util::matchQualifiedTypeName(const QualType &Ty, - ArrayRef Scopes) { - // The idea: check the declaration context chain starting from the type +bool Util::matchContext(const DeclContext *Ctx, + ArrayRef Scopes) { + // The idea: check the declaration context chain starting from the item // itself. At each step check the context is of expected kind // (namespace) and name. - const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl(); - - if (!RecTy) - return false; // only classes/structs supported - const auto *Ctx = cast(RecTy); StringRef Name = ""; for (const auto &Scope : llvm::reverse(Scopes)) { clang::Decl::Kind DK = Ctx->getDeclKind(); - if (DK != Scope.first) return false; @@ -4021,7 +4115,7 @@ bool Util::matchQualifiedTypeName(const QualType &Ty, Name = cast(Ctx)->getName(); break; default: - llvm_unreachable("matchQualifiedTypeName: decl kind not supported"); + llvm_unreachable("matchContext: decl kind not supported"); } if (Name != Scope.second) return false; @@ -4029,3 +4123,13 @@ bool Util::matchQualifiedTypeName(const QualType &Ty, } return Ctx->isTranslationUnit(); } + +bool Util::matchQualifiedTypeName(const QualType &Ty, + ArrayRef Scopes) { + const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl(); + + if (!RecTy) + return false; // only classes/structs supported + const auto *Ctx = cast(RecTy); + return Util::matchContext(Ctx, Scopes); +} diff --git a/clang/test/CodeGenSYCL/Inputs/sycl.hpp b/clang/test/CodeGenSYCL/Inputs/sycl.hpp index 72d1c284b39f2..0f71db428018a 100644 --- a/clang/test/CodeGenSYCL/Inputs/sycl.hpp +++ b/clang/test/CodeGenSYCL/Inputs/sycl.hpp @@ -118,6 +118,18 @@ struct id { int Data; }; +template struct item { + template + item(T... args) {} // fake constructor +private: + // Some fake field added to see using of item arguments in the + // kernel wrapper + int Data; +}; + +template item +this_item() { return item{}; } + template struct range { template diff --git a/clang/test/CodeGenSYCL/kernel-by-reference.cpp b/clang/test/CodeGenSYCL/kernel-by-reference.cpp index 6502cddf602d8..f5bbac0e75730 100644 --- a/clang/test/CodeGenSYCL/kernel-by-reference.cpp +++ b/clang/test/CodeGenSYCL/kernel-by-reference.cpp @@ -15,7 +15,7 @@ int simple_add(int i) { int main() { queue q; #if defined(SYCL2020) - // expected-warning@Inputs/sycl.hpp:286 {{Passing kernel functions by value is deprecated in SYCL 2020}} + // expected-warning@Inputs/sycl.hpp:298 {{Passing kernel functions by value is deprecated in SYCL 2020}} // expected-note@+3 {{in instantiation of function template specialization}} #endif q.submit([&](handler &h) { @@ -23,7 +23,7 @@ int main() { }); #if defined(SYCL2017) - // expected-warning@Inputs/sycl.hpp:281 {{Passing of kernel functions by reference is a SYCL 2020 extension}} + // expected-warning@Inputs/sycl.hpp:293 {{Passing of kernel functions by reference is a SYCL 2020 extension}} // expected-note@+3 {{in instantiation of function template specialization}} #endif q.submit([&](handler &h) { diff --git a/clang/test/CodeGenSYCL/parallel_for_this_item.cpp b/clang/test/CodeGenSYCL/parallel_for_this_item.cpp new file mode 100755 index 0000000000000..422a1bad33373 --- /dev/null +++ b/clang/test/CodeGenSYCL/parallel_for_this_item.cpp @@ -0,0 +1,114 @@ +// RUN: %clang_cc1 -fsycl -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -fsyntax-only +// RUN: FileCheck -input-file=%t.h %s + +// This test checks that compiler generates correct kernel description +// for parallel_for kernels that use the this_item API. + +// CHECK: __SYCL_INLINE_NAMESPACE(cl) { +// CHECK-NEXT: namespace sycl { +// CHECK-NEXT: namespace detail { + +// CHECK: static constexpr +// CHECK-NEXT: const char* const kernel_names[] = { +// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3GNU", +// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3EMU", +// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3OWL", +// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3RAT" +// CHECK-NEXT: }; + +// CHECK:template <> struct KernelInfo { +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3GNU"; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { +// CHECK-NEXT: return kernel_signatures[i+0]; +// CHECK-NEXT: } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool callsThisItem() { return 0; } +// CHECK-NEXT:}; +// CHECK-NEXT:template <> struct KernelInfo { +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3EMU"; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { +// CHECK-NEXT: return kernel_signatures[i+0]; +// CHECK-NEXT: } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool callsThisItem() { return 1; } +// CHECK-NEXT:}; +// CHECK-NEXT:template <> struct KernelInfo { +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3OWL"; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { +// CHECK-NEXT: return kernel_signatures[i+0]; +// CHECK-NEXT: } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool callsThisItem() { return 0; } +// CHECK-NEXT:}; +// CHECK-NEXT:template <> struct KernelInfo { +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3RAT"; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { +// CHECK-NEXT: return kernel_signatures[i+0]; +// CHECK-NEXT: } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } +// CHECK-NEXT: __SYCL_DLL_LOCAL +// CHECK-NEXT: static constexpr bool callsThisItem() { return 1; } +// CHECK-NEXT:}; + +#include "sycl.hpp" + +using namespace cl::sycl; + +SYCL_EXTERNAL item<1> g() { return this_item<1>(); } +SYCL_EXTERNAL item<1> f() { return g(); } + +// This is a similar-looking this_item function but not the real one. +template item this_item(int i) { return item<1>{i}; } + +// This is a method named this_item but not the real one. +class C { +public: + template item this_item() { return item<1>{66}; }; +}; + +int main() { + queue myQueue; + myQueue.submit([&](::handler &cgh) { + // This kernel does not call sycl::this_item + cgh.parallel_for(range<1>(1), + [=](item<1> I) { this_item<1>(55); }); + + // This kernel calls sycl::this_item + cgh.parallel_for(range<1>(1), + [=](::item<1> I) { this_item<1>(); }); + + // This kernel does not call sycl::this_item + cgh.parallel_for(range<1>(1), [=](id<1> I) { + class C c; + c.this_item<1>(); + }); + + // This kernel calls sycl::this_item + cgh.parallel_for(range<1>(1), [=](id<1> I) { f(); }); + }); + + return 0; +} diff --git a/sycl/doc/EnvironmentVariables.md b/sycl/doc/EnvironmentVariables.md index adc9e6d0ab89f..2b3e7d6577879 100644 --- a/sycl/doc/EnvironmentVariables.md +++ b/sycl/doc/EnvironmentVariables.md @@ -29,6 +29,8 @@ subject to change. Do not rely on these variables in production code. | SYCL_PI_LEVEL_ZERO_MAX_COMMAND_LIST_CACHE | Positive integer | Maximum number of oneAPI Level Zero Command lists that can be allocated with no reuse before throwing an "out of resources" error. Default is 20000, threshold may be increased based on resource availabilty and workload demand. | | SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR | Any(\*) | Disable USM allocator in Level Zero plugin (each memory request will go directly to Level Zero runtime) | | SYCL_PI_LEVEL_ZERO_BATCH_SIZE | Integer | Sets a preferred number of commands to batch into a command list before executing the command list. A value of 0 causes the batch size to be adjusted dynamically. A value greater than 0 specifies fixed size batching, with the batch size set to the specified value. The default is 0. | +| SYCL_PARALLEL_FOR_RANGE_ROUNDING_TRACE | Any(\*) | Enables tracing of parallel_for invocations with rounded-up ranges. | +| SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING | Any(\*) | Disables automatic rounding-up of parallel_for invocation ranges. | `(*) Note: Any means this environment variable is effective when set to any non-null value.` diff --git a/sycl/include/CL/sycl/detail/kernel_desc.hpp b/sycl/include/CL/sycl/detail/kernel_desc.hpp index 5c53b67f46b93..f3bf02b1b1492 100644 --- a/sycl/include/CL/sycl/detail/kernel_desc.hpp +++ b/sycl/include/CL/sycl/detail/kernel_desc.hpp @@ -57,6 +57,7 @@ template struct KernelInfo { } static constexpr const char *getName() { return ""; } static constexpr bool isESIMD() { return 0; } + static constexpr bool callsThisItem() { return false; } }; #else template struct KernelInfoData { @@ -67,6 +68,7 @@ template struct KernelInfoData { } static constexpr const char *getName() { return ""; } static constexpr bool isESIMD() { return 0; } + static constexpr bool callsThisItem() { return false; } }; // C++14 like index_sequence and make_index_sequence diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index 5fe5fc91ac4b6..f0eb9c4872fc2 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -120,6 +120,14 @@ template struct get_kernel_name_t { using name = Type; }; +// Used when parallel_for range is rounded-up. +template class __pf_kernel_wrapper; + +template struct get_kernel_wrapper_name_t { + using name = __pf_kernel_wrapper< + typename get_kernel_name_t::name>; +}; + template struct check_fn_signature { static_assert(std::integral_constant::value, "Second template parameter is required to be of function type"); @@ -742,23 +750,92 @@ class __SYCL_EXPORT handler { void parallel_for_lambda_impl(range NumWorkItems, KernelType KernelFunc) { throwIfActionIsCreated(); - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; + + // If 1D kernel argument is an integral type, convert it to sycl::item<1> using TransformedArgType = - typename detail::conditional_t::value && - Dims == 1, - item, LambdaArgType>; + typename std::conditional::value && + Dims == 1, + item, LambdaArgType>::type; + using NameT = + typename detail::get_kernel_name_t::name; + + // The work group size preferred by this device. + // A reasonable choice for rounding up the range is 32. + constexpr size_t GoodLocalSizeX = 32; + + // Disable the rounding-up optimizations under these conditions: + // 1. The env var SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING is set. + // 2. The string SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING is in + // the kernel name. + // 3. The kernel is provided via an interoperability method. + // 4. The API "this_item" is used inside the kernel. + // 5. The range is already a multiple of the rounding factor. + // + // Cases 3 and 4 could be supported with extra effort. + // As an optimization for the common case it is an + // implementation choice to not support those scenarios. + // Note that "this_item" is a free function, i.e. not tied to any + // specific id or item. When concurrent parallel_fors are executing + // on a device it is difficult to tell which parallel_for the call is + // being made from. One could replicate portions of the + // call-graph to make this_item calls kernel-specific but this is + // not considered worthwhile. + + // Get the kernal name to check condition 3. + std::string KName = typeid(NameT *).name(); + using KI = detail::KernelInfo; + bool DisableRounding = + (getenv("SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING") != nullptr) || + (KName.find("SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING") != + std::string::npos) || + (KI::getName() == nullptr || KI::getName()[0] == '\0') || + (KI::callsThisItem()); + + // Perform range rounding if rounding-up is enabled + // and the user-specified range is not a multiple of a "good" value. + if (!DisableRounding && NumWorkItems[0] % GoodLocalSizeX != 0) { + // It is sufficient to round up just the first dimension. + // Multiplying the rounded-up value of the first dimension + // by the values of the remaining dimensions (if any) + // will yield a rounded-up value for the total range. + size_t NewValX = + ((NumWorkItems[0] + GoodLocalSizeX - 1) / GoodLocalSizeX) * + GoodLocalSizeX; + using NameWT = typename detail::get_kernel_wrapper_name_t::name; + if (getenv("SYCL_PARALLEL_FOR_RANGE_ROUNDING_TRACE") != nullptr) + std::cout << "parallel_for range adjusted from " << NumWorkItems[0] + << " to " << NewValX << std::endl; + auto Wrapper = [=](TransformedArgType Arg) { + if (Arg[0] >= NumWorkItems[0]) + return; + Arg.set_allowed_range(NumWorkItems); + KernelFunc(Arg); + }; + + range AdjustedRange = NumWorkItems; + AdjustedRange.set_range_dim0(NewValX); #ifdef __SYCL_DEVICE_ONLY__ - (void)NumWorkItems; - kernel_parallel_for(KernelFunc); + kernel_parallel_for(Wrapper); #else - detail::checkValueRange(NumWorkItems); - MNDRDesc.set(std::move(NumWorkItems)); - StoreLambda( - std::move(KernelFunc)); - MCGType = detail::CG::KERNEL; + detail::checkValueRange(AdjustedRange); + MNDRDesc.set(std::move(AdjustedRange)); + StoreLambda( + std::move(Wrapper)); + MCGType = detail::CG::KERNEL; #endif + } else { +#ifdef __SYCL_DEVICE_ONLY__ + (void)NumWorkItems; + kernel_parallel_for(KernelFunc); +#else + detail::checkValueRange(NumWorkItems); + MNDRDesc.set(std::move(NumWorkItems)); + StoreLambda( + std::move(KernelFunc)); + MCGType = detail::CG::KERNEL; +#endif + } } /// Defines and invokes a SYCL kernel function for the specified range. diff --git a/sycl/include/CL/sycl/id.hpp b/sycl/include/CL/sycl/id.hpp index 16d176b8b698d..151657aa661e8 100644 --- a/sycl/include/CL/sycl/id.hpp +++ b/sycl/include/CL/sycl/id.hpp @@ -239,6 +239,10 @@ template class id : public detail::array { __SYCL_GEN_OPT(^=) #undef __SYCL_GEN_OPT + +private: + friend class handler; + void set_allowed_range(range rnwi) { (void)rnwi[0]; } }; namespace detail { diff --git a/sycl/include/CL/sycl/item.hpp b/sycl/include/CL/sycl/item.hpp index 9d9a879815294..a8aa9c8ef09f5 100644 --- a/sycl/include/CL/sycl/item.hpp +++ b/sycl/include/CL/sycl/item.hpp @@ -118,6 +118,9 @@ template class item { friend class detail::Builder; private: + friend class handler; + void set_allowed_range(const range rnwi) { MImpl.MExtent = rnwi; } + detail::ItemBase MImpl; }; diff --git a/sycl/include/CL/sycl/range.hpp b/sycl/include/CL/sycl/range.hpp index 0fdfa3cb9c494..32337109f97a9 100644 --- a/sycl/include/CL/sycl/range.hpp +++ b/sycl/include/CL/sycl/range.hpp @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -141,6 +142,13 @@ template class range : public detail::array { __SYCL_GEN_OPT(^=) #undef __SYCL_GEN_OPT + +private: + friend class handler; + friend class detail::Builder; + + // Adjust the first dim of the range + void set_range_dim0(const size_t dim0) { this->common_array[0] = dim0; } }; #ifdef __cpp_deduction_guides diff --git a/sycl/test/basic_tests/parallel_for_range_roundup.cpp b/sycl/test/basic_tests/parallel_for_range_roundup.cpp new file mode 100755 index 0000000000000..a4a8f45c2ae92 --- /dev/null +++ b/sycl/test/basic_tests/parallel_for_range_roundup.cpp @@ -0,0 +1,192 @@ +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: env SYCL_PARALLEL_FOR_RANGE_ROUNDING_TRACE=1 %GPU_RUN_PLACEHOLDER %t.out %GPU_CHECK_PLACEHOLDER + +#include + +using namespace sycl; + +range<1> Range1 = {0}; +range<2> Range2 = {0, 0}; +range<3> Range3 = {0, 0, 0}; + +void check(const char *msg, size_t v, size_t ref) { + std::cout << msg << v << std::endl; + assert(v == ref); +} + +int try_item1(size_t size) { + range<1> Size{size}; + int Counter = 0; + { + buffer, 1> BufRange(&Range1, 1); + buffer BufCounter(&Counter, 1); + queue myQueue; + + myQueue.submit([&](handler &cgh) { + auto AccRange = BufRange.get_access(cgh); + auto AccCounter = BufCounter.get_access(cgh); + cgh.parallel_for(Size, [=](item<1> ITEM) { + AccCounter[0].fetch_add(1); + AccRange[0] = ITEM.get_range(0); + }); + }); + myQueue.wait(); + } + check("Size seen by user = ", Range1.get(0), size); + check("Counter = ", Counter, size); + return 0; +} + +void try_item2(size_t size) { + range<2> Size{size, size}; + int Counter = 0; + { + buffer, 1> BufRange(&Range2, 1); + buffer BufCounter(&Counter, 1); + queue myQueue; + + myQueue.submit([&](handler &cgh) { + auto AccRange = BufRange.get_access(cgh); + auto AccCounter = BufCounter.get_access(cgh); + cgh.parallel_for(Size, [=](item<2> ITEM) { + AccCounter[0].fetch_add(1); + AccRange[0][0] = ITEM.get_range(0); + }); + }); + myQueue.wait(); + } + check("Size seen by user = ", Range2.get(0), size); + check("Counter = ", Counter, size * size); +} + +void try_item3(size_t size) { + range<3> Size{size, size, size}; + int Counter = 0; + { + buffer, 1> BufRange(&Range3, 1); + buffer BufCounter(&Counter, 1); + queue myQueue; + + myQueue.submit([&](handler &cgh) { + auto AccRange = BufRange.get_access(cgh); + auto AccCounter = BufCounter.get_access(cgh); + cgh.parallel_for(Size, [=](item<3> ITEM) { + AccCounter[0].fetch_add(1); + AccRange[0][0] = ITEM.get_range(0); + }); + }); + myQueue.wait(); + } + check("Size seen by user = ", Range3.get(0), size); + check("Counter = ", Counter, size * size * size); +} + +void try_id1(size_t size) { + range<1> Size{size}; + int Counter = 0; + { + buffer, 1> BufRange(&Range1, 1); + buffer BufCounter(&Counter, 1); + queue myQueue; + + myQueue.submit([&](handler &cgh) { + auto AccRange = BufRange.get_access(cgh); + auto AccCounter = BufCounter.get_access(cgh); + cgh.parallel_for(Size, [=](id<1> ID) { + AccCounter[0].fetch_add(1); + AccRange[0] = ID[0]; + }); + }); + myQueue.wait(); + } + check("Counter = ", Counter, size); +} + +void try_id2(size_t size) { + range<2> Size{size, size}; + int Counter = 0; + { + buffer, 1> BufRange(&Range2, 1); + buffer BufCounter(&Counter, 1); + queue myQueue; + + myQueue.submit([&](handler &cgh) { + auto AccRange = BufRange.get_access(cgh); + auto AccCounter = BufCounter.get_access(cgh); + cgh.parallel_for(Size, [=](id<2> ID) { + AccCounter[0].fetch_add(1); + AccRange[0][0] = ID[0]; + }); + }); + myQueue.wait(); + } + check("Counter = ", Counter, size * size); +} + +void try_id3(size_t size) { + range<3> Size{size, size, size}; + int Counter = 0; + { + buffer, 1> BufRange(&Range3, 1); + buffer BufCounter(&Counter, 1); + queue myQueue; + + myQueue.submit([&](handler &cgh) { + auto AccRange = BufRange.get_access(cgh); + auto AccCounter = BufCounter.get_access(cgh); + cgh.parallel_for(Size, [=](id<3> ID) { + AccCounter[0].fetch_add(1); + AccRange[0][0] = ID[0]; + }); + }); + myQueue.wait(); + } + check("Counter = ", Counter, size * size * size); +} + +int main() { + int x; + + x = 10; + try_item1(x); + try_item2(x); + try_item3(x); + try_id1(x); + try_id2(x); + try_id3(x); + + x = 256; + try_item1(x); + try_item2(x); + try_item3(x); + try_id1(x); + try_id2(x); + try_id3(x); + + return 0; +} + +// CHECK: parallel_for range adjusted from 10 to 32 +// CHECK-NEXT: Size seen by user = 10 +// CHECK-NEXT: Counter = 10 +// CHECK-NEXT: parallel_for range adjusted from 10 to 32 +// CHECK-NEXT: Size seen by user = 10 +// CHECK-NEXT: Counter = 100 +// CHECK-NEXT: parallel_for range adjusted from 10 to 32 +// CHECK-NEXT: Size seen by user = 10 +// CHECK-NEXT: Counter = 1000 +// CHECK-NEXT: parallel_for range adjusted from 10 to 32 +// CHECK-NEXT: Counter = 10 +// CHECK-NEXT: parallel_for range adjusted from 10 to 32 +// CHECK-NEXT: Counter = 100 +// CHECK-NEXT: parallel_for range adjusted from 10 to 32 +// CHECK-NEXT: Counter = 1000 +// CHECK-NEXT: Size seen by user = 256 +// CHECK-NEXT: Counter = 256 +// CHECK-NEXT: Size seen by user = 256 +// CHECK-NEXT: Counter = 65536 +// CHECK-NEXT: Size seen by user = 256 +// CHECK-NEXT: Counter = 16777216 +// CHECK-NEXT: Counter = 256 +// CHECK-NEXT: Counter = 65536 +// CHECK-NEXT: Counter = 16777216