-
Notifications
You must be signed in to change notification settings - Fork 770
[SYCL] Parallel-for range correction to improve group size selection by GPU driver #2703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
a23f9ad
047123f
73f50bd
2aad33d
ac6bf28
da1a3ab
e186605
eaacd8a
535745f
5c6c841
dc20fa1
d62e2f1
b860666
8677a2b
c340ccf
e08a478
0b878dc
d6773cb
59ae778
81b777c
900aca8
094c01d
700c056
6bfede9
383ae96
e8e7f74
e8b0872
e6d42a0
4b9093e
a2a6ded
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,10 @@ class Util { | |
/// stream class. | ||
static bool isSyclStreamType(const QualType &Ty); | ||
|
||
/// Checks whether given clang type is a full specialization of the SYCL | ||
/// item class. | ||
static bool isSyclItemType(const QualType &Ty); | ||
|
||
/// Checks whether given clang type is a full specialization of the SYCL | ||
/// half class. | ||
static bool isSyclHalfType(const QualType &Ty); | ||
|
@@ -511,6 +515,21 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> { | |
FunctionDecl *FD = WorkList.back().first; | ||
FunctionDecl *ParentFD = WorkList.back().second; | ||
|
||
// To implement rounding-up of a parallel-for range the | ||
// SYCL header implementation modifies the kernel call like this: | ||
// auto Wrapper = [=](TransformedArgType Arg) { | ||
// if (Arg[0] >= NumWorkItems[0]) | ||
// return; | ||
// Arg.set_allowed_range(NumWorkItems); | ||
// KernelFunc(Arg); | ||
// }; | ||
// | ||
// This transformation leads to a condition where a kernel body | ||
// function becomes callable from a new kernel body function. | ||
// Hence this test. | ||
if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction(FD)) | ||
KernelBody = FD; | ||
|
||
if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction(FD)) { | ||
assert(!KernelBody && "inconsistent call graph - only one kernel body " | ||
"function can be called"); | ||
|
@@ -642,6 +661,39 @@ class FindPFWGLambdaFnVisitor | |
const CXXRecordDecl *LambdaObjTy; | ||
}; | ||
|
||
// Searches for a call to PF lambda function and captures it. | ||
class FindPFLambdaFnVisitor | ||
: public RecursiveASTVisitor<FindPFLambdaFnVisitor> { | ||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
public: | ||
// LambdaObjTy - lambda type of the PF lambda object | ||
FindPFLambdaFnVisitor(const CXXRecordDecl *LambdaObjTy) | ||
: LambdaFn(nullptr), LambdaObjTy(LambdaObjTy) {} | ||
|
||
bool VisitCallExpr(CallExpr *Call) { | ||
auto *M = dyn_cast<CXXMethodDecl>(Call->getDirectCallee()); | ||
if (!M || (M->getOverloadedOperator() != OO_Call)) | ||
return true; | ||
const int NumPFLambdaArgs = 2; // range and lambda obj | ||
if (Call->getNumArgs() != NumPFLambdaArgs) | ||
return true; | ||
QualType Range = Call->getArg(1)->getType(); | ||
if (!Util::isSyclType(Range, "id", true /*Tmpl*/) && | ||
!Util::isSyclType(Range, "item", true /*Tmpl*/)) | ||
return true; | ||
if (Call->getArg(0)->getType()->getAsCXXRecordDecl() != LambdaObjTy) | ||
return true; | ||
LambdaFn = M; // call to PF lambda found - record the lambda | ||
return false; // ... and stop searching | ||
} | ||
|
||
// Returns the captured lambda function or nullptr; | ||
CXXMethodDecl *getLambdaFn() const { return LambdaFn; } | ||
|
||
private: | ||
CXXMethodDecl *LambdaFn; | ||
const CXXRecordDecl *LambdaObjTy; | ||
}; | ||
|
||
class MarkWIScopeFnVisitor : public RecursiveASTVisitor<MarkWIScopeFnVisitor> { | ||
public: | ||
MarkWIScopeFnVisitor(ASTContext &Ctx) : Ctx(Ctx) {} | ||
|
@@ -2687,15 +2739,65 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { | |
return !SemaRef.getASTContext().hasSameType(FD->getType(), Ty); | ||
} | ||
|
||
// Sets a flag if the kernel is a parallel_for that calls the | ||
// free function API "this_item". | ||
void setThisItemIsCalled(const CXXRecordDecl *KernelObj, | ||
FunctionDecl *KernelFunc) { | ||
if (getKernelInvocationKind(KernelFunc) != InvokeParallelFor) | ||
return; | ||
|
||
FindPFLambdaFnVisitor V(KernelObj); | ||
V.TraverseStmt(KernelFunc->getBody()); | ||
CXXMethodDecl *WGLambdaFn = V.getLambdaFn(); | ||
if (!WGLambdaFn) | ||
return; | ||
|
||
// The call graph for this translation unit. | ||
CallGraph SYCLCG; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK we already build a callgraph in SemaSYCL. Can we try to re-use it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it would be nice to reuse that infrastructure. I first tried pursuing that approach. The result of a scan for calls to this_item would have to be saved somewhere. The existing callgraph traversal lead to various function "attributes" being set. This would be fine, except that calls_this_item is not an attribute. We could define an internal attribute for that. Would that be acceptable? If yes, it would simplify the SemaSYCL changes quite a bit. How to add such an attribute? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why not. You can check SYCLRequiresDecomposition for an example of internal attribute. @premanandrao @Fznamznon could you please confirm this is ok? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be ok. |
||
SYCLCG.addToCallGraph(SemaRef.getASTContext().getTranslationUnitDecl()); | ||
using ChildParentPair = std::pair<FunctionDecl *, FunctionDecl *>; | ||
llvm::SmallPtrSet<FunctionDecl *, 16> Visited; | ||
llvm::SmallVector<ChildParentPair, 16> WorkList; | ||
WorkList.push_back({WGLambdaFn, nullptr}); | ||
|
||
while (!WorkList.empty()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code seems also pretty similar to code which walks through another CallGraph in SemaSYCL. I think we need to merge them one day, but this is not in scope of your PR, I just wanted to capture the problem. |
||
FunctionDecl *FD = WorkList.back().first; | ||
WorkList.pop_back(); | ||
if (!Visited.insert(FD).second) | ||
continue; // We've already seen this Decl | ||
|
||
if (FD->isFunctionOrMethod() && FD->getIdentifier() && | ||
!FD->getName().empty() && "this_item" == FD->getName() && | ||
Util::isSyclItemType(FD->getReturnType())) { | ||
Header.setCallsThisItem(true); | ||
return; | ||
} | ||
|
||
CallGraphNode *N = SYCLCG.getNode(FD); | ||
if (!N) | ||
continue; | ||
|
||
for (const CallGraphNode *CI : *N) { | ||
if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) { | ||
Callee = Callee->getMostRecentDecl(); | ||
if (!Visited.count(Callee)) | ||
WorkList.push_back({Callee, FD}); | ||
} | ||
} | ||
} | ||
} | ||
|
||
public: | ||
static constexpr const bool VisitInsideSimpleContainers = false; | ||
SyclKernelIntHeaderCreator(Sema &S, SYCLIntegrationHeader &H, | ||
const CXXRecordDecl *KernelObj, QualType NameType, | ||
StringRef Name, StringRef StableName) | ||
StringRef Name, StringRef StableName, | ||
FunctionDecl *KernelFunc) | ||
: SyclKernelFieldHandler(S), Header(H) { | ||
bool IsSIMDKernel = isESIMDKernelType(KernelObj); | ||
Header.startKernel(Name, NameType, StableName, KernelObj->getLocation(), | ||
IsSIMDKernel); | ||
setThisItemIsCalled(KernelObj, KernelFunc); | ||
} | ||
|
||
bool handleSyclAccessorType(const CXXRecordDecl *RD, | ||
|
@@ -3142,7 +3244,7 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, | |
SyclKernelIntHeaderCreator int_header( | ||
*this, getSyclIntegrationHeader(), KernelObj, | ||
calculateKernelNameType(Context, KernelCallerFunc), KernelName, | ||
StableName); | ||
StableName, KernelCallerFunc); | ||
|
||
KernelObjVisitor Visitor{*this}; | ||
Visitor.VisitRecordBases(KernelObj, kernel_decl, kernel_body, int_header); | ||
|
@@ -3854,6 +3956,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { | |
O << " __SYCL_DLL_LOCAL\n"; | ||
O << " static constexpr bool isESIMD() { return " << K.IsESIMDKernel | ||
<< "; }\n"; | ||
O << " __SYCL_DLL_LOCAL\n"; | ||
O << " static constexpr bool callsThisItem() { return "; | ||
O << K.CallsThisItem << "; }\n"; | ||
O << "};\n"; | ||
CurStart += N; | ||
} | ||
|
@@ -3912,6 +4017,12 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) { | |
SpecConsts.emplace_back(std::make_pair(IDType, IDName.str())); | ||
} | ||
|
||
void SYCLIntegrationHeader::setCallsThisItem(bool B) { | ||
KernelDesc *K = getCurKernelDesc(); | ||
assert(K && "no kernels"); | ||
K->CallsThisItem = B; | ||
} | ||
|
||
SYCLIntegrationHeader::SYCLIntegrationHeader(DiagnosticsEngine &_Diag, | ||
bool _UnnamedLambdaSupport, | ||
Sema &_S) | ||
|
@@ -3933,6 +4044,10 @@ bool Util::isSyclStreamType(const QualType &Ty) { | |
return isSyclType(Ty, "stream"); | ||
} | ||
|
||
bool Util::isSyclItemType(const QualType &Ty) { | ||
return isSyclType(Ty, "item", true /*Tmpl*/); | ||
} | ||
|
||
bool Util::isSyclHalfType(const QualType &Ty) { | ||
const StringRef &Name = "half"; | ||
std::array<DeclContextDesc, 5> Scopes = { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
// RUN: %clang_cc1 -fsycl -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -fsyntax-only | ||
// RUN: FileCheck -input-file=%t.h %s | ||
|
||
// This test checks that compiler generates correct kernel description | ||
// for parallel_for kernels that use the this_item API. | ||
|
||
// CHECK: __SYCL_INLINE_NAMESPACE(cl) { | ||
// CHECK-NEXT: namespace sycl { | ||
// CHECK-NEXT: namespace detail { | ||
|
||
// CHECK: static constexpr | ||
// CHECK-NEXT: const char* const kernel_names[] = { | ||
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3GNU", | ||
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3EMU", | ||
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3OWL", | ||
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3RAT" | ||
// CHECK-NEXT: }; | ||
|
||
// CHECK:template <> struct KernelInfo<class GNU> { | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3GNU"; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { | ||
// CHECK-NEXT: return kernel_signatures[i+0]; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool callsThisItem() { return 0; } | ||
// CHECK-NEXT:}; | ||
// CHECK-NEXT:template <> struct KernelInfo<class EMU> { | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3EMU"; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { | ||
// CHECK-NEXT: return kernel_signatures[i+0]; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool callsThisItem() { return 1; } | ||
// CHECK-NEXT:}; | ||
// CHECK-NEXT:template <> struct KernelInfo<class OWL> { | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3OWL"; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { | ||
// CHECK-NEXT: return kernel_signatures[i+0]; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool callsThisItem() { return 0; } | ||
// CHECK-NEXT:}; | ||
// CHECK-NEXT:template <> struct KernelInfo<class RAT> { | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3RAT"; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) { | ||
// CHECK-NEXT: return kernel_signatures[i+0]; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; } | ||
// CHECK-NEXT: __SYCL_DLL_LOCAL | ||
// CHECK-NEXT: static constexpr bool callsThisItem() { return 1; } | ||
// CHECK-NEXT:}; | ||
|
||
#include "sycl.hpp" | ||
|
||
using namespace cl::sycl; | ||
|
||
int main() { | ||
::queue myQueue; | ||
myQueue.submit([&](::handler &cgh) { | ||
cgh.parallel_for<class GNU>(::range<1>(1), | ||
[=](::item<1> I) {}); | ||
cgh.parallel_for<class EMU>( | ||
::range<1>(1), | ||
[=](::item<1> I) { ::this_item<1>(); }); | ||
cgh.parallel_for<class OWL>(::range<1>(1), | ||
[=](::id<1> I) {}); | ||
cgh.parallel_for<class RAT>(::range<1>(1), [=](::id<1> I) { | ||
::this_item<1>(); | ||
}); | ||
}); | ||
|
||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I found this while working on something else. Is there anything we can do to make this MUCH more selective? The problem we have now is that someone who uses a lambda (or operator()) inside their top-level lambda will have things mis-diagnose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isSYCLKernelBodyFunction has a simplistic implementation, not introduced by this PR, by the way. One way to improve matters is to recognize the kernel early during parsing, and use an internal attribute to mark it as a KernelBody.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but this use of it actually is a pretty nasty breaking change.
I'm not sure what opportunity the parser has to do that marking, AND it would likely break your patch (since there is no way to mark the 2nd lambda there).
I think we might need some sort of way of having this library opt-into pulling the body-attributes in from the child.