|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
9 | 9 | #include "PassDetail.h"
|
10 |
| - |
11 |
| -#include "clang/CIR/Dialect/IR/CIRDialect.h" |
12 |
| -#include "clang/CIR/Dialect/Passes.h" |
13 |
| - |
14 | 10 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
15 |
| - |
16 |
| -#include "mlir/IR/Matchers.h" |
17 | 11 | #include "mlir/IR/PatternMatch.h"
|
18 | 12 | #include "mlir/Support/LogicalResult.h"
|
19 | 13 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
| 14 | +#include "clang/CIR/Dialect/IR/CIRDialect.h" |
| 15 | +#include "clang/CIR/Dialect/Passes.h" |
20 | 16 |
|
21 | 17 | using namespace mlir;
|
22 | 18 | using namespace cir;
|
23 | 19 |
|
24 |
| -namespace { |
25 |
| - |
26 |
| -template <typename ScopeLikeOpTy> |
27 |
| -struct SimplifyRetYieldBlocks : public mlir::OpRewritePattern<ScopeLikeOpTy> { |
28 |
| - using OpRewritePattern<ScopeLikeOpTy>::OpRewritePattern; |
29 |
| - mlir::LogicalResult replaceScopeLikeOp(PatternRewriter &rewriter, |
30 |
| - ScopeLikeOpTy scopeLikeOp) const; |
31 |
| - |
32 |
| - SimplifyRetYieldBlocks(mlir::MLIRContext *context) |
33 |
| - : OpRewritePattern<ScopeLikeOpTy>(context, /*benefit=*/1) {} |
34 |
| - |
35 |
| - mlir::LogicalResult |
36 |
| - checkAndRewriteRegion(mlir::Region &r, |
37 |
| - mlir::PatternRewriter &rewriter) const { |
38 |
| - auto &blocks = r.getBlocks(); |
39 |
| - |
40 |
| - if (blocks.size() <= 1) |
41 |
| - return failure(); |
42 |
| - |
43 |
| - // Rewrite something like this: |
44 |
| - // |
45 |
| - // cir.if %2 { |
46 |
| - // %3 = cir.const(3 : i32) : i32 |
47 |
| - // cir.br ^bb1 |
48 |
| - // ^bb1: // pred: ^bb0 |
49 |
| - // cir.return %3 : i32 |
50 |
| - // } |
51 |
| - // |
52 |
| - // to this: |
53 |
| - // |
54 |
| - // cir.if %2 { |
55 |
| - // %3 = cir.const(3 : i32) : i32 |
56 |
| - // cir.return %3 : i32 |
57 |
| - // } |
58 |
| - // |
59 |
| - SmallPtrSet<mlir::Block *, 4> candidateBlocks; |
60 |
| - for (Block &block : blocks) { |
61 |
| - if (block.isEntryBlock()) |
62 |
| - continue; |
63 |
| - |
64 |
| - auto yieldVars = block.getOps<cir::YieldOp>(); |
65 |
| - for (cir::YieldOp yield : yieldVars) |
66 |
| - candidateBlocks.insert(yield.getOperation()->getBlock()); |
| 20 | +//===----------------------------------------------------------------------===// |
| 21 | +// Rewrite patterns |
| 22 | +//===----------------------------------------------------------------------===// |
67 | 23 |
|
68 |
| - auto retVars = block.getOps<cir::ReturnOp>(); |
69 |
| - for (cir::ReturnOp ret : retVars) |
70 |
| - candidateBlocks.insert(ret.getOperation()->getBlock()); |
71 |
| - } |
| 24 | +namespace { |
72 | 25 |
|
73 |
| - auto changed = mlir::failure(); |
74 |
| - for (auto *mergeSource : candidateBlocks) { |
75 |
| - if (!(mergeSource->hasNoSuccessors() && mergeSource->hasOneUse())) |
76 |
| - continue; |
77 |
| - auto *mergeDest = mergeSource->getSinglePredecessor(); |
78 |
| - if (!mergeDest || mergeDest->getNumSuccessors() != 1) |
79 |
| - continue; |
80 |
| - rewriter.eraseOp(mergeDest->getTerminator()); |
81 |
| - rewriter.mergeBlocks(mergeSource, mergeDest); |
82 |
| - changed = mlir::success(); |
| 26 | +/// Removes branches between two blocks if it is the only branch. |
| 27 | +/// |
| 28 | +/// From: |
| 29 | +/// ^bb0: |
| 30 | +/// cir.br ^bb1 |
| 31 | +/// ^bb1: // pred: ^bb0 |
| 32 | +/// cir.return |
| 33 | +/// |
| 34 | +/// To: |
| 35 | +/// ^bb0: |
| 36 | +/// cir.return |
| 37 | +struct RemoveRedudantBranches : public OpRewritePattern<BrOp> { |
| 38 | + using OpRewritePattern<BrOp>::OpRewritePattern; |
| 39 | + |
| 40 | + LogicalResult matchAndRewrite(BrOp op, |
| 41 | + PatternRewriter &rewriter) const final { |
| 42 | + Block *block = op.getOperation()->getBlock(); |
| 43 | + Block *dest = op.getDest(); |
| 44 | + |
| 45 | + // Single edge between blocks: merge it. |
| 46 | + if (block->getNumSuccessors() == 1 && |
| 47 | + dest->getSinglePredecessor() == block) { |
| 48 | + rewriter.eraseOp(op); |
| 49 | + rewriter.mergeBlocks(dest, block); |
| 50 | + return success(); |
83 | 51 | }
|
84 | 52 |
|
85 |
| - return changed; |
| 53 | + return failure(); |
86 | 54 | }
|
| 55 | +}; |
87 | 56 |
|
88 |
| - mlir::LogicalResult |
89 |
| - checkAndRewriteLoopCond(mlir::Region &condRegion, |
90 |
| - mlir::PatternRewriter &rewriter) const { |
91 |
| - SmallVector<Operation *> opsToSimplify; |
92 |
| - condRegion.walk([&](Operation *op) { |
93 |
| - if (isa<cir::BrCondOp>(op)) |
94 |
| - opsToSimplify.push_back(op); |
95 |
| - }); |
96 |
| - |
97 |
| - // Blocks should only contain one "yield" operation. |
98 |
| - auto trivialYield = [&](Block *b) { |
99 |
| - if (&b->front() != &b->back()) |
100 |
| - return false; |
101 |
| - return isa<YieldOp>(b->getTerminator()); |
102 |
| - }; |
103 |
| - |
104 |
| - if (opsToSimplify.size() != 1) |
105 |
| - return failure(); |
106 |
| - BrCondOp brCondOp = cast<cir::BrCondOp>(opsToSimplify[0]); |
| 57 | +/// Merges basic blocks of trivial conditional branches. This is useful when a |
| 58 | +/// the condition of conditional branch is a constant and the destinations of |
| 59 | +/// the conditional branch both have only one predecessor. |
| 60 | +/// |
| 61 | +/// From: |
| 62 | +/// ^bb0: |
| 63 | +/// %0 = cir.const(#true) : !cir.bool |
| 64 | +/// cir.brcond %0 ^bb1, ^bb2 |
| 65 | +/// ^bb1: // pred: ^bb0 |
| 66 | +/// cir.yield continue |
| 67 | +/// ^bb2: // pred: ^bb0 |
| 68 | +/// cir.yield |
| 69 | +/// |
| 70 | +/// To: |
| 71 | +/// ^bb0: |
| 72 | +/// cir.yield continue |
| 73 | +/// |
| 74 | +struct MergeTrivialConditionalBranches : public OpRewritePattern<BrCondOp> { |
| 75 | + using OpRewritePattern<BrCondOp>::OpRewritePattern; |
| 76 | + |
| 77 | + LogicalResult match(BrCondOp op) const final { |
| 78 | + return success(isa<ConstantOp>(op.getCond().getDefiningOp()) && |
| 79 | + op.getDestFalse()->hasOneUse() && |
| 80 | + op.getDestTrue()->hasOneUse()); |
| 81 | + } |
107 | 82 |
|
108 |
| - // TODO: leverage SCCP to get improved results. |
109 |
| - auto cstOp = dyn_cast<cir::ConstantOp>(brCondOp.getCond().getDefiningOp()); |
110 |
| - if (!cstOp || !cstOp.getValue().isa<mlir::cir::BoolAttr>() || |
111 |
| - !trivialYield(brCondOp.getDestTrue()) || |
112 |
| - !trivialYield(brCondOp.getDestFalse())) |
113 |
| - return failure(); |
| 83 | + /// Replace conditional branch with unconditional branch. |
| 84 | + void rewrite(BrCondOp op, PatternRewriter &rewriter) const final { |
| 85 | + auto constOp = llvm::cast<ConstantOp>(op.getCond().getDefiningOp()); |
| 86 | + bool cond = constOp.getValue().cast<cir::BoolAttr>().getValue(); |
| 87 | + Block *block = op.getOperation()->getBlock(); |
114 | 88 |
|
115 |
| - // If the condition is constant, no need to use brcond, just yield |
116 |
| - // properly, "yield" for false and "yield continue" for true. |
117 |
| - auto boolAttr = cstOp.getValue().cast<mlir::cir::BoolAttr>(); |
118 |
| - auto *falseBlock = brCondOp.getDestFalse(); |
119 |
| - auto *trueBlock = brCondOp.getDestTrue(); |
120 |
| - auto *currBlock = brCondOp.getOperation()->getBlock(); |
121 |
| - if (boolAttr.getValue()) { |
122 |
| - rewriter.eraseOp(opsToSimplify[0]); |
123 |
| - rewriter.mergeBlocks(trueBlock, currBlock); |
124 |
| - falseBlock->erase(); |
| 89 | + rewriter.eraseOp(op); |
| 90 | + if (cond) { |
| 91 | + rewriter.mergeBlocks(op.getDestTrue(), block); |
| 92 | + rewriter.eraseBlock(op.getDestFalse()); |
125 | 93 | } else {
|
126 |
| - rewriter.eraseOp(opsToSimplify[0]); |
127 |
| - rewriter.mergeBlocks(falseBlock, currBlock); |
128 |
| - trueBlock->erase(); |
| 94 | + rewriter.mergeBlocks(op.getDestFalse(), block); |
| 95 | + rewriter.eraseBlock(op.getDestTrue()); |
129 | 96 | }
|
130 |
| - if (cstOp.use_empty()) |
131 |
| - rewriter.eraseOp(cstOp); |
132 |
| - return success(); |
133 |
| - } |
134 |
| - |
135 |
| - mlir::LogicalResult |
136 |
| - matchAndRewrite(ScopeLikeOpTy op, |
137 |
| - mlir::PatternRewriter &rewriter) const override { |
138 |
| - return replaceScopeLikeOp(rewriter, op); |
139 | 97 | }
|
140 | 98 | };
|
141 | 99 |
|
142 |
| -// Specialize the template to account for the different build signatures for |
143 |
| -// IfOp, ScopeOp, FuncOp, SwitchOp, LoopOp. |
144 |
| -template <> |
145 |
| -mlir::LogicalResult |
146 |
| -SimplifyRetYieldBlocks<IfOp>::replaceScopeLikeOp(PatternRewriter &rewriter, |
147 |
| - IfOp ifOp) const { |
148 |
| - auto regionChanged = mlir::failure(); |
149 |
| - if (checkAndRewriteRegion(ifOp.getThenRegion(), rewriter).succeeded()) |
150 |
| - regionChanged = mlir::success(); |
151 |
| - if (checkAndRewriteRegion(ifOp.getElseRegion(), rewriter).succeeded()) |
152 |
| - regionChanged = mlir::success(); |
153 |
| - return regionChanged; |
154 |
| -} |
| 100 | +struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> { |
| 101 | + using OpRewritePattern<ScopeOp>::OpRewritePattern; |
155 | 102 |
|
156 |
| -template <> |
157 |
| -mlir::LogicalResult |
158 |
| -SimplifyRetYieldBlocks<ScopeOp>::replaceScopeLikeOp(PatternRewriter &rewriter, |
159 |
| - ScopeOp scopeOp) const { |
160 |
| - // Scope region empty: just remove scope. |
161 |
| - if (scopeOp.getRegion().empty()) { |
162 |
| - rewriter.eraseOp(scopeOp); |
163 |
| - return mlir::success(); |
| 103 | + LogicalResult match(ScopeOp op) const final { |
| 104 | + return success(op.getRegion().empty() || |
| 105 | + (op.getRegion().getBlocks().size() == 1 && |
| 106 | + op.getRegion().front().empty())); |
164 | 107 | }
|
165 | 108 |
|
166 |
| - // Scope region non-empty: clean it up. |
167 |
| - if (checkAndRewriteRegion(scopeOp.getRegion(), rewriter).succeeded()) |
168 |
| - return mlir::success(); |
169 |
| - |
170 |
| - return mlir::failure(); |
171 |
| -} |
172 |
| - |
173 |
| -template <> |
174 |
| -mlir::LogicalResult SimplifyRetYieldBlocks<cir::FuncOp>::replaceScopeLikeOp( |
175 |
| - PatternRewriter &rewriter, cir::FuncOp funcOp) const { |
176 |
| - auto regionChanged = mlir::failure(); |
177 |
| - if (checkAndRewriteRegion(funcOp.getRegion(), rewriter).succeeded()) |
178 |
| - regionChanged = mlir::success(); |
179 |
| - return regionChanged; |
180 |
| -} |
| 109 | + void rewrite(ScopeOp op, PatternRewriter &rewriter) const final { |
| 110 | + rewriter.eraseOp(op); |
| 111 | + } |
| 112 | +}; |
181 | 113 |
|
182 |
| -template <> |
183 |
| -mlir::LogicalResult SimplifyRetYieldBlocks<cir::SwitchOp>::replaceScopeLikeOp( |
184 |
| - PatternRewriter &rewriter, cir::SwitchOp switchOp) const { |
185 |
| - auto regionChanged = mlir::failure(); |
| 114 | +struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> { |
| 115 | + using OpRewritePattern<SwitchOp>::OpRewritePattern; |
186 | 116 |
|
187 |
| - // Empty switch statement: just remove it. |
188 |
| - if (!switchOp.getCases().has_value() || switchOp.getCases()->empty()) { |
189 |
| - rewriter.eraseOp(switchOp); |
190 |
| - return mlir::success(); |
| 117 | + LogicalResult match(SwitchOp op) const final { |
| 118 | + return success(op.getRegions().empty()); |
191 | 119 | }
|
192 | 120 |
|
193 |
| - // Non-empty switch statement: clean it up. |
194 |
| - for (auto &r : switchOp.getRegions()) { |
195 |
| - if (checkAndRewriteRegion(r, rewriter).succeeded()) |
196 |
| - regionChanged = mlir::success(); |
| 121 | + void rewrite(SwitchOp op, PatternRewriter &rewriter) const final { |
| 122 | + rewriter.eraseOp(op); |
197 | 123 | }
|
198 |
| - return regionChanged; |
199 |
| -} |
200 |
| - |
201 |
| -template <> |
202 |
| -mlir::LogicalResult SimplifyRetYieldBlocks<cir::LoopOp>::replaceScopeLikeOp( |
203 |
| - PatternRewriter &rewriter, cir::LoopOp loopOp) const { |
204 |
| - auto regionChanged = mlir::failure(); |
205 |
| - if (checkAndRewriteRegion(loopOp.getBody(), rewriter).succeeded()) |
206 |
| - regionChanged = mlir::success(); |
207 |
| - if (checkAndRewriteLoopCond(loopOp.getCond(), rewriter).succeeded()) |
208 |
| - regionChanged = mlir::success(); |
209 |
| - return regionChanged; |
210 |
| -} |
| 124 | +}; |
211 | 125 |
|
212 |
| -void getMergeCleanupsPatterns(RewritePatternSet &results, |
213 |
| - MLIRContext *context) { |
214 |
| - results.add<SimplifyRetYieldBlocks<IfOp>, SimplifyRetYieldBlocks<ScopeOp>, |
215 |
| - SimplifyRetYieldBlocks<cir::FuncOp>, |
216 |
| - SimplifyRetYieldBlocks<cir::SwitchOp>, |
217 |
| - SimplifyRetYieldBlocks<cir::LoopOp>>(context); |
218 |
| -} |
| 126 | +//===----------------------------------------------------------------------===// |
| 127 | +// MergeCleanupsPass |
| 128 | +//===----------------------------------------------------------------------===// |
219 | 129 |
|
220 | 130 | struct MergeCleanupsPass : public MergeCleanupsBase<MergeCleanupsPass> {
|
221 |
| - MergeCleanupsPass() = default; |
| 131 | + using MergeCleanupsBase::MergeCleanupsBase; |
| 132 | + |
| 133 | + // The same operation rewriting done here could have been performed |
| 134 | + // by CanonicalizerPass (adding hasCanonicalizer for target Ops and |
| 135 | + // implementing the same from above in CIRDialects.cpp). However, it's |
| 136 | + // currently too aggressive for static analysis purposes, since it might |
| 137 | + // remove things where a diagnostic can be generated. |
| 138 | + // |
| 139 | + // FIXME: perhaps we can add one more mode to GreedyRewriteConfig to |
| 140 | + // disable this behavior. |
222 | 141 | void runOnOperation() override;
|
223 | 142 | };
|
224 | 143 |
|
225 |
| -// The same operation rewriting done here could have been performed |
226 |
| -// by CanonicalizerPass (adding hasCanonicalizer for target Ops and implementing |
227 |
| -// the same from above in CIRDialects.cpp). However, it's currently too |
228 |
| -// aggressive for static analysis purposes, since it might remove things where |
229 |
| -// a diagnostic can be generated. |
230 |
| -// |
231 |
| -// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to |
232 |
| -// disable this behavior. |
233 |
| -void MergeCleanupsPass::runOnOperation() { |
234 |
| - auto op = getOperation(); |
235 |
| - mlir::RewritePatternSet patterns(&getContext()); |
236 |
| - getMergeCleanupsPatterns(patterns, &getContext()); |
237 |
| - FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
| 144 | +void populateMergeCleanupPatterns(RewritePatternSet &patterns) { |
| 145 | + // clang-format off |
| 146 | + patterns.add< |
| 147 | + RemoveRedudantBranches, |
| 148 | + MergeTrivialConditionalBranches, |
| 149 | + RemoveEmptyScope, |
| 150 | + RemoveEmptySwitch |
| 151 | + >(patterns.getContext()); |
| 152 | + // clang-format on |
| 153 | +} |
238 | 154 |
|
239 |
| - SmallVector<Operation *> opsToSimplify; |
240 |
| - op->walk([&](Operation *op) { |
241 |
| - if (isa<cir::IfOp, cir::ScopeOp, cir::FuncOp, cir::SwitchOp, cir::LoopOp>( |
242 |
| - op)) |
243 |
| - opsToSimplify.push_back(op); |
| 155 | +void MergeCleanupsPass::runOnOperation() { |
| 156 | + // Collect rewrite patterns. |
| 157 | + RewritePatternSet patterns(&getContext()); |
| 158 | + populateMergeCleanupPatterns(patterns); |
| 159 | + |
| 160 | + // Collect operations to apply patterns. |
| 161 | + SmallVector<Operation *, 16> ops; |
| 162 | + getOperation()->walk([&](Operation *op) { |
| 163 | + if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp>(op)) |
| 164 | + ops.push_back(op); |
244 | 165 | });
|
245 | 166 |
|
246 |
| - for (auto *o : opsToSimplify) { |
247 |
| - bool erase = false; |
248 |
| - (void)applyOpPatternsAndFold(o, frozenPatterns, GreedyRewriteConfig(), |
249 |
| - &erase); |
250 |
| - } |
| 167 | + // Apply patterns. |
| 168 | + if (applyOpPatternsAndFold(ops, std::move(patterns)).failed()) |
| 169 | + signalPassFailure(); |
251 | 170 | }
|
| 171 | + |
252 | 172 | } // namespace
|
253 | 173 |
|
254 | 174 | std::unique_ptr<Pass> mlir::createMergeCleanupsPass() {
|
|
0 commit comments