@@ -69,6 +69,19 @@ class SCFWhileLoop {
69
69
mlir::ConversionPatternRewriter *rewriter;
70
70
};
71
71
72
+ class SCFDoLoop {
73
+ public:
74
+ SCFDoLoop (mlir::cir::DoWhileOp op, mlir::cir::DoWhileOp::Adaptor adaptor,
75
+ mlir::ConversionPatternRewriter *rewriter)
76
+ : DoOp(op), adaptor(adaptor), rewriter(rewriter) {}
77
+ void transferToSCFWhileOp ();
78
+
79
+ private:
80
+ mlir::cir::DoWhileOp DoOp;
81
+ mlir::cir::DoWhileOp::Adaptor adaptor;
82
+ mlir::ConversionPatternRewriter *rewriter;
83
+ };
84
+
72
85
static int64_t getConstant (mlir::cir::ConstantOp op) {
73
86
auto attr = op->getAttrs ().front ().getValue ();
74
87
const auto IntAttr = attr.dyn_cast <mlir::cir::IntAttr>();
@@ -261,6 +274,40 @@ void SCFWhileLoop::transferToSCFWhileOp() {
261
274
rewriter->eraseBlock (&scfWhileOp.getAfter ().back ());
262
275
}
263
276
277
+ void SCFDoLoop::transferToSCFWhileOp () {
278
+ // only support a simple do-while
279
+ // FIXME: can not support nested do-while
280
+
281
+ auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
282
+ DoOp.getLoc (), DoOp->getResultTypes (), adaptor.getOperands ());
283
+
284
+ rewriter->createBlock (&scfWhileOp.getBefore ());
285
+ rewriter->createBlock (&scfWhileOp.getAfter ());
286
+
287
+ rewriter->cloneRegionBefore (DoOp.getBody (), &scfWhileOp.getBefore ().back ());
288
+ rewriter->eraseBlock (&scfWhileOp.getBefore ().back ());
289
+
290
+ rewriter->cloneRegionBefore (DoOp.getCond (), &scfWhileOp.getAfter ().back ());
291
+ rewriter->eraseBlock (&scfWhileOp.getAfter ().back ());
292
+
293
+ rewriter->inlineBlockBefore (&scfWhileOp.getAfter ().back (),
294
+ &scfWhileOp.getBefore ().back (),
295
+ scfWhileOp.getBefore ().back ().end ());
296
+
297
+ rewriter->createBlock (&scfWhileOp.getAfter ());
298
+
299
+ auto &beforeFrontBlock = scfWhileOp.getBefore ().front ();
300
+ for (auto it = beforeFrontBlock.begin (); it != beforeFrontBlock.end (); ++it) {
301
+ if (auto yieldOp = llvm::dyn_cast<mlir::cir::YieldOp>(&*it)) {
302
+ rewriter->eraseOp (yieldOp);
303
+ break ;
304
+ }
305
+ }
306
+
307
+ rewriter->setInsertionPointToEnd (&scfWhileOp.getAfter ().front ());
308
+ rewriter->create <mlir::scf::YieldOp>(DoOp.getLoc ());
309
+ }
310
+
264
311
class CIRForOpLowering : public mlir ::OpConversionPattern<mlir::cir::ForOp> {
265
312
public:
266
313
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
@@ -291,6 +338,20 @@ class CIRWhileOpLowering
291
338
}
292
339
};
293
340
341
+ class CIRDoOpLowering : public mlir ::OpConversionPattern<mlir::cir::DoWhileOp> {
342
+ public:
343
+ using OpConversionPattern<mlir::cir::DoWhileOp>::OpConversionPattern;
344
+
345
+ mlir::LogicalResult
346
+ matchAndRewrite (mlir::cir::DoWhileOp op, OpAdaptor adaptor,
347
+ mlir::ConversionPatternRewriter &rewriter) const override {
348
+ SCFDoLoop loop (op, adaptor, &rewriter);
349
+ loop.transferToSCFWhileOp ();
350
+ rewriter.eraseOp (op);
351
+ return mlir::success ();
352
+ }
353
+ };
354
+
294
355
class CIRConditionOpLowering
295
356
: public mlir::OpConversionPattern<mlir::cir::ConditionOp> {
296
357
public:
@@ -314,8 +375,8 @@ class CIRConditionOpLowering
314
375
315
376
void populateCIRLoopToSCFConversionPatterns (mlir::RewritePatternSet &patterns,
316
377
mlir::TypeConverter &converter) {
317
- patterns.add <CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering>(
318
- converter, patterns.getContext ());
378
+ patterns.add <CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering,
379
+ CIRDoOpLowering>( converter, patterns.getContext ());
319
380
}
320
381
321
- } // namespace cir
382
+ } // namespace cir
0 commit comments