24
24
#include " clang/CIR/Dialect/IR/CIRTypes.h"
25
25
#include " clang/CIR/LowerToMLIR.h"
26
26
#include " clang/CIR/Passes.h"
27
+ #include " llvm/ADT/TypeSwitch.h"
27
28
28
29
using namespace cir ;
29
30
using namespace llvm ;
@@ -55,6 +56,19 @@ class SCFLoop {
55
56
int64_t step = 0 ;
56
57
};
57
58
59
+ class SCFWhileLoop {
60
+ public:
61
+ SCFWhileLoop (mlir::cir::WhileOp op, mlir::cir::WhileOp::Adaptor adaptor,
62
+ mlir::ConversionPatternRewriter *rewriter)
63
+ : whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
64
+ void transferToSCFWhileOp ();
65
+
66
+ private:
67
+ mlir::cir::WhileOp whileOp;
68
+ mlir::cir::WhileOp::Adaptor adaptor;
69
+ mlir::ConversionPatternRewriter *rewriter;
70
+ };
71
+
58
72
static int64_t getConstant (mlir::cir::ConstantOp op) {
59
73
auto attr = op->getAttrs ().front ().getValue ();
60
74
const auto IntAttr = attr.dyn_cast <mlir::cir::IntAttr>();
@@ -233,6 +247,20 @@ void SCFLoop::transferToSCFForOp() {
233
247
});
234
248
}
235
249
250
+ void SCFWhileLoop::transferToSCFWhileOp () {
251
+ auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
252
+ whileOp->getLoc (), whileOp->getResultTypes (), adaptor.getOperands ());
253
+ rewriter->createBlock (&scfWhileOp.getBefore ());
254
+ rewriter->createBlock (&scfWhileOp.getAfter ());
255
+
256
+ rewriter->cloneRegionBefore (whileOp.getCond (),
257
+ &scfWhileOp.getBefore ().back ());
258
+ rewriter->eraseBlock (&scfWhileOp.getBefore ().back ());
259
+
260
+ rewriter->cloneRegionBefore (whileOp.getBody (), &scfWhileOp.getAfter ().back ());
261
+ rewriter->eraseBlock (&scfWhileOp.getAfter ().back ());
262
+ }
263
+
236
264
class CIRForOpLowering : public mlir ::OpConversionPattern<mlir::cir::ForOp> {
237
265
public:
238
266
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
@@ -248,9 +276,46 @@ class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
248
276
}
249
277
};
250
278
279
+ class CIRWhileOpLowering
280
+ : public mlir::OpConversionPattern<mlir::cir::WhileOp> {
281
+ public:
282
+ using OpConversionPattern<mlir::cir::WhileOp>::OpConversionPattern;
283
+
284
+ mlir::LogicalResult
285
+ matchAndRewrite (mlir::cir::WhileOp op, OpAdaptor adaptor,
286
+ mlir::ConversionPatternRewriter &rewriter) const override {
287
+ SCFWhileLoop loop (op, adaptor, &rewriter);
288
+ loop.transferToSCFWhileOp ();
289
+ rewriter.eraseOp (op);
290
+ return mlir::success ();
291
+ }
292
+ };
293
+
294
+ class CIRConditionOpLowering
295
+ : public mlir::OpConversionPattern<mlir::cir::ConditionOp> {
296
+ public:
297
+ using OpConversionPattern<mlir::cir::ConditionOp>::OpConversionPattern;
298
+ mlir::LogicalResult
299
+ matchAndRewrite (mlir::cir::ConditionOp op, OpAdaptor adaptor,
300
+ mlir::ConversionPatternRewriter &rewriter) const override {
301
+ auto *parentOp = op->getParentOp ();
302
+ return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
303
+ .Case <mlir::scf::WhileOp>([&](auto ) {
304
+ auto condition = adaptor.getCondition ();
305
+ auto i1Condition = rewriter.create <mlir::arith::TruncIOp>(
306
+ op.getLoc (), rewriter.getI1Type (), condition);
307
+ rewriter.replaceOpWithNewOp <mlir::scf::ConditionOp>(
308
+ op, i1Condition, parentOp->getOperands ());
309
+ return mlir::success ();
310
+ })
311
+ .Default ([](auto ) { return mlir::failure (); });
312
+ }
313
+ };
314
+
251
315
void populateCIRLoopToSCFConversionPatterns (mlir::RewritePatternSet &patterns,
252
316
mlir::TypeConverter &converter) {
253
- patterns.add <CIRForOpLowering>(converter, patterns.getContext ());
317
+ patterns.add <CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering>(
318
+ converter, patterns.getContext ());
254
319
}
255
320
256
- } // namespace cir
321
+ } // namespace cir
0 commit comments