Skip to content

[SYCL] Implement SYCL 2020 specialization constants in Clang #3345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ class SYCLIntegrationHeader {
kind_std_layout,
kind_sampler,
kind_pointer,
kind_last = kind_pointer
kind_specialization_constants_buffer,
kind_last = kind_specialization_constants_buffer
};

public:
Expand Down
208 changes: 187 additions & 21 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ enum KernelInvocationKind {

static constexpr llvm::StringLiteral InitMethodName = "__init";
static constexpr llvm::StringLiteral InitESIMDMethodName = "__init_esimd";
static constexpr llvm::StringLiteral InitSpecConstantsBuffer =
"__init_specialization_constants_buffer";
static constexpr llvm::StringLiteral FinalizeMethodName = "__finalize";
constexpr unsigned MaxKernelArgsSize = 2048;

Expand Down Expand Up @@ -109,6 +111,10 @@ class Util {
/// specialization constant class.
static bool isSyclSpecConstantType(const QualType &Ty);

/// Checks whether given clang type is a full specialization of the SYCL
/// kernel_handler class.
static bool isSyclKernelHandlerType(const QualType &Ty);

// Checks declaration context hierarchy.
/// \param DC the context of the item to be checked.
/// \param Scopes the declaration scopes leading from the item context to the
Expand Down Expand Up @@ -616,11 +622,16 @@ class FindPFWGLambdaFnVisitor
auto *M = dyn_cast<CXXMethodDecl>(Call->getDirectCallee());
if (!M || (M->getOverloadedOperator() != OO_Call))
return true;
const int NumPFWGLambdaArgs = 2; // group and lambda obj

unsigned int NumPFWGLambdaArgs =
M->getNumParams() + 1; // group, optional kernel_handler and lambda obj
if (Call->getNumArgs() != NumPFWGLambdaArgs)
return true;
if (!Util::isSyclType(Call->getArg(1)->getType(), "group", true /*Tmpl*/))
return true;
if ((Call->getNumArgs() > 2) &&
!Util::isSyclKernelHandlerType(Call->getArg(2)->getType()))
return true;
if (Call->getArg(0)->getType()->getAsCXXRecordDecl() != LambdaObjTy)
return true;
LambdaFn = M; // call to PFWG lambda found - record the lambda
Expand Down Expand Up @@ -732,12 +743,7 @@ static ParamDesc makeParamDesc(const FieldDecl *Src, QualType Ty) {
Ctx.getTrivialTypeSourceInfo(Ty));
}

static ParamDesc makeParamDesc(ASTContext &Ctx, const CXXBaseSpecifier &Src,
QualType Ty) {
// TODO: There is no name for the base available, but duplicate names are
// seemingly already possible, so we'll give them all the same name for now.
// This only happens with the accessor types.
std::string Name = "_arg__base";
static ParamDesc makeParamDesc(ASTContext &Ctx, StringRef Name, QualType Ty) {
return std::make_tuple(Ty, &Ctx.Idents.get(Name),
Ctx.getTrivialTypeSourceInfo(Ty));
}
Expand Down Expand Up @@ -777,6 +783,38 @@ constructKernelName(Sema &S, FunctionDecl *KernelCallerFunc,
KernelNameType)};
}

static bool isDefaultSPIRArch(ASTContext &Context) {
llvm::Triple T = Context.getTargetInfo().getTriple();
if (T.isSPIR() && T.getSubArch() == llvm::Triple::NoSubArch)
return true;
return false;
}

static ParmVarDecl *getSyclKernelHandlerArg(FunctionDecl *KernelCallerFunc) {
// Specialization constants in SYCL 2020 are not captured by lambda and
// accessed through new optional lambda argument kernel_handler
auto KHArg =
std::find_if(KernelCallerFunc->param_begin(),
KernelCallerFunc->param_end(), [](ParmVarDecl *PVD) {
return Util::isSyclKernelHandlerType(PVD->getType());
});

ParmVarDecl *KernelHandlerArg =
(KHArg != KernelCallerFunc->param_end()) ? *KHArg : nullptr;

if (KernelHandlerArg) {
auto KHArgTooMany = std::find_if(
std::next(KHArg), KernelCallerFunc->param_end(), [](ParmVarDecl *PVD) {
return Util::isSyclKernelHandlerType(PVD->getType());
});

assert(KHArgTooMany == KernelCallerFunc->param_end() &&
"Too many kernel_handler arguments");
}

return KernelHandlerArg;
}

// anonymous namespace so these don't get linkage.
namespace {

Expand Down Expand Up @@ -1642,10 +1680,20 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
}

void addParam(const CXXBaseSpecifier &BS, QualType FieldTy) {
// TODO: There is no name for the base available, but duplicate names are
// seemingly already possible, so we'll give them all the same name for now.
// This only happens with the accessor types.
StringRef Name = "_arg__base";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For anyone who continues along next time and sees this "TODO" and wants to fix it, the type name of FieldTy might be a good addition here. Duplicates don't cause problems, but perhaps it would be nice from a debug-ability perspective.

ParamDesc newParamDesc =
makeParamDesc(SemaRef.getASTContext(), BS, FieldTy);
makeParamDesc(SemaRef.getASTContext(), Name, FieldTy);
addParam(newParamDesc, FieldTy);
}
// Add a parameter with specified name and type
void addParam(StringRef Name, QualType ParamTy) {
ParamDesc newParamDesc =
makeParamDesc(SemaRef.getASTContext(), Name, ParamTy);
addParam(newParamDesc, ParamTy);
}

void addParam(ParamDesc newParamDesc, QualType FieldTy) {
// Create a new ParmVarDecl based on the new info.
Expand Down Expand Up @@ -1946,6 +1994,18 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
return true;
}

// Generate kernel argument to intialize specialization constants. This
// argument is only generated when the target has no native support for
// specialization constants
void handleSyclKernelHandlerType() {
ASTContext &Context = SemaRef.getASTContext();
if (isDefaultSPIRArch(Context))
return;

StringRef Name = "_arg__specialization_constants_buffer";
addParam(Name, Context.getPointerType(Context.CharTy));
}

void setBody(CompoundStmt *KB) { KernelDecl->setBody(KB); }

FunctionDecl *getKernelDecl() { return KernelDecl; }
Expand Down Expand Up @@ -2091,28 +2151,46 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
// pointer-struct-wrapping code to ensure that we don't try to wrap
// non-top-level pointers.
uint64_t StructDepth = 0;
VarDecl *KernelHandlerClone = nullptr;

Stmt *replaceWithLocalClone(ParmVarDecl *OriginalParam, VarDecl *LocalClone,
Stmt *FunctionBody) {
// DeclRefExpr with valid source location but with decl which is not marked
// as used is invalid.
LocalClone->setIsUsed();
std::pair<DeclaratorDecl *, DeclaratorDecl *> MappingPair =
std::make_pair(OriginalParam, LocalClone);
KernelBodyTransform KBT(MappingPair, SemaRef);
return KBT.TransformStmt(FunctionBody).get();
}

// Using the statements/init expressions that we've created, this generates
// the kernel body compound stmt. CompoundStmt needs to know its number of
// statements in advance to allocate it, so we cannot do this as we go along.
CompoundStmt *createKernelBody() {
// Push the Kernel function scope to ensure the scope isn't empty
SemaRef.PushFunctionScope();

// Initialize kernel object local clone
assert(CollectionInitExprs.size() == 1 &&
"Should have been popped down to just the first one");
KernelObjClone->setInit(CollectionInitExprs.back());
Stmt *FunctionBody = KernelCallerFunc->getBody();

ParmVarDecl *KernelObjParam = *(KernelCallerFunc->param_begin());

// DeclRefExpr with valid source location but with decl which is not marked
// as used is invalid.
KernelObjClone->setIsUsed();
std::pair<DeclaratorDecl *, DeclaratorDecl *> MappingPair =
std::make_pair(KernelObjParam, KernelObjClone);

// Push the Kernel function scope to ensure the scope isn't empty
SemaRef.PushFunctionScope();
KernelBodyTransform KBT(MappingPair, SemaRef);
Stmt *NewBody = KBT.TransformStmt(FunctionBody).get();
// Replace references to the kernel object in kernel body, to use the
// compiler generated local clone
Stmt *NewBody =
replaceWithLocalClone(KernelCallerFunc->getParamDecl(0), KernelObjClone,
KernelCallerFunc->getBody());

// If kernel_handler argument is passed by SYCL kernel, replace references
// to this argument in kernel body, to use the compiler generated local
// clone
if (ParmVarDecl *KernelHandlerParam =
getSyclKernelHandlerArg(KernelCallerFunc))
NewBody = replaceWithLocalClone(KernelHandlerParam, KernelHandlerClone,
NewBody);

// Use transformed body (with clones) as kernel body
BodyStmts.push_back(NewBody);

BodyStmts.insert(BodyStmts.end(), FinalizeStmts.begin(),
Expand Down Expand Up @@ -2412,6 +2490,39 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

// Generate __init call for kernel handler argument
void handleSpecialType(QualType KernelHandlerTy) {
DeclRefExpr *KernelHandlerCloneRef =
DeclRefExpr::Create(SemaRef.Context, NestedNameSpecifierLoc(),
KernelCallerSrcLoc, KernelHandlerClone, false,
DeclarationNameInfo(), KernelHandlerTy, VK_LValue);
const auto *RecordDecl =
KernelHandlerClone->getType()->getAsCXXRecordDecl();
MemberExprBases.push_back(KernelHandlerCloneRef);
createSpecialMethodCall(RecordDecl, InitSpecConstantsBuffer, BodyStmts);
MemberExprBases.pop_back();
}

void createKernelHandlerClone(ASTContext &Ctx, DeclContext *DC,
ParmVarDecl *KernelHandlerArg) {
QualType Ty = KernelHandlerArg->getType();
TypeSourceInfo *TSInfo = Ctx.getTrivialTypeSourceInfo(Ty);
KernelHandlerClone =
VarDecl::Create(Ctx, DC, KernelCallerSrcLoc, KernelCallerSrcLoc,
KernelHandlerArg->getIdentifier(), Ty, TSInfo, SC_None);

// Default initialize clone
InitializedEntity VarEntity =
InitializedEntity::InitializeVariable(KernelHandlerClone);
InitializationKind InitKind =
InitializationKind::CreateDefault(KernelCallerSrcLoc);
InitializationSequence InitSeq(SemaRef, VarEntity, InitKind, None);
ExprResult Init = InitSeq.Perform(SemaRef, VarEntity, InitKind, None);
KernelHandlerClone->setInit(
SemaRef.MaybeCreateExprWithCleanups(Init.get()));
KernelHandlerClone->setInitStyle(VarDecl::CallInit);
}

public:
static constexpr const bool VisitInsideSimpleContainers = false;
SyclKernelBodyCreator(Sema &S, SyclKernelDeclCreator &DC,
Expand Down Expand Up @@ -2516,6 +2627,28 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

// Default inits the type, then calls the init-method in the body
void handleSyclKernelHandlerType(ParmVarDecl *KernelHandlerArg) {

// Create and default initialize local clone of kernel handler
createKernelHandlerClone(SemaRef.getASTContext(),
DeclCreator.getKernelDecl(), KernelHandlerArg);

// Add declaration statement to openCL kernel body
Stmt *DS =
new (SemaRef.Context) DeclStmt(DeclGroupRef(KernelHandlerClone),
KernelCallerSrcLoc, KernelCallerSrcLoc);
BodyStmts.push_back(DS);

// Generate
// KernelHandlerClone.__init_specialization_constants_buffer(specialization_constants_buffer)
// call if target does not have native support for specialization constants.
// Here, specialization_constants_buffer is the compiler generated kernel
// argument of type char*.
if (!isDefaultSPIRArch(SemaRef.Context))
handleSpecialType(KernelHandlerArg->getType());
}

bool enterStream(const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final {
++StructDepth;
// Add a dummy init expression to catch the accessor initializers.
Expand Down Expand Up @@ -2870,6 +3003,22 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
return true;
}

void handleSyclKernelHandlerType(QualType Ty) {
// The compiler generated kernel argument used to initialize SYCL 2020
// specialization constants, `specialization_constants_buffer`, should
// have corresponding entry in integration header. This argument is
// only generated when target has no native support for specialization
// constants.
ASTContext &Context = SemaRef.getASTContext();
if (isDefaultSPIRArch(Context))
return;

// Offset is zero since kernel_handler argument is not part of
// kernel object (i.e. it is not captured)
addParam(Context.getPointerType(Context.CharTy),
SYCLIntegrationHeader::kind_specialization_constants_buffer, 0);
}

bool enterStream(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final {
++StructDepth;
CurOffset += offsetOf(FD, Ty);
Expand Down Expand Up @@ -3257,6 +3406,13 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
KernelObjVisitor Visitor{*this};
Visitor.VisitRecordBases(KernelObj, kernel_decl, kernel_body, int_header);
Visitor.VisitRecordFields(KernelObj, kernel_decl, kernel_body, int_header);

if (ParmVarDecl *KernelHandlerArg =
getSyclKernelHandlerArg(KernelCallerFunc)) {
kernel_decl.handleSyclKernelHandlerType();
kernel_body.handleSyclKernelHandlerType(KernelHandlerArg);
int_header.handleSyclKernelHandlerType(KernelHandlerArg->getType());
}
}

void Sema::MarkDevice(void) {
Expand Down Expand Up @@ -3504,6 +3660,7 @@ static const char *paramKind2Str(KernelParamKind K) {
CASE(accessor);
CASE(std_layout);
CASE(sampler);
CASE(specialization_constants_buffer);
CASE(pointer);
}
return "<ERROR>";
Expand Down Expand Up @@ -4089,6 +4246,15 @@ bool Util::isSyclSpecConstantType(const QualType &Ty) {
return matchQualifiedTypeName(Ty, Scopes);
}

bool Util::isSyclKernelHandlerType(const QualType &Ty) {
const StringRef &Name = "kernel_handler";
std::array<DeclContextDesc, 3> Scopes = {
Util::DeclContextDesc{clang::Decl::Kind::Namespace, "cl"},
Util::DeclContextDesc{clang::Decl::Kind::Namespace, "sycl"},
Util::DeclContextDesc{Decl::Kind::CXXRecord, Name}};
return matchQualifiedTypeName(Ty, Scopes);
}

bool Util::isSyclBufferLocationType(const QualType &Ty) {
const StringRef &PropertyName = "buffer_location";
const StringRef &InstanceName = "instance";
Expand Down
20 changes: 20 additions & 0 deletions clang/test/CodeGenSYCL/Inputs/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,22 @@ class spec_constant {
} // namespace experimental
} // namespace ONEAPI

class kernel_handler {
void __init_specialization_constants_buffer(char *specialization_constants_buffer) {}
};

#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
template <typename KernelName = auto_name, typename KernelType>
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc) {
kernelFunc();
}

#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
template <typename KernelName = auto_name, typename KernelType>
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc, kernel_handler kh) {
kernelFunc(kh);
}

template <typename KernelName = auto_name, typename KernelType>
ATTR_SYCL_KERNEL void kernel_single_task_2017(KernelType kernelFunc) {
kernelFunc();
Expand Down Expand Up @@ -347,6 +357,16 @@ class handler {
#endif
}

template <typename KernelName = auto_name, typename KernelType>
void single_task(const KernelType &kernelFunc, kernel_handler kh) {
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
#ifdef __SYCL_DEVICE_ONLY__
kernel_single_task<NameT>(kernelFunc, kh);
#else
kernelFunc(kh);
#endif
}

template <typename KernelName = auto_name, typename KernelType>
void single_task_2017(KernelType kernelFunc) {
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
Expand Down
Loading