8
8
9
9
#include " PassDetail.h"
10
10
#include " mlir/Dialect/Func/IR/FuncOps.h"
11
+ #include " mlir/IR/Block.h"
12
+ #include " mlir/IR/Operation.h"
11
13
#include " mlir/IR/PatternMatch.h"
14
+ #include " mlir/IR/Region.h"
12
15
#include " mlir/Support/LogicalResult.h"
13
16
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
14
17
#include " clang/CIR/Dialect/IR/CIRDialect.h"
15
18
#include " clang/CIR/Dialect/Passes.h"
19
+ #include " llvm/ADT/SmallVector.h"
16
20
17
21
using namespace mlir ;
18
22
using namespace cir ;
@@ -107,6 +111,92 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
107
111
}
108
112
};
109
113
114
+ // / Simplify suitable ternary operations into select operations.
115
+ // /
116
+ // / Only those ternary operations that meet the following criteria can be
117
+ // / simplified:
118
+ // / - The true branch and the false branch cannot have any side effects;
119
+ // / - The true branch and the false branch cannot be "too costly" since both of
120
+ // / them will be executed after the folding happens.
121
+ // /
122
+ // / For now we only simplify those ternary operations whose true and false
123
+ // / branches either directly yield a value or directly yield a constant. That
124
+ // / is, both of the two branches of these ternary operation must either:
125
+ // / - Only contain a single cir.yield operation, or
126
+ // / - Contain a cir.const operation followed by a cir.yield operation that
127
+ // / yields the constant value produced by the cir.const operation.
128
+ // /
129
+ // / For example, we will simplify the following ternary operation:
130
+ // /
131
+ // / %0 = cir.ternary (%condition, true {
132
+ // / %1 = cir.const ...
133
+ // / cir.yield %1
134
+ // / } false {
135
+ // / cir.yield %2
136
+ // / })
137
+ // /
138
+ // / into the following sequence of operations:
139
+ // /
140
+ // / %1 = cir.const ...
141
+ // / %0 = cir.select if %condition then %1 else %2
142
+ struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
143
+ using OpRewritePattern<TernaryOp>::OpRewritePattern;
144
+
145
+ LogicalResult matchAndRewrite (TernaryOp op,
146
+ PatternRewriter &rewriter) const override {
147
+ llvm::SmallVector<mlir::Operation *> opsToHoist;
148
+
149
+ mlir::Value trueValue =
150
+ simplifyTernaryBranch (op.getTrueRegion (), opsToHoist);
151
+ if (!trueValue)
152
+ return mlir::failure ();
153
+
154
+ mlir::Value falseValue =
155
+ simplifyTernaryBranch (op.getFalseRegion (), opsToHoist);
156
+ if (!falseValue)
157
+ return mlir::failure ();
158
+
159
+ for (auto *hoistOp : opsToHoist)
160
+ rewriter.moveOpBefore (hoistOp, op);
161
+ rewriter.replaceOpWithNewOp <mlir::cir::SelectOp>(op, op.getCond (),
162
+ trueValue, falseValue);
163
+
164
+ return mlir::success ();
165
+ }
166
+
167
+ private:
168
+ mlir::Value simplifyTernaryBranch (
169
+ mlir::Region ®ion,
170
+ llvm::SmallVector<mlir::Operation *> &opsToHoist) const {
171
+ if (!region.hasOneBlock ())
172
+ return nullptr ;
173
+
174
+ mlir::Block &block = region.front ();
175
+
176
+ // The block can contain at most 2 operations: one cir.const operation
177
+ // followed by one cir.yield operation
178
+ if (block.getOperations ().size () > 2 )
179
+ return nullptr ;
180
+
181
+ auto yieldOp = mlir::cast<mlir::cir::YieldOp>(block.getTerminator ());
182
+ auto yieldValue = yieldOp.getArgs ()[0 ];
183
+ if (block.getOperations ().size () == 1 )
184
+ return yieldValue;
185
+
186
+ // The yielded value must be produced by a cir.const operation in the same
187
+ // block to make the branch simplifiable.
188
+ auto yieldValueDef = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
189
+ yieldValue.getDefiningOp ());
190
+ if (!yieldValueDef)
191
+ return nullptr ;
192
+ if (yieldValueDef->getBlock () != &block)
193
+ return nullptr ;
194
+
195
+ opsToHoist.push_back (yieldValueDef);
196
+ return yieldValue;
197
+ }
198
+ };
199
+
110
200
// ===----------------------------------------------------------------------===//
111
201
// CIRSimplifyPass
112
202
// ===----------------------------------------------------------------------===//
@@ -131,7 +221,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
131
221
RemoveRedundantBranches,
132
222
RemoveEmptyScope,
133
223
RemoveEmptySwitch,
134
- RemoveTrivialTry
224
+ RemoveTrivialTry,
225
+ SimplifyTernary
135
226
>(patterns.getContext ());
136
227
// clang-format on
137
228
}
@@ -146,8 +237,9 @@ void CIRSimplifyPass::runOnOperation() {
146
237
getOperation ()->walk ([&](Operation *op) {
147
238
// CastOp here is to perform a manual `fold` in
148
239
// applyOpPatternsAndFold
149
- if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
150
- ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
240
+ if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
241
+ TernaryOp, SelectOp, ComplexCreateOp, ComplexRealOp, ComplexImagOp>(
242
+ op))
151
243
ops.push_back (op);
152
244
});
153
245
0 commit comments