Skip to content

Commit 3aeb3e0

Browse files
sitio-coutolanza
authored andcommitted
[CIR][Transforms][NFC] Refactor MergeCleanups pass (#384)
Breaks the pass into smaller more manageable rewrites.
1 parent 0affd8d commit 3aeb3e0

File tree

1 file changed

+123
-203
lines changed

1 file changed

+123
-203
lines changed

clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp

Lines changed: 123 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -7,248 +7,168 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "PassDetail.h"
10-
11-
#include "clang/CIR/Dialect/IR/CIRDialect.h"
12-
#include "clang/CIR/Dialect/Passes.h"
13-
1410
#include "mlir/Dialect/Func/IR/FuncOps.h"
15-
16-
#include "mlir/IR/Matchers.h"
1711
#include "mlir/IR/PatternMatch.h"
1812
#include "mlir/Support/LogicalResult.h"
1913
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
15+
#include "clang/CIR/Dialect/Passes.h"
2016

2117
using namespace mlir;
2218
using namespace cir;
2319

24-
namespace {
25-
26-
template <typename ScopeLikeOpTy>
27-
struct SimplifyRetYieldBlocks : public mlir::OpRewritePattern<ScopeLikeOpTy> {
28-
using OpRewritePattern<ScopeLikeOpTy>::OpRewritePattern;
29-
mlir::LogicalResult replaceScopeLikeOp(PatternRewriter &rewriter,
30-
ScopeLikeOpTy scopeLikeOp) const;
31-
32-
SimplifyRetYieldBlocks(mlir::MLIRContext *context)
33-
: OpRewritePattern<ScopeLikeOpTy>(context, /*benefit=*/1) {}
34-
35-
mlir::LogicalResult
36-
checkAndRewriteRegion(mlir::Region &r,
37-
mlir::PatternRewriter &rewriter) const {
38-
auto &blocks = r.getBlocks();
39-
40-
if (blocks.size() <= 1)
41-
return failure();
42-
43-
// Rewrite something like this:
44-
//
45-
// cir.if %2 {
46-
// %3 = cir.const(3 : i32) : i32
47-
// cir.br ^bb1
48-
// ^bb1: // pred: ^bb0
49-
// cir.return %3 : i32
50-
// }
51-
//
52-
// to this:
53-
//
54-
// cir.if %2 {
55-
// %3 = cir.const(3 : i32) : i32
56-
// cir.return %3 : i32
57-
// }
58-
//
59-
SmallPtrSet<mlir::Block *, 4> candidateBlocks;
60-
for (Block &block : blocks) {
61-
if (block.isEntryBlock())
62-
continue;
63-
64-
auto yieldVars = block.getOps<cir::YieldOp>();
65-
for (cir::YieldOp yield : yieldVars)
66-
candidateBlocks.insert(yield.getOperation()->getBlock());
20+
//===----------------------------------------------------------------------===//
21+
// Rewrite patterns
22+
//===----------------------------------------------------------------------===//
6723

68-
auto retVars = block.getOps<cir::ReturnOp>();
69-
for (cir::ReturnOp ret : retVars)
70-
candidateBlocks.insert(ret.getOperation()->getBlock());
71-
}
24+
namespace {
7225

73-
auto changed = mlir::failure();
74-
for (auto *mergeSource : candidateBlocks) {
75-
if (!(mergeSource->hasNoSuccessors() && mergeSource->hasOneUse()))
76-
continue;
77-
auto *mergeDest = mergeSource->getSinglePredecessor();
78-
if (!mergeDest || mergeDest->getNumSuccessors() != 1)
79-
continue;
80-
rewriter.eraseOp(mergeDest->getTerminator());
81-
rewriter.mergeBlocks(mergeSource, mergeDest);
82-
changed = mlir::success();
26+
/// Removes branches between two blocks if it is the only branch.
27+
///
28+
/// From:
29+
/// ^bb0:
30+
/// cir.br ^bb1
31+
/// ^bb1: // pred: ^bb0
32+
/// cir.return
33+
///
34+
/// To:
35+
/// ^bb0:
36+
/// cir.return
37+
struct RemoveRedudantBranches : public OpRewritePattern<BrOp> {
38+
using OpRewritePattern<BrOp>::OpRewritePattern;
39+
40+
LogicalResult matchAndRewrite(BrOp op,
41+
PatternRewriter &rewriter) const final {
42+
Block *block = op.getOperation()->getBlock();
43+
Block *dest = op.getDest();
44+
45+
// Single edge between blocks: merge it.
46+
if (block->getNumSuccessors() == 1 &&
47+
dest->getSinglePredecessor() == block) {
48+
rewriter.eraseOp(op);
49+
rewriter.mergeBlocks(dest, block);
50+
return success();
8351
}
8452

85-
return changed;
53+
return failure();
8654
}
55+
};
8756

88-
mlir::LogicalResult
89-
checkAndRewriteLoopCond(mlir::Region &condRegion,
90-
mlir::PatternRewriter &rewriter) const {
91-
SmallVector<Operation *> opsToSimplify;
92-
condRegion.walk([&](Operation *op) {
93-
if (isa<cir::BrCondOp>(op))
94-
opsToSimplify.push_back(op);
95-
});
96-
97-
// Blocks should only contain one "yield" operation.
98-
auto trivialYield = [&](Block *b) {
99-
if (&b->front() != &b->back())
100-
return false;
101-
return isa<YieldOp>(b->getTerminator());
102-
};
103-
104-
if (opsToSimplify.size() != 1)
105-
return failure();
106-
BrCondOp brCondOp = cast<cir::BrCondOp>(opsToSimplify[0]);
57+
/// Merges basic blocks of trivial conditional branches. This is useful when a
58+
/// the condition of conditional branch is a constant and the destinations of
59+
/// the conditional branch both have only one predecessor.
60+
///
61+
/// From:
62+
/// ^bb0:
63+
/// %0 = cir.const(#true) : !cir.bool
64+
/// cir.brcond %0 ^bb1, ^bb2
65+
/// ^bb1: // pred: ^bb0
66+
/// cir.yield continue
67+
/// ^bb2: // pred: ^bb0
68+
/// cir.yield
69+
///
70+
/// To:
71+
/// ^bb0:
72+
/// cir.yield continue
73+
///
74+
struct MergeTrivialConditionalBranches : public OpRewritePattern<BrCondOp> {
75+
using OpRewritePattern<BrCondOp>::OpRewritePattern;
76+
77+
LogicalResult match(BrCondOp op) const final {
78+
return success(isa<ConstantOp>(op.getCond().getDefiningOp()) &&
79+
op.getDestFalse()->hasOneUse() &&
80+
op.getDestTrue()->hasOneUse());
81+
}
10782

108-
// TODO: leverage SCCP to get improved results.
109-
auto cstOp = dyn_cast<cir::ConstantOp>(brCondOp.getCond().getDefiningOp());
110-
if (!cstOp || !cstOp.getValue().isa<mlir::cir::BoolAttr>() ||
111-
!trivialYield(brCondOp.getDestTrue()) ||
112-
!trivialYield(brCondOp.getDestFalse()))
113-
return failure();
83+
/// Replace conditional branch with unconditional branch.
84+
void rewrite(BrCondOp op, PatternRewriter &rewriter) const final {
85+
auto constOp = llvm::cast<ConstantOp>(op.getCond().getDefiningOp());
86+
bool cond = constOp.getValue().cast<cir::BoolAttr>().getValue();
87+
Block *block = op.getOperation()->getBlock();
11488

115-
// If the condition is constant, no need to use brcond, just yield
116-
// properly, "yield" for false and "yield continue" for true.
117-
auto boolAttr = cstOp.getValue().cast<mlir::cir::BoolAttr>();
118-
auto *falseBlock = brCondOp.getDestFalse();
119-
auto *trueBlock = brCondOp.getDestTrue();
120-
auto *currBlock = brCondOp.getOperation()->getBlock();
121-
if (boolAttr.getValue()) {
122-
rewriter.eraseOp(opsToSimplify[0]);
123-
rewriter.mergeBlocks(trueBlock, currBlock);
124-
falseBlock->erase();
89+
rewriter.eraseOp(op);
90+
if (cond) {
91+
rewriter.mergeBlocks(op.getDestTrue(), block);
92+
rewriter.eraseBlock(op.getDestFalse());
12593
} else {
126-
rewriter.eraseOp(opsToSimplify[0]);
127-
rewriter.mergeBlocks(falseBlock, currBlock);
128-
trueBlock->erase();
94+
rewriter.mergeBlocks(op.getDestFalse(), block);
95+
rewriter.eraseBlock(op.getDestTrue());
12996
}
130-
if (cstOp.use_empty())
131-
rewriter.eraseOp(cstOp);
132-
return success();
133-
}
134-
135-
mlir::LogicalResult
136-
matchAndRewrite(ScopeLikeOpTy op,
137-
mlir::PatternRewriter &rewriter) const override {
138-
return replaceScopeLikeOp(rewriter, op);
13997
}
14098
};
14199

142-
// Specialize the template to account for the different build signatures for
143-
// IfOp, ScopeOp, FuncOp, SwitchOp, LoopOp.
144-
template <>
145-
mlir::LogicalResult
146-
SimplifyRetYieldBlocks<IfOp>::replaceScopeLikeOp(PatternRewriter &rewriter,
147-
IfOp ifOp) const {
148-
auto regionChanged = mlir::failure();
149-
if (checkAndRewriteRegion(ifOp.getThenRegion(), rewriter).succeeded())
150-
regionChanged = mlir::success();
151-
if (checkAndRewriteRegion(ifOp.getElseRegion(), rewriter).succeeded())
152-
regionChanged = mlir::success();
153-
return regionChanged;
154-
}
100+
struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
101+
using OpRewritePattern<ScopeOp>::OpRewritePattern;
155102

156-
template <>
157-
mlir::LogicalResult
158-
SimplifyRetYieldBlocks<ScopeOp>::replaceScopeLikeOp(PatternRewriter &rewriter,
159-
ScopeOp scopeOp) const {
160-
// Scope region empty: just remove scope.
161-
if (scopeOp.getRegion().empty()) {
162-
rewriter.eraseOp(scopeOp);
163-
return mlir::success();
103+
LogicalResult match(ScopeOp op) const final {
104+
return success(op.getRegion().empty() ||
105+
(op.getRegion().getBlocks().size() == 1 &&
106+
op.getRegion().front().empty()));
164107
}
165108

166-
// Scope region non-empty: clean it up.
167-
if (checkAndRewriteRegion(scopeOp.getRegion(), rewriter).succeeded())
168-
return mlir::success();
169-
170-
return mlir::failure();
171-
}
172-
173-
template <>
174-
mlir::LogicalResult SimplifyRetYieldBlocks<cir::FuncOp>::replaceScopeLikeOp(
175-
PatternRewriter &rewriter, cir::FuncOp funcOp) const {
176-
auto regionChanged = mlir::failure();
177-
if (checkAndRewriteRegion(funcOp.getRegion(), rewriter).succeeded())
178-
regionChanged = mlir::success();
179-
return regionChanged;
180-
}
109+
void rewrite(ScopeOp op, PatternRewriter &rewriter) const final {
110+
rewriter.eraseOp(op);
111+
}
112+
};
181113

182-
template <>
183-
mlir::LogicalResult SimplifyRetYieldBlocks<cir::SwitchOp>::replaceScopeLikeOp(
184-
PatternRewriter &rewriter, cir::SwitchOp switchOp) const {
185-
auto regionChanged = mlir::failure();
114+
struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
115+
using OpRewritePattern<SwitchOp>::OpRewritePattern;
186116

187-
// Empty switch statement: just remove it.
188-
if (!switchOp.getCases().has_value() || switchOp.getCases()->empty()) {
189-
rewriter.eraseOp(switchOp);
190-
return mlir::success();
117+
LogicalResult match(SwitchOp op) const final {
118+
return success(op.getRegions().empty());
191119
}
192120

193-
// Non-empty switch statement: clean it up.
194-
for (auto &r : switchOp.getRegions()) {
195-
if (checkAndRewriteRegion(r, rewriter).succeeded())
196-
regionChanged = mlir::success();
121+
void rewrite(SwitchOp op, PatternRewriter &rewriter) const final {
122+
rewriter.eraseOp(op);
197123
}
198-
return regionChanged;
199-
}
200-
201-
template <>
202-
mlir::LogicalResult SimplifyRetYieldBlocks<cir::LoopOp>::replaceScopeLikeOp(
203-
PatternRewriter &rewriter, cir::LoopOp loopOp) const {
204-
auto regionChanged = mlir::failure();
205-
if (checkAndRewriteRegion(loopOp.getBody(), rewriter).succeeded())
206-
regionChanged = mlir::success();
207-
if (checkAndRewriteLoopCond(loopOp.getCond(), rewriter).succeeded())
208-
regionChanged = mlir::success();
209-
return regionChanged;
210-
}
124+
};
211125

212-
void getMergeCleanupsPatterns(RewritePatternSet &results,
213-
MLIRContext *context) {
214-
results.add<SimplifyRetYieldBlocks<IfOp>, SimplifyRetYieldBlocks<ScopeOp>,
215-
SimplifyRetYieldBlocks<cir::FuncOp>,
216-
SimplifyRetYieldBlocks<cir::SwitchOp>,
217-
SimplifyRetYieldBlocks<cir::LoopOp>>(context);
218-
}
126+
//===----------------------------------------------------------------------===//
127+
// MergeCleanupsPass
128+
//===----------------------------------------------------------------------===//
219129

220130
struct MergeCleanupsPass : public MergeCleanupsBase<MergeCleanupsPass> {
221-
MergeCleanupsPass() = default;
131+
using MergeCleanupsBase::MergeCleanupsBase;
132+
133+
// The same operation rewriting done here could have been performed
134+
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and
135+
// implementing the same from above in CIRDialects.cpp). However, it's
136+
// currently too aggressive for static analysis purposes, since it might
137+
// remove things where a diagnostic can be generated.
138+
//
139+
// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
140+
// disable this behavior.
222141
void runOnOperation() override;
223142
};
224143

225-
// The same operation rewriting done here could have been performed
226-
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and implementing
227-
// the same from above in CIRDialects.cpp). However, it's currently too
228-
// aggressive for static analysis purposes, since it might remove things where
229-
// a diagnostic can be generated.
230-
//
231-
// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
232-
// disable this behavior.
233-
void MergeCleanupsPass::runOnOperation() {
234-
auto op = getOperation();
235-
mlir::RewritePatternSet patterns(&getContext());
236-
getMergeCleanupsPatterns(patterns, &getContext());
237-
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
144+
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
145+
// clang-format off
146+
patterns.add<
147+
RemoveRedudantBranches,
148+
MergeTrivialConditionalBranches,
149+
RemoveEmptyScope,
150+
RemoveEmptySwitch
151+
>(patterns.getContext());
152+
// clang-format on
153+
}
238154

239-
SmallVector<Operation *> opsToSimplify;
240-
op->walk([&](Operation *op) {
241-
if (isa<cir::IfOp, cir::ScopeOp, cir::FuncOp, cir::SwitchOp, cir::LoopOp>(
242-
op))
243-
opsToSimplify.push_back(op);
155+
void MergeCleanupsPass::runOnOperation() {
156+
// Collect rewrite patterns.
157+
RewritePatternSet patterns(&getContext());
158+
populateMergeCleanupPatterns(patterns);
159+
160+
// Collect operations to apply patterns.
161+
SmallVector<Operation *, 16> ops;
162+
getOperation()->walk([&](Operation *op) {
163+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp>(op))
164+
ops.push_back(op);
244165
});
245166

246-
for (auto *o : opsToSimplify) {
247-
bool erase = false;
248-
(void)applyOpPatternsAndFold(o, frozenPatterns, GreedyRewriteConfig(),
249-
&erase);
250-
}
167+
// Apply patterns.
168+
if (applyOpPatternsAndFold(ops, std::move(patterns)).failed())
169+
signalPassFailure();
251170
}
171+
252172
} // namespace
253173

254174
std::unique_ptr<Pass> mlir::createMergeCleanupsPass() {

0 commit comments

Comments
 (0)