-
Notifications
You must be signed in to change notification settings - Fork 770
[SYCL] Parallel-for range correction to improve group size selection by GPU driver #2703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
a23f9ad
047123f
73f50bd
2aad33d
ac6bf28
da1a3ab
e186605
eaacd8a
535745f
5c6c841
dc20fa1
d62e2f1
b860666
8677a2b
c340ccf
e08a478
0b878dc
d6773cb
59ae778
81b777c
900aca8
094c01d
700c056
6bfede9
383ae96
e8e7f74
e8b0872
e6d42a0
4b9093e
a2a6ded
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -510,6 +510,22 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> { | |||||
FunctionDecl *FD = WorkList.back().first; | ||||||
FunctionDecl *ParentFD = WorkList.back().second; | ||||||
|
||||||
// To implement rounding-up of a parallel-for range | ||||||
// a kernel call is modified like this: | ||||||
// auto Wrapper = [=](TransformedArgType Arg) { | ||||||
// if (Arg[0] >= NumWorkItems[0]) | ||||||
// return; | ||||||
// Arg.set_allowed_range(NumWorkItems); | ||||||
// KernelFunc(Arg); | ||||||
// }; | ||||||
// | ||||||
// This transformation leads to a condition where a kernel body | ||||||
// function becomes callable from a new kernel body function. | ||||||
// Hence this test. | ||||||
if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction(FD)) { | ||||||
KernelBody = FD; | ||||||
} | ||||||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction(FD)) { | ||||||
assert(!KernelBody && "inconsistent call graph - only one kernel body " | ||||||
"function can be called"); | ||||||
|
@@ -641,6 +657,39 @@ class FindPFWGLambdaFnVisitor | |||||
const CXXRecordDecl *LambdaObjTy; | ||||||
}; | ||||||
|
||||||
// Searches for a call to PF lambda function and captures it. | ||||||
class FindPFLambdaFnVisitor | ||||||
: public RecursiveASTVisitor<FindPFLambdaFnVisitor> { | ||||||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
public: | ||||||
// LambdaObjTy - lambda type of the PF lambda object | ||||||
FindPFLambdaFnVisitor(const CXXRecordDecl *LambdaObjTy) | ||||||
: LambdaFn(nullptr), LambdaObjTy(LambdaObjTy) {} | ||||||
|
||||||
bool VisitCallExpr(CallExpr *Call) { | ||||||
auto *M = dyn_cast<CXXMethodDecl>(Call->getDirectCallee()); | ||||||
if (!M || (M->getOverloadedOperator() != OO_Call)) | ||||||
return true; | ||||||
const int NumPFLambdaArgs = 2; // range and lambda obj | ||||||
if (Call->getNumArgs() != NumPFLambdaArgs) | ||||||
return true; | ||||||
QualType Range = Call->getArg(1)->getType(); | ||||||
if (!Util::isSyclType(Range, "id", true /*Tmpl*/) && | ||||||
!Util::isSyclType(Range, "item", true /*Tmpl*/)) | ||||||
return true; | ||||||
if (Call->getArg(0)->getType()->getAsCXXRecordDecl() != LambdaObjTy) | ||||||
return true; | ||||||
LambdaFn = M; // call to PF lambda found - record the lambda | ||||||
return false; // ... and stop searching | ||||||
} | ||||||
|
||||||
// Returns the captured lambda function or nullptr; | ||||||
CXXMethodDecl *getLambdaFn() const { return LambdaFn; } | ||||||
|
||||||
private: | ||||||
CXXMethodDecl *LambdaFn; | ||||||
const CXXRecordDecl *LambdaObjTy; | ||||||
}; | ||||||
|
||||||
class MarkWIScopeFnVisitor : public RecursiveASTVisitor<MarkWIScopeFnVisitor> { | ||||||
public: | ||||||
MarkWIScopeFnVisitor(ASTContext &Ctx) : Ctx(Ctx) {} | ||||||
|
@@ -2653,13 +2702,62 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { | |||||
return !SemaRef.getASTContext().hasSameType(FD->getType(), Ty); | ||||||
} | ||||||
|
||||||
// Sets a flag if the kernel is a parallel_for that calls the | ||||||
// free function API "this_item". | ||||||
void setThisItemIsCalled(const CXXRecordDecl *KernelObj, | ||||||
FunctionDecl *KernelFunc) { | ||||||
if (getKernelInvocationKind(KernelFunc) != InvokeParallelFor) | ||||||
return; | ||||||
|
||||||
FindPFLambdaFnVisitor V(KernelObj); | ||||||
V.TraverseStmt(KernelFunc->getBody()); | ||||||
CXXMethodDecl *WGLambdaFn = V.getLambdaFn(); | ||||||
if (!WGLambdaFn) | ||||||
return; | ||||||
|
||||||
// The call graph for this translation unit. | ||||||
CallGraph SYCLCG; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK we already build a callgraph in SemaSYCL. Can we try to re-use it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it would be nice to reuse that infrastructure. I first tried pursuing that approach. The result of a scan for calls to this_item would have to be saved somewhere. The existing callgraph traversal lead to various function "attributes" being set. This would be fine, except that calls_this_item is not an attribute. We could define an internal attribute for that. Would that be acceptable? If yes, it would simplify the SemaSYCL changes quite a bit. How to add such an attribute? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why not. You can check SYCLRequiresDecomposition for an example of internal attribute. @premanandrao @Fznamznon could you please confirm this is ok? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be ok. |
||||||
SYCLCG.addToCallGraph(SemaRef.getASTContext().getTranslationUnitDecl()); | ||||||
typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||||||
llvm::SmallPtrSet<FunctionDecl *, 16> Visited; | ||||||
llvm::SmallVector<ChildParentPair, 16> WorkList; | ||||||
WorkList.push_back({WGLambdaFn, nullptr}); | ||||||
|
||||||
while (!WorkList.empty()) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code seems also pretty similar to code which walks through another CallGraph in SemaSYCL. I think we need to merge them one day, but this is not in scope of your PR, I just wanted to capture the problem. |
||||||
FunctionDecl *FD = WorkList.back().first; | ||||||
WorkList.pop_back(); | ||||||
if (!Visited.insert(FD).second) | ||||||
continue; // We've already seen this Decl | ||||||
|
||||||
if (FD->isFunctionOrMethod() && FD->getIdentifier() && | ||||||
!FD->getName().empty() && "this_item" == FD->getName()) { | ||||||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
Header.setCallsThisItem(true); | ||||||
return; | ||||||
} | ||||||
|
||||||
CallGraphNode *N = SYCLCG.getNode(FD); | ||||||
if (!N) | ||||||
continue; | ||||||
|
||||||
for (const CallGraphNode *CI : *N) { | ||||||
if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) { | ||||||
Callee = Callee->getMostRecentDecl(); | ||||||
if (!Visited.count(Callee)) | ||||||
WorkList.push_back({Callee, FD}); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
public: | ||||||
static constexpr const bool VisitInsideSimpleContainers = false; | ||||||
SyclKernelIntHeaderCreator(Sema &S, SYCLIntegrationHeader &H, | ||||||
const CXXRecordDecl *KernelObj, QualType NameType, | ||||||
StringRef Name, StringRef StableName) | ||||||
StringRef Name, StringRef StableName, | ||||||
FunctionDecl *KernelFunc) | ||||||
: SyclKernelFieldHandler(S), Header(H) { | ||||||
Header.startKernel(Name, NameType, StableName, KernelObj->getLocation()); | ||||||
setThisItemIsCalled(KernelObj, KernelFunc); | ||||||
} | ||||||
|
||||||
bool handleSyclAccessorType(const CXXRecordDecl *RD, | ||||||
|
@@ -3101,7 +3199,7 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, | |||||
SyclKernelIntHeaderCreator int_header( | ||||||
*this, getSyclIntegrationHeader(), KernelObj, | ||||||
calculateKernelNameType(Context, KernelCallerFunc), KernelName, | ||||||
StableName); | ||||||
StableName, KernelCallerFunc); | ||||||
|
||||||
KernelObjVisitor Visitor{*this}; | ||||||
Visitor.VisitRecordBases(KernelObj, kernel_decl, kernel_body, int_header); | ||||||
|
@@ -3810,6 +3908,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { | |||||
O << "getParamDesc(unsigned i) {\n"; | ||||||
O << " return kernel_signatures[i+" << CurStart << "];\n"; | ||||||
O << " }\n"; | ||||||
O << " __SYCL_DLL_LOCAL\n"; | ||||||
O << " static constexpr bool callsThisItem() { return "; | ||||||
O << K.CallsThisItem << "; }\n"; | ||||||
O << "};\n"; | ||||||
CurStart += N; | ||||||
} | ||||||
|
@@ -3866,6 +3967,12 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) { | |||||
SpecConsts.emplace_back(std::make_pair(IDType, IDName.str())); | ||||||
} | ||||||
|
||||||
void SYCLIntegrationHeader::setCallsThisItem(bool B) { | ||||||
auto *K = getCurKernelDesc(); | ||||||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
assert(K && "no kernels"); | ||||||
K->CallsThisItem = B; | ||||||
} | ||||||
|
||||||
SYCLIntegrationHeader::SYCLIntegrationHeader(DiagnosticsEngine &_Diag, | ||||||
bool _UnnamedLambdaSupport, | ||||||
Sema &_S) | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,6 +120,14 @@ template <typename Type> struct get_kernel_name_t<detail::auto_name, Type> { | |
using name = Type; | ||
}; | ||
|
||
// Used when parallel_for range is rounded-up. | ||
template <typename Type> class __pf_kernel_wrapper; | ||
|
||
template <typename Type> struct get_kernel_wrapper_name_t { | ||
using name = __pf_kernel_wrapper< | ||
typename get_kernel_name_t<detail::auto_name, Type>::name>; | ||
}; | ||
|
||
template <typename, typename T> struct check_fn_signature { | ||
static_assert(std::integral_constant<T, false>::value, | ||
"Second template parameter is required to be of function type"); | ||
|
@@ -728,23 +736,79 @@ class __SYCL_EXPORT handler { | |
void parallel_for_lambda_impl(range<Dims> NumWorkItems, | ||
KernelType KernelFunc) { | ||
throwIfActionIsCreated(); | ||
using NameT = | ||
typename detail::get_kernel_name_t<KernelName, KernelType>::name; | ||
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>; | ||
|
||
// If 1D kernel argument is an integral type, convert it to sycl::item<1> | ||
using TransformedArgType = | ||
typename detail::conditional_t<std::is_integral<LambdaArgType>::value && | ||
Dims == 1, | ||
item<Dims>, LambdaArgType>; | ||
typename std::conditional<std::is_integral<LambdaArgType>::value && | ||
Dims == 1, | ||
item<Dims>, LambdaArgType>::type; | ||
using NameT = | ||
typename detail::get_kernel_name_t<KernelName, KernelType>::name; | ||
|
||
// A reasonable choice for rounding up the range is 32. | ||
constexpr size_t GoodLocalSizeX = 32; | ||
|
||
// Disable the rounding-up optimizations under these conditions: | ||
// 1. The env var SYCL_OPT_PFWGS_DISABLE is set | ||
// 2. When the string SYCL_OPT_PFWGS_DISABLE is in the kernel name. | ||
// 3. The kernel is created and invoked without an integration header entry. | ||
// 4. The API "this_item" is used inside the kernel. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the rationale for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My question was actually for the point "4." Otherwise I find the point "3." quite clearer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah thanks, very useful. Now I can understand what's going on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added to the comments. |
||
// 5. The range is already a multiple of the rounding factor. | ||
|
||
// Get the kernal name to check condition 3. | ||
std::string KName = typeid(NameT *).name(); | ||
using KI = detail::KernelInfo<KernelName>; | ||
bool DisableRounding = | ||
(getenv("SYCL_OPT_PFWGS_DISABLE") != nullptr) || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can env var names be less cryptic? Also we need to document them in https://github.com/intel/llvm/blob/sycl/sycl/doc/EnvironmentVariables.md There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
(KName.find("SYCL_OPT_PFWGS_DISABLE") != std::string::npos) || | ||
(KI::getName() == nullptr || KI::getName()[0] == '\0') || | ||
(KI::callsThisItem()); | ||
|
||
// Perform range rounding if rounding-up is enabled | ||
// and the user-specified range is not a multiple of a "good" value. | ||
if (!DisableRounding && NumWorkItems[0] % GoodLocalSizeX != 0) { | ||
rdeodhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// It is sufficient to round up just the first dimension. | ||
// Multiplying the rounded-up value of the first dimension | ||
// by the values of the remaining dimensions (if any) | ||
// will yield a rounded-up value for the total range. | ||
size_t NewValX = | ||
((NumWorkItems[0] + GoodLocalSizeX - 1) / GoodLocalSizeX) * | ||
GoodLocalSizeX; | ||
using NameWT = typename detail::get_kernel_wrapper_name_t<NameT>::name; | ||
if (getenv("SYCL_OPT_PFWGS_TRACE") != nullptr) | ||
std::cerr << "***** Adjusted size from " << NumWorkItems[0] << " to " | ||
<< NewValX << " *****\n"; | ||
auto Wrapper = [=](TransformedArgType Arg) { | ||
if (Arg[0] >= NumWorkItems[0]) | ||
return; | ||
Arg.set_allowed_range(NumWorkItems); | ||
KernelFunc(Arg); | ||
}; | ||
|
||
range<Dims> AdjustedRange = NumWorkItems; | ||
AdjustedRange.set_range(NewValX); | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
(void)NumWorkItems; | ||
kernel_parallel_for<NameT, TransformedArgType>(KernelFunc); | ||
kernel_parallel_for<NameWT, TransformedArgType>(Wrapper); | ||
#else | ||
detail::checkValueRange<Dims>(NumWorkItems); | ||
MNDRDesc.set(std::move(NumWorkItems)); | ||
StoreLambda<NameT, KernelType, Dims, TransformedArgType>( | ||
std::move(KernelFunc)); | ||
MCGType = detail::CG::KERNEL; | ||
detail::checkValueRange<Dims>(AdjustedRange); | ||
MNDRDesc.set(std::move(AdjustedRange)); | ||
StoreLambda<NameWT, decltype(Wrapper), Dims, TransformedArgType>( | ||
std::move(Wrapper)); | ||
MCGType = detail::CG::KERNEL; | ||
#endif | ||
} else { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
(void)NumWorkItems; | ||
kernel_parallel_for<NameT, TransformedArgType>(KernelFunc); | ||
#else | ||
detail::checkValueRange<Dims>(NumWorkItems); | ||
MNDRDesc.set(std::move(NumWorkItems)); | ||
StoreLambda<NameT, KernelType, Dims, TransformedArgType>( | ||
std::move(KernelFunc)); | ||
MCGType = detail::CG::KERNEL; | ||
#endif | ||
} | ||
} | ||
|
||
/// Defines and invokes a SYCL kernel function for the specified range. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
|
||
#pragma once | ||
#include <CL/sycl/detail/array.hpp> | ||
#include <CL/sycl/detail/helpers.hpp> | ||
|
||
#include <stdexcept> | ||
#include <type_traits> | ||
|
@@ -141,6 +142,13 @@ template <int dimensions = 1> class range : public detail::array<dimensions> { | |
__SYCL_GEN_OPT(^=) | ||
|
||
#undef __SYCL_GEN_OPT | ||
|
||
private: | ||
friend class handler; | ||
friend class detail::Builder; | ||
|
||
// Adjust the first dim of the range | ||
void set_range(const size_t dim0) { this->common_array[0] = dim0; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sets first dim - better reflect this in the name (e.g. set_zero_dim) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed name. |
||
}; | ||
|
||
#ifdef __cpp_deduction_guides | ||
|
Uh oh!
There was an error while loading. Please reload this page.