Skip to content

Commit 05a3014

Browse files
Merge branch 'llvm:main' into main
2 parents f347165 + dbe5613 commit 05a3014

File tree

32 files changed

+863
-178
lines changed

32 files changed

+863
-178
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9475,15 +9475,17 @@ class ConstOMPClauseVisitor :
94759475
class OMPClausePrinter final : public OMPClauseVisitor<OMPClausePrinter> {
94769476
raw_ostream &OS;
94779477
const PrintingPolicy &Policy;
9478+
unsigned Version;
94789479

94799480
/// Process clauses with list of variables.
94809481
template <typename T> void VisitOMPClauseList(T *Node, char StartSym);
94819482
/// Process motion clauses.
94829483
template <typename T> void VisitOMPMotionClause(T *Node);
94839484

94849485
public:
9485-
OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy)
9486-
: OS(OS), Policy(Policy) {}
9486+
OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy,
9487+
unsigned OpenMPVersion)
9488+
: OS(OS), Policy(Policy), Version(OpenMPVersion) {}
94879489

94889490
#define GEN_CLANG_CLAUSE_CLASS
94899491
#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S);

clang/lib/AST/DeclPrinter.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,7 +1827,7 @@ void DeclPrinter::VisitOMPAllocateDecl(OMPAllocateDecl *D) {
18271827
Out << ")";
18281828
}
18291829
if (!D->clauselist_empty()) {
1830-
OMPClausePrinter Printer(Out, Policy);
1830+
OMPClausePrinter Printer(Out, Policy, Context.getLangOpts().OpenMP);
18311831
for (OMPClause *C : D->clauselists()) {
18321832
Out << " ";
18331833
Printer.Visit(C);
@@ -1838,7 +1838,7 @@ void DeclPrinter::VisitOMPAllocateDecl(OMPAllocateDecl *D) {
18381838
void DeclPrinter::VisitOMPRequiresDecl(OMPRequiresDecl *D) {
18391839
Out << "#pragma omp requires ";
18401840
if (!D->clauselist_empty()) {
1841-
OMPClausePrinter Printer(Out, Policy);
1841+
OMPClausePrinter Printer(Out, Policy, Context.getLangOpts().OpenMP);
18421842
for (auto I = D->clauselist_begin(), E = D->clauselist_end(); I != E; ++I)
18431843
Printer.Visit(*I);
18441844
}
@@ -1891,7 +1891,7 @@ void DeclPrinter::VisitOMPDeclareMapperDecl(OMPDeclareMapperDecl *D) {
18911891
Out << D->getVarName();
18921892
Out << ")";
18931893
if (!D->clauselist_empty()) {
1894-
OMPClausePrinter Printer(Out, Policy);
1894+
OMPClausePrinter Printer(Out, Policy, Context.getLangOpts().OpenMP);
18951895
for (auto *C : D->clauselists()) {
18961896
Out << " ";
18971897
Printer.Visit(C);

clang/lib/AST/OpenMPClause.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,7 +1821,7 @@ OMPThreadLimitClause *OMPThreadLimitClause::CreateEmpty(const ASTContext &C,
18211821
void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) {
18221822
OS << "if(";
18231823
if (Node->getNameModifier() != OMPD_unknown)
1824-
OS << getOpenMPDirectiveName(Node->getNameModifier()) << ": ";
1824+
OS << getOpenMPDirectiveName(Node->getNameModifier(), Version) << ": ";
18251825
Node->getCondition()->printPretty(OS, nullptr, Policy, 0);
18261826
OS << ")";
18271827
}
@@ -2049,7 +2049,7 @@ void OMPClausePrinter::VisitOMPAbsentClause(OMPAbsentClause *Node) {
20492049
for (auto &D : Node->getDirectiveKinds()) {
20502050
if (!First)
20512051
OS << ", ";
2052-
OS << getOpenMPDirectiveName(D);
2052+
OS << getOpenMPDirectiveName(D, Version);
20532053
First = false;
20542054
}
20552055
OS << ")";
@@ -2067,7 +2067,7 @@ void OMPClausePrinter::VisitOMPContainsClause(OMPContainsClause *Node) {
20672067
for (auto &D : Node->getDirectiveKinds()) {
20682068
if (!First)
20692069
OS << ", ";
2070-
OS << getOpenMPDirectiveName(D);
2070+
OS << getOpenMPDirectiveName(D, Version);
20712071
First = false;
20722072
}
20732073
OS << ")";

clang/lib/AST/StmtPrinter.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,9 @@ void StmtPrinter::VisitOMPCanonicalLoop(OMPCanonicalLoop *Node) {
737737

738738
void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S,
739739
bool ForceNoStmt) {
740-
OMPClausePrinter Printer(OS, Policy);
740+
unsigned OpenMPVersion =
741+
Context ? Context->getLangOpts().OpenMP : llvm::omp::FallbackVersion;
742+
OMPClausePrinter Printer(OS, Policy, OpenMPVersion);
741743
ArrayRef<OMPClause *> Clauses = S->clauses();
742744
for (auto *Clause : Clauses)
743745
if (Clause && !Clause->isImplicit()) {
@@ -964,14 +966,18 @@ void StmtPrinter::VisitOMPTeamsDirective(OMPTeamsDirective *Node) {
964966

965967
void StmtPrinter::VisitOMPCancellationPointDirective(
966968
OMPCancellationPointDirective *Node) {
969+
unsigned OpenMPVersion =
970+
Context ? Context->getLangOpts().OpenMP : llvm::omp::FallbackVersion;
967971
Indent() << "#pragma omp cancellation point "
968-
<< getOpenMPDirectiveName(Node->getCancelRegion());
972+
<< getOpenMPDirectiveName(Node->getCancelRegion(), OpenMPVersion);
969973
PrintOMPExecutableDirective(Node);
970974
}
971975

972976
void StmtPrinter::VisitOMPCancelDirective(OMPCancelDirective *Node) {
977+
unsigned OpenMPVersion =
978+
Context ? Context->getLangOpts().OpenMP : llvm::omp::FallbackVersion;
973979
Indent() << "#pragma omp cancel "
974-
<< getOpenMPDirectiveName(Node->getCancelRegion());
980+
<< getOpenMPDirectiveName(Node->getCancelRegion(), OpenMPVersion);
975981
PrintOMPExecutableDirective(Node);
976982
}
977983

clang/lib/Basic/OpenMPKinds.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,8 @@ void clang::getOpenMPCaptureRegions(
850850
case OMPD_master:
851851
return false;
852852
default:
853-
llvm::errs() << getOpenMPDirectiveName(LKind) << '\n';
853+
llvm::errs() << getOpenMPDirectiveName(LKind, llvm::omp::FallbackVersion)
854+
<< '\n';
854855
llvm_unreachable("Unexpected directive");
855856
}
856857
return false;

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,12 @@ class CIRGenFunction : public CIRGenTypeCache {
718718
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
719719
const Stmt *associatedStmt);
720720

721+
template <typename Op, typename TermOp>
722+
mlir::LogicalResult emitOpenACCOpCombinedConstruct(
723+
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
724+
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
725+
const Stmt *loopStmt);
726+
721727
public:
722728
mlir::LogicalResult
723729
emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s);

clang/lib/CIR/CodeGen/CIRGenOpenACCClause.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ class OpenACCClauseCIREmitter final
107107
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
108108
}
109109

110+
mlir::acc::GangArgType decodeGangType(OpenACCGangKind gk) {
111+
switch (gk) {
112+
case OpenACCGangKind::Num:
113+
return mlir::acc::GangArgType::Num;
114+
case OpenACCGangKind::Dim:
115+
return mlir::acc::GangArgType::Dim;
116+
case OpenACCGangKind::Static:
117+
return mlir::acc::GangArgType::Static;
118+
}
119+
llvm_unreachable("unknown gang kind");
120+
}
121+
110122
public:
111123
OpenACCClauseCIREmitter(OpTy &operation, CIRGen::CIRGenFunction &cgf,
112124
CIRGen::CIRGenBuilderTy &builder,
@@ -424,6 +436,42 @@ class OpenACCClauseCIREmitter final
424436
return clauseNotImplemented(clause);
425437
}
426438
}
439+
440+
void VisitGangClause(const OpenACCGangClause &clause) {
441+
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
442+
if (clause.getNumExprs() == 0) {
443+
operation.addEmptyGang(builder.getContext(), lastDeviceTypeValues);
444+
} else {
445+
llvm::SmallVector<mlir::Value> values;
446+
llvm::SmallVector<mlir::acc::GangArgType> argTypes;
447+
for (unsigned i : llvm::index_range(0u, clause.getNumExprs())) {
448+
auto [kind, expr] = clause.getExpr(i);
449+
mlir::Location exprLoc = cgf.cgm.getLoc(expr->getBeginLoc());
450+
argTypes.push_back(decodeGangType(kind));
451+
if (kind == OpenACCGangKind::Dim) {
452+
llvm::APInt curValue =
453+
expr->EvaluateKnownConstInt(cgf.cgm.getASTContext());
454+
// The value is 1, 2, or 3, but the type isn't necessarily smaller
455+
// than 64.
456+
curValue = curValue.sextOrTrunc(64);
457+
values.push_back(
458+
createConstantInt(exprLoc, 64, curValue.getSExtValue()));
459+
} else if (isa<OpenACCAsteriskSizeExpr>(expr)) {
460+
values.push_back(createConstantInt(exprLoc, 64, -1));
461+
} else {
462+
values.push_back(createIntExpr(expr));
463+
}
464+
}
465+
466+
operation.addGangOperands(builder.getContext(), lastDeviceTypeValues,
467+
argTypes, values);
468+
}
469+
} else {
470+
// TODO: When we've implemented this for everything, switch this to an
471+
// unreachable. Combined constructs remain.
472+
return clauseNotImplemented(clause);
473+
}
474+
}
427475
};
428476

429477
template <typename OpTy>

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,65 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
5656
return res;
5757
}
5858

59+
namespace {
60+
template <typename Op> struct CombinedType;
61+
template <> struct CombinedType<ParallelOp> {
62+
static constexpr mlir::acc::CombinedConstructsType value =
63+
mlir::acc::CombinedConstructsType::ParallelLoop;
64+
};
65+
template <> struct CombinedType<SerialOp> {
66+
static constexpr mlir::acc::CombinedConstructsType value =
67+
mlir::acc::CombinedConstructsType::SerialLoop;
68+
};
69+
template <> struct CombinedType<KernelsOp> {
70+
static constexpr mlir::acc::CombinedConstructsType value =
71+
mlir::acc::CombinedConstructsType::KernelsLoop;
72+
};
73+
} // namespace
74+
75+
template <typename Op, typename TermOp>
76+
mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct(
77+
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
78+
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
79+
const Stmt *loopStmt) {
80+
mlir::LogicalResult res = mlir::success();
81+
82+
llvm::SmallVector<mlir::Type> retTy;
83+
llvm::SmallVector<mlir::Value> operands;
84+
85+
auto computeOp = builder.create<Op>(start, retTy, operands);
86+
computeOp.setCombinedAttr(builder.getUnitAttr());
87+
mlir::acc::LoopOp loopOp;
88+
89+
// First, emit the bodies of both operations, with the loop inside the body of
90+
// the combined construct.
91+
{
92+
mlir::Block &block = computeOp.getRegion().emplaceBlock();
93+
mlir::OpBuilder::InsertionGuard guardCase(builder);
94+
builder.setInsertionPointToEnd(&block);
95+
96+
LexicalScope ls{*this, start, builder.getInsertionBlock()};
97+
auto loopOp = builder.create<LoopOp>(start, retTy, operands);
98+
loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
99+
builder.getContext(), CombinedType<Op>::value));
100+
101+
{
102+
mlir::Block &innerBlock = loopOp.getRegion().emplaceBlock();
103+
mlir::OpBuilder::InsertionGuard guardCase(builder);
104+
builder.setInsertionPointToEnd(&innerBlock);
105+
106+
LexicalScope ls{*this, start, builder.getInsertionBlock()};
107+
res = emitStmt(loopStmt, /*useCurrentScope=*/true);
108+
109+
builder.create<mlir::acc::YieldOp>(end);
110+
}
111+
112+
builder.create<TermOp>(end);
113+
}
114+
115+
return res;
116+
}
117+
59118
template <typename Op>
60119
Op CIRGenFunction::emitOpenACCOp(
61120
mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
@@ -170,8 +229,25 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
170229

171230
mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct(
172231
const OpenACCCombinedConstruct &s) {
173-
cgm.errorNYI(s.getSourceRange(), "OpenACC Combined Construct");
174-
return mlir::failure();
232+
mlir::Location start = getLoc(s.getSourceRange().getBegin());
233+
mlir::Location end = getLoc(s.getSourceRange().getEnd());
234+
235+
switch (s.getDirectiveKind()) {
236+
case OpenACCDirectiveKind::ParallelLoop:
237+
return emitOpenACCOpCombinedConstruct<ParallelOp, mlir::acc::YieldOp>(
238+
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
239+
s.getLoop());
240+
case OpenACCDirectiveKind::SerialLoop:
241+
return emitOpenACCOpCombinedConstruct<SerialOp, mlir::acc::YieldOp>(
242+
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
243+
s.getLoop());
244+
case OpenACCDirectiveKind::KernelsLoop:
245+
return emitOpenACCOpCombinedConstruct<KernelsOp, mlir::acc::TerminatorOp>(
246+
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
247+
s.getLoop());
248+
default:
249+
llvm_unreachable("invalid compute construct kind");
250+
}
175251
}
176252
mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct(
177253
const OpenACCEnterDataConstruct &s) {

0 commit comments

Comments
 (0)